图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

400 lines
12 KiB

  1. """Kazoo handler helpers"""
  2. from collections import defaultdict
  3. import errno
  4. import functools
  5. import select
  6. import selectors
  7. import ssl
  8. import socket
  9. import time
  10. HAS_FNCTL = True
  11. try:
  12. import fcntl
  13. except ImportError: # pragma: nocover
  14. HAS_FNCTL = False
  15. # sentinel objects
  16. _NONE = object()
  17. class AsyncResult(object):
  18. """A one-time event that stores a value or an exception"""
  19. def __init__(self, handler, condition_factory, timeout_factory):
  20. self._handler = handler
  21. self._exception = _NONE
  22. self._condition = condition_factory()
  23. self._callbacks = []
  24. self._timeout_factory = timeout_factory
  25. self.value = None
  26. def ready(self):
  27. """Return true if and only if it holds a value or an
  28. exception"""
  29. return self._exception is not _NONE
  30. def successful(self):
  31. """Return true if and only if it is ready and holds a value"""
  32. return self._exception is None
  33. @property
  34. def exception(self):
  35. if self._exception is not _NONE:
  36. return self._exception
  37. def set(self, value=None):
  38. """Store the value. Wake up the waiters."""
  39. with self._condition:
  40. self.value = value
  41. self._exception = None
  42. self._do_callbacks()
  43. self._condition.notify_all()
  44. def set_exception(self, exception):
  45. """Store the exception. Wake up the waiters."""
  46. with self._condition:
  47. self._exception = exception
  48. self._do_callbacks()
  49. self._condition.notify_all()
  50. def get(self, block=True, timeout=None):
  51. """Return the stored value or raise the exception.
  52. If there is no value raises TimeoutError.
  53. """
  54. with self._condition:
  55. if self._exception is not _NONE:
  56. if self._exception is None:
  57. return self.value
  58. raise self._exception
  59. elif block:
  60. self._condition.wait(timeout)
  61. if self._exception is not _NONE:
  62. if self._exception is None:
  63. return self.value
  64. raise self._exception
  65. # if we get to this point we timeout
  66. raise self._timeout_factory()
  67. def get_nowait(self):
  68. """Return the value or raise the exception without blocking.
  69. If nothing is available, raises TimeoutError
  70. """
  71. return self.get(block=False)
  72. def wait(self, timeout=None):
  73. """Block until the instance is ready."""
  74. with self._condition:
  75. if not self.ready():
  76. self._condition.wait(timeout)
  77. return self._exception is not _NONE
  78. def rawlink(self, callback):
  79. """Register a callback to call when a value or an exception is
  80. set"""
  81. with self._condition:
  82. if callback not in self._callbacks:
  83. self._callbacks.append(callback)
  84. # Are we already set? Dispatch it now
  85. if self.ready():
  86. self._do_callbacks()
  87. def unlink(self, callback):
  88. """Remove the callback set by :meth:`rawlink`"""
  89. with self._condition:
  90. if self.ready():
  91. # Already triggered, ignore
  92. return
  93. if callback in self._callbacks:
  94. self._callbacks.remove(callback)
  95. def _do_callbacks(self):
  96. """Execute the callbacks that were registered by :meth:`rawlink`.
  97. If the handler is in running state this method only schedules
  98. the calls to be performed by the handler. If it's stopped,
  99. the callbacks are called right away."""
  100. for callback in self._callbacks:
  101. if self._handler.running:
  102. self._handler.completion_queue.put(
  103. functools.partial(callback, self)
  104. )
  105. else:
  106. functools.partial(callback, self)()
  107. def _set_fd_cloexec(fd):
  108. flags = fcntl.fcntl(fd, fcntl.F_GETFD)
  109. fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
  110. def _set_default_tcpsock_options(module, sock):
  111. sock.setsockopt(module.IPPROTO_TCP, module.TCP_NODELAY, 1)
  112. if HAS_FNCTL:
  113. _set_fd_cloexec(sock)
  114. return sock
  115. def create_socket_pair(module, port=0):
  116. """Create socket pair.
  117. If socket.socketpair isn't available, we emulate it.
  118. """
  119. # See if socketpair() is available.
  120. have_socketpair = hasattr(module, "socketpair")
  121. if have_socketpair:
  122. client_sock, srv_sock = module.socketpair()
  123. return client_sock, srv_sock
  124. # Create a non-blocking temporary server socket
  125. temp_srv_sock = module.socket()
  126. temp_srv_sock.setblocking(False)
  127. temp_srv_sock.bind(("", port))
  128. port = temp_srv_sock.getsockname()[1]
  129. temp_srv_sock.listen(1)
  130. # Create non-blocking client socket
  131. client_sock = module.socket()
  132. client_sock.setblocking(False)
  133. try:
  134. client_sock.connect(("localhost", port))
  135. except module.error as err:
  136. # EWOULDBLOCK is not an error, as the socket is non-blocking
  137. if err.errno != errno.EWOULDBLOCK:
  138. raise
  139. # Use select to wait for connect() to succeed.
  140. timeout = 1
  141. readable = select.select([temp_srv_sock], [], [], timeout)[0]
  142. if temp_srv_sock not in readable:
  143. raise Exception(
  144. "Client socket not connected in %s" " second(s)" % (timeout)
  145. )
  146. srv_sock, _ = temp_srv_sock.accept()
  147. return client_sock, srv_sock
  148. def create_tcp_socket(module):
  149. """Create a TCP socket with the CLOEXEC flag set."""
  150. type_ = module.SOCK_STREAM
  151. if hasattr(module, "SOCK_CLOEXEC"): # pragma: nocover
  152. # if available, set cloexec flag during socket creation
  153. type_ |= module.SOCK_CLOEXEC
  154. sock = module.socket(module.AF_INET, type_)
  155. _set_default_tcpsock_options(module, sock)
  156. return sock
  157. def create_tcp_connection(
  158. module,
  159. address,
  160. timeout=None,
  161. use_ssl=False,
  162. ca=None,
  163. certfile=None,
  164. keyfile=None,
  165. keyfile_password=None,
  166. verify_certs=True,
  167. options=None,
  168. ciphers=None,
  169. ):
  170. end = None
  171. if timeout is None:
  172. # thanks to create_connection() developers for
  173. # this ugliness...
  174. timeout = module.getdefaulttimeout()
  175. if timeout is not None:
  176. end = time.time() + timeout
  177. sock = None
  178. while True:
  179. timeout_at = end if end is None else end - time.time()
  180. # The condition is not '< 0' here because socket.settimeout treats 0 as
  181. # a special case to put the socket in non-blocking mode.
  182. if timeout_at is not None and timeout_at <= 0:
  183. break
  184. if use_ssl:
  185. # Disallow use of SSLv2 and V3 (meaning we require TLSv1.0+)
  186. context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  187. if options is not None:
  188. context.options = options
  189. else:
  190. context.options |= ssl.OP_NO_SSLv2
  191. context.options |= ssl.OP_NO_SSLv3
  192. if ciphers:
  193. context.set_ciphers(ciphers)
  194. # Load default CA certs
  195. context.load_default_certs(ssl.Purpose.SERVER_AUTH)
  196. # We must set check_hostname to False prior to setting
  197. # verify_mode to CERT_NONE.
  198. # TODO: Make hostname verification configurable as some users may
  199. # elect to use it.
  200. context.check_hostname = False
  201. context.verify_mode = (
  202. ssl.CERT_REQUIRED if verify_certs else ssl.CERT_NONE
  203. )
  204. if ca:
  205. context.load_verify_locations(ca)
  206. if certfile and keyfile:
  207. context.load_cert_chain(
  208. certfile=certfile,
  209. keyfile=keyfile,
  210. password=keyfile_password,
  211. )
  212. try:
  213. # Query the address to get back it's address family
  214. addrs = socket.getaddrinfo(
  215. address[0], address[1], 0, socket.SOCK_STREAM
  216. )
  217. conn = context.wrap_socket(module.socket(addrs[0][0]))
  218. conn.settimeout(timeout_at)
  219. conn.connect(address)
  220. sock = conn
  221. break
  222. except ssl.SSLError:
  223. raise
  224. else:
  225. try:
  226. # if we got a timeout, lets ensure that we decrement the time
  227. # otherwise there is no timeout set and we'll call it as such
  228. sock = module.create_connection(address, timeout_at)
  229. break
  230. except Exception as ex:
  231. errnum = ex.errno if isinstance(ex, OSError) else ex[0]
  232. if errnum == errno.EINTR:
  233. continue
  234. raise
  235. if sock is None:
  236. raise module.error
  237. _set_default_tcpsock_options(module, sock)
  238. return sock
  239. def capture_exceptions(async_result):
  240. """Return a new decorated function that propagates the exceptions of the
  241. wrapped function to an async_result.
  242. :param async_result: An async result implementing :class:`IAsyncResult`
  243. """
  244. def capture(function):
  245. @functools.wraps(function)
  246. def captured_function(*args, **kwargs):
  247. try:
  248. return function(*args, **kwargs)
  249. except Exception as exc:
  250. async_result.set_exception(exc)
  251. return captured_function
  252. return capture
  253. def wrap(async_result):
  254. """Return a new decorated function that propagates the return value or
  255. exception of wrapped function to an async_result. NOTE: Only propagates a
  256. non-None return value.
  257. :param async_result: An async result implementing :class:`IAsyncResult`
  258. """
  259. def capture(function):
  260. @capture_exceptions(async_result)
  261. def captured_function(*args, **kwargs):
  262. value = function(*args, **kwargs)
  263. if value is not None:
  264. async_result.set(value)
  265. return value
  266. return captured_function
  267. return capture
  268. def fileobj_to_fd(fileobj):
  269. """Return a file descriptor from a file object.
  270. Parameters:
  271. fileobj -- file object or file descriptor
  272. Returns:
  273. corresponding file descriptor
  274. Raises:
  275. TypeError if the object is invalid
  276. """
  277. if isinstance(fileobj, int):
  278. fd = fileobj
  279. else:
  280. try:
  281. fd = int(fileobj.fileno())
  282. except (AttributeError, TypeError, ValueError):
  283. raise TypeError("Invalid file object: " "{!r}".format(fileobj))
  284. if fd < 0:
  285. raise TypeError("Invalid file descriptor: {}".format(fd))
  286. return fd
  287. def selector_select(
  288. rlist, wlist, xlist, timeout=None, selectors_module=selectors
  289. ):
  290. """Selector-based drop-in replacement for select to overcome select
  291. limitation on a maximum filehandle value.
  292. """
  293. if timeout is not None:
  294. if not isinstance(timeout, (int, float)):
  295. raise TypeError("timeout must be a number")
  296. if timeout < 0:
  297. raise ValueError("timeout must be non-negative")
  298. events_mapping = {
  299. selectors_module.EVENT_READ: rlist,
  300. selectors_module.EVENT_WRITE: wlist,
  301. }
  302. fd_events = defaultdict(int)
  303. fd_fileobjs = defaultdict(list)
  304. for event, fileobjs in events_mapping.items():
  305. for fileobj in fileobjs:
  306. fd = fileobj_to_fd(fileobj)
  307. fd_events[fd] |= event
  308. fd_fileobjs[fd].append(fileobj)
  309. selector = selectors_module.DefaultSelector()
  310. for fd, events in fd_events.items():
  311. try:
  312. selector.register(fd, events)
  313. except (ValueError, OSError) as e:
  314. # gevent can raise OSError
  315. raise ValueError("Invalid event mask or fd") from e
  316. revents, wevents, xevents = [], [], []
  317. try:
  318. ready = selector.select(timeout)
  319. finally:
  320. selector.close()
  321. for info in ready:
  322. k, events = info
  323. if events & selectors_module.EVENT_READ:
  324. revents.extend(fd_fileobjs[k.fd])
  325. elif events & selectors_module.EVENT_WRITE:
  326. wevents.extend(fd_fileobjs[k.fd])
  327. return revents, wevents, xevents