Coverage for src/click/_winconsole.py: 0%

158 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-05-21 11:44 +0200

1# This module is based on the excellent work by Adam Bartoš who 

2# provided a lot of what went into the implementation here in 

3# the discussion to issue1602 in the Python bug tracker. 

4# 

5# There are some general differences in regards to how this works 

6# compared to the original patches as we do not need to patch 

7# the entire interpreter but just work in our little world of 

8# echo and prompt. 

9import io 

10import sys 

11import time 

12import typing as t 

13from ctypes import byref 

14from ctypes import c_char 

15from ctypes import c_char_p 

16from ctypes import c_int 

17from ctypes import c_ssize_t 

18from ctypes import c_ulong 

19from ctypes import c_void_p 

20from ctypes import POINTER 

21from ctypes import py_object 

22from ctypes import Structure 

23from ctypes.wintypes import DWORD 

24from ctypes.wintypes import HANDLE 

25from ctypes.wintypes import LPCWSTR 

26from ctypes.wintypes import LPWSTR 

27 

28from ._compat import _NonClosingTextIOWrapper 

29 

30assert sys.platform == "win32" 

31import msvcrt # noqa: E402 

32from ctypes import windll # noqa: E402 

33from ctypes import WINFUNCTYPE # noqa: E402 

34 

35c_ssize_p = POINTER(c_ssize_t) 

36 

37kernel32 = windll.kernel32 

38GetStdHandle = kernel32.GetStdHandle 

39ReadConsoleW = kernel32.ReadConsoleW 

40WriteConsoleW = kernel32.WriteConsoleW 

41GetConsoleMode = kernel32.GetConsoleMode 

42GetLastError = kernel32.GetLastError 

43GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32)) 

44CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( 

45 ("CommandLineToArgvW", windll.shell32) 

46) 

47LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32)) 

48 

49STDIN_HANDLE = GetStdHandle(-10) 

50STDOUT_HANDLE = GetStdHandle(-11) 

51STDERR_HANDLE = GetStdHandle(-12) 

52 

53PyBUF_SIMPLE = 0 

54PyBUF_WRITABLE = 1 

55 

56ERROR_SUCCESS = 0 

57ERROR_NOT_ENOUGH_MEMORY = 8 

58ERROR_OPERATION_ABORTED = 995 

59 

60STDIN_FILENO = 0 

61STDOUT_FILENO = 1 

62STDERR_FILENO = 2 

63 

64EOF = b"\x1a" 

65MAX_BYTES_WRITTEN = 32767 

66 

67try: 

68 from ctypes import pythonapi 

69except ImportError: 

70 # On PyPy we cannot get buffers so our ability to operate here is 

71 # severely limited. 

72 get_buffer = None 

73else: 

74 

75 class Py_buffer(Structure): 

76 _fields_ = [ 

77 ("buf", c_void_p), 

78 ("obj", py_object), 

79 ("len", c_ssize_t), 

80 ("itemsize", c_ssize_t), 

81 ("readonly", c_int), 

82 ("ndim", c_int), 

83 ("format", c_char_p), 

84 ("shape", c_ssize_p), 

85 ("strides", c_ssize_p), 

86 ("suboffsets", c_ssize_p), 

87 ("internal", c_void_p), 

88 ] 

89 

90 PyObject_GetBuffer = pythonapi.PyObject_GetBuffer 

91 PyBuffer_Release = pythonapi.PyBuffer_Release 

92 

93 def get_buffer(obj, writable=False): 

94 buf = Py_buffer() 

95 flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE 

96 PyObject_GetBuffer(py_object(obj), byref(buf), flags) 

97 

98 try: 

99 buffer_type = c_char * buf.len 

100 return buffer_type.from_address(buf.buf) 

101 finally: 

102 PyBuffer_Release(byref(buf)) 

103 

104 

105class _WindowsConsoleRawIOBase(io.RawIOBase): 

106 def __init__(self, handle): 

107 self.handle = handle 

108 

109 def isatty(self): 

110 super().isatty() 

111 return True 

112 

113 

114class _WindowsConsoleReader(_WindowsConsoleRawIOBase): 

115 def readable(self): 

116 return True 

117 

118 def readinto(self, b): 

119 bytes_to_be_read = len(b) 

120 if not bytes_to_be_read: 

121 return 0 

122 elif bytes_to_be_read % 2: 

123 raise ValueError( 

124 "cannot read odd number of bytes from UTF-16-LE encoded console" 

125 ) 

126 

127 buffer = get_buffer(b, writable=True) 

128 code_units_to_be_read = bytes_to_be_read // 2 

129 code_units_read = c_ulong() 

130 

131 rv = ReadConsoleW( 

132 HANDLE(self.handle), 

133 buffer, 

134 code_units_to_be_read, 

135 byref(code_units_read), 

136 None, 

137 ) 

138 if GetLastError() == ERROR_OPERATION_ABORTED: 

139 # wait for KeyboardInterrupt 

140 time.sleep(0.1) 

141 if not rv: 

142 raise OSError(f"Windows error: {GetLastError()}") 

143 

144 if buffer[0] == EOF: 

145 return 0 

146 return 2 * code_units_read.value 

147 

148 

149class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): 

150 def writable(self): 

151 return True 

152 

153 @staticmethod 

154 def _get_error_message(errno): 

155 if errno == ERROR_SUCCESS: 

156 return "ERROR_SUCCESS" 

157 elif errno == ERROR_NOT_ENOUGH_MEMORY: 

158 return "ERROR_NOT_ENOUGH_MEMORY" 

159 return f"Windows error {errno}" 

160 

161 def write(self, b): 

162 bytes_to_be_written = len(b) 

163 buf = get_buffer(b) 

164 code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 

165 code_units_written = c_ulong() 

166 

167 WriteConsoleW( 

168 HANDLE(self.handle), 

169 buf, 

170 code_units_to_be_written, 

171 byref(code_units_written), 

172 None, 

173 ) 

174 bytes_written = 2 * code_units_written.value 

175 

176 if bytes_written == 0 and bytes_to_be_written > 0: 

177 raise OSError(self._get_error_message(GetLastError())) 

178 return bytes_written 

179 

180 

181class ConsoleStream: 

182 def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None: 

183 self._text_stream = text_stream 

184 self.buffer = byte_stream 

185 

186 @property 

187 def name(self) -> str: 

188 return self.buffer.name 

189 

190 def write(self, x: t.AnyStr) -> int: 

191 if isinstance(x, str): 

192 return self._text_stream.write(x) 

193 try: 

194 self.flush() 

195 except Exception: 

196 pass 

197 return self.buffer.write(x) 

198 

199 def writelines(self, lines: t.Iterable[t.AnyStr]) -> None: 

200 for line in lines: 

201 self.write(line) 

202 

203 def __getattr__(self, name: str) -> t.Any: 

204 return getattr(self._text_stream, name) 

205 

206 def isatty(self) -> bool: 

207 return self.buffer.isatty() 

208 

209 def __repr__(self): 

210 return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>" 

211 

212 

213def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO: 

214 text_stream = _NonClosingTextIOWrapper( 

215 io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), 

216 "utf-16-le", 

217 "strict", 

218 line_buffering=True, 

219 ) 

220 return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) 

221 

222 

223def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO: 

224 text_stream = _NonClosingTextIOWrapper( 

225 io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), 

226 "utf-16-le", 

227 "strict", 

228 line_buffering=True, 

229 ) 

230 return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) 

231 

232 

233def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO: 

234 text_stream = _NonClosingTextIOWrapper( 

235 io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), 

236 "utf-16-le", 

237 "strict", 

238 line_buffering=True, 

239 ) 

240 return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) 

241 

242 

243_stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = { 

244 0: _get_text_stdin, 

245 1: _get_text_stdout, 

246 2: _get_text_stderr, 

247} 

248 

249 

250def _is_console(f: t.TextIO) -> bool: 

251 if not hasattr(f, "fileno"): 

252 return False 

253 

254 try: 

255 fileno = f.fileno() 

256 except (OSError, io.UnsupportedOperation): 

257 return False 

258 

259 handle = msvcrt.get_osfhandle(fileno) 

260 return bool(GetConsoleMode(handle, byref(DWORD()))) 

261 

262 

263def _get_windows_console_stream( 

264 f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] 

265) -> t.Optional[t.TextIO]: 

266 if ( 

267 get_buffer is not None 

268 and encoding in {"utf-16-le", None} 

269 and errors in {"strict", None} 

270 and _is_console(f) 

271 ): 

272 func = _stream_factories.get(f.fileno()) 

273 if func is not None: 

274 b = getattr(f, "buffer", None) 

275 

276 if b is None: 

277 return None 

278 

279 return func(b)