图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

852 lines
30 KiB

  1. """Zookeeper Protocol Connection Handler"""
  2. from binascii import hexlify
  3. from contextlib import contextmanager
  4. import copy
  5. import logging
  6. import random
  7. import select
  8. import socket
  9. import ssl
  10. import sys
  11. import time
  12. from kazoo.exceptions import (
  13. AuthFailedError,
  14. ConnectionDropped,
  15. EXCEPTIONS,
  16. SessionExpiredError,
  17. NoNodeError,
  18. SASLException,
  19. )
  20. from kazoo.loggingsupport import BLATHER
  21. from kazoo.protocol.serialization import (
  22. Auth,
  23. Close,
  24. Connect,
  25. Exists,
  26. GetChildren,
  27. GetChildren2,
  28. Ping,
  29. PingInstance,
  30. ReplyHeader,
  31. SASL,
  32. Transaction,
  33. Watch,
  34. int_struct,
  35. )
  36. from kazoo.protocol.states import (
  37. Callback,
  38. KeeperState,
  39. WatchedEvent,
  40. EVENT_TYPE_MAP,
  41. )
  42. from kazoo.retry import (
  43. ForceRetryError,
  44. RetryFailedError,
  45. )
  46. try:
  47. import puresasl
  48. import puresasl.client
  49. PURESASL_AVAILABLE = True
  50. except ImportError:
  51. PURESASL_AVAILABLE = False
  52. log = logging.getLogger(__name__)
  53. # Special testing hook objects used to force a session expired error as
  54. # if it came from the server
  55. _SESSION_EXPIRED = object()
  56. _CONNECTION_DROP = object()
  57. STOP_CONNECTING = object()
  58. CREATED_EVENT = 1
  59. DELETED_EVENT = 2
  60. CHANGED_EVENT = 3
  61. CHILD_EVENT = 4
  62. WATCH_XID = -1
  63. PING_XID = -2
  64. AUTH_XID = -4
  65. CLOSE_RESPONSE = Close.type
  66. if sys.version_info > (3,): # pragma: nocover
  67. def buffer(obj, offset=0):
  68. return memoryview(obj)[offset:]
  69. advance_iterator = next
  70. else: # pragma: nocover
  71. def advance_iterator(it):
  72. return it.next()
  73. class RWPinger(object):
  74. """A Read/Write Server Pinger Iterable
  75. This object is initialized with the hosts iterator object and the
  76. socket creation function. Anytime `next` is called on its iterator
  77. it yields either False, or a host, port tuple if it found a r/w
  78. capable Zookeeper node.
  79. After the first run-through of hosts, an exponential back-off delay
  80. is added before the next run. This delay is tracked internally and
  81. the iterator will yield False if called too soon.
  82. """
  83. def __init__(self, hosts, connection_func, socket_handling):
  84. self.hosts = hosts
  85. self.connection = connection_func
  86. self.last_attempt = None
  87. self.socket_handling = socket_handling
  88. def __iter__(self):
  89. if not self.last_attempt:
  90. self.last_attempt = time.time()
  91. delay = 0.5
  92. while True:
  93. yield self._next_server(delay)
  94. def _next_server(self, delay):
  95. jitter = random.randint(0, 100) / 100.0
  96. while time.time() < self.last_attempt + delay + jitter:
  97. # Skip rw ping checks if its too soon
  98. return False
  99. for host, port in self.hosts:
  100. log.debug("Pinging server for r/w: %s:%s", host, port)
  101. self.last_attempt = time.time()
  102. try:
  103. with self.socket_handling():
  104. sock = self.connection((host, port))
  105. sock.sendall(b"isro")
  106. result = sock.recv(8192)
  107. sock.close()
  108. if result == b"rw":
  109. return (host, port)
  110. else:
  111. return False
  112. except ConnectionDropped:
  113. return False
  114. # Add some jitter between host pings
  115. while time.time() < self.last_attempt + jitter:
  116. return False
  117. delay *= 2
  118. class RWServerAvailable(Exception):
  119. """Thrown if a RW Server becomes available"""
  120. class ConnectionHandler(object):
  121. """Zookeeper connection handler"""
  122. def __init__(self, client, retry_sleeper, logger=None, sasl_options=None):
  123. self.client = client
  124. self.handler = client.handler
  125. self.retry_sleeper = retry_sleeper
  126. self.logger = logger or log
  127. # Our event objects
  128. self.connection_closed = client.handler.event_object()
  129. self.connection_closed.set()
  130. self.connection_stopped = client.handler.event_object()
  131. self.connection_stopped.set()
  132. self.ping_outstanding = client.handler.event_object()
  133. self._read_sock = None
  134. self._write_sock = None
  135. self._socket = None
  136. self._xid = None
  137. self._rw_server = None
  138. self._ro_mode = False
  139. self._connection_routine = None
  140. self.sasl_options = sasl_options
  141. self.sasl_cli = None
  142. # This is instance specific to avoid odd thread bug issues in Python
  143. # during shutdown global cleanup
  144. @contextmanager
  145. def _socket_error_handling(self):
  146. try:
  147. yield
  148. except (socket.error, select.error) as e:
  149. err = getattr(e, "strerror", e)
  150. raise ConnectionDropped("socket connection error: %s" % (err,))
  151. def start(self):
  152. """Start the connection up"""
  153. if self.connection_closed.is_set():
  154. rw_sockets = self.handler.create_socket_pair()
  155. self._read_sock, self._write_sock = rw_sockets
  156. self.connection_closed.clear()
  157. if self._connection_routine:
  158. raise Exception(
  159. "Unable to start, connection routine already " "active."
  160. )
  161. self._connection_routine = self.handler.spawn(self.zk_loop)
  162. def stop(self, timeout=None):
  163. """Ensure the writer has stopped, wait to see if it does."""
  164. self.connection_stopped.wait(timeout)
  165. if self._connection_routine:
  166. self._connection_routine.join()
  167. self._connection_routine = None
  168. return self.connection_stopped.is_set()
  169. def close(self):
  170. """Release resources held by the connection
  171. The connection can be restarted afterwards.
  172. """
  173. if not self.connection_stopped.is_set():
  174. raise Exception("Cannot close connection until it is stopped")
  175. self.connection_closed.set()
  176. ws, rs = self._write_sock, self._read_sock
  177. self._write_sock = self._read_sock = None
  178. if ws is not None:
  179. ws.close()
  180. if rs is not None:
  181. rs.close()
  182. def _server_pinger(self):
  183. """Returns a server pinger iterable, that will ping the next
  184. server in the list, and apply a back-off between attempts."""
  185. return RWPinger(
  186. self.client.hosts,
  187. self.handler.create_connection,
  188. self._socket_error_handling,
  189. )
  190. def _read_header(self, timeout):
  191. b = self._read(4, timeout)
  192. length = int_struct.unpack(b)[0]
  193. b = self._read(length, timeout)
  194. header, offset = ReplyHeader.deserialize(b, 0)
  195. return header, b, offset
  196. def _read(self, length, timeout):
  197. msgparts = []
  198. remaining = length
  199. with self._socket_error_handling():
  200. while remaining > 0:
  201. # Because of SSL framing, a select may not return when using
  202. # an SSL socket because the underlying physical socket may not
  203. # have anything to select, but the wrapped object may still
  204. # have something to read as it has previously gotten enough
  205. # data from the underlying socket.
  206. if (
  207. hasattr(self._socket, "pending")
  208. and self._socket.pending() > 0
  209. ):
  210. pass
  211. else:
  212. s = self.handler.select([self._socket], [], [], timeout)[0]
  213. if not s: # pragma: nocover
  214. # If the read list is empty, we got a timeout. We don't
  215. # have to check wlist and xlist as we don't set any
  216. raise self.handler.timeout_exception(
  217. "socket time-out during read"
  218. )
  219. try:
  220. chunk = self._socket.recv(remaining)
  221. except ssl.SSLError as e:
  222. if e.errno in (
  223. ssl.SSL_ERROR_WANT_READ,
  224. ssl.SSL_ERROR_WANT_WRITE,
  225. ):
  226. continue
  227. else:
  228. raise
  229. if chunk == b"":
  230. raise ConnectionDropped("socket connection broken")
  231. msgparts.append(chunk)
  232. remaining -= len(chunk)
  233. return b"".join(msgparts)
  234. def _invoke(self, timeout, request, xid=None):
  235. """A special writer used during connection establishment
  236. only"""
  237. self._submit(request, timeout, xid)
  238. zxid = None
  239. if xid:
  240. header, buffer, offset = self._read_header(timeout)
  241. if header.xid != xid:
  242. raise RuntimeError(
  243. "xids do not match, expected %r " "received %r",
  244. xid,
  245. header.xid,
  246. )
  247. if header.zxid > 0:
  248. zxid = header.zxid
  249. if header.err:
  250. callback_exception = EXCEPTIONS[header.err]()
  251. self.logger.debug(
  252. "Received error(xid=%s) %r", xid, callback_exception
  253. )
  254. raise callback_exception
  255. return zxid
  256. msg = self._read(4, timeout)
  257. length = int_struct.unpack(msg)[0]
  258. msg = self._read(length, timeout)
  259. if hasattr(request, "deserialize"):
  260. try:
  261. obj, _ = request.deserialize(msg, 0)
  262. except Exception:
  263. self.logger.exception(
  264. "Exception raised during deserialization "
  265. "of request: %s",
  266. request,
  267. )
  268. # raise ConnectionDropped so connect loop will retry
  269. raise ConnectionDropped("invalid server response")
  270. self.logger.log(BLATHER, "Read response %s", obj)
  271. return obj, zxid
  272. return zxid
  273. def _submit(self, request, timeout, xid=None):
  274. """Submit a request object with a timeout value and optional
  275. xid"""
  276. b = bytearray()
  277. if xid:
  278. b.extend(int_struct.pack(xid))
  279. if request.type:
  280. b.extend(int_struct.pack(request.type))
  281. b += request.serialize()
  282. self.logger.log(
  283. (BLATHER if isinstance(request, Ping) else logging.DEBUG),
  284. "Sending request(xid=%s): %s",
  285. xid,
  286. request,
  287. )
  288. self._write(int_struct.pack(len(b)) + b, timeout)
  289. def _write(self, msg, timeout):
  290. """Write a raw msg to the socket"""
  291. sent = 0
  292. msg_length = len(msg)
  293. with self._socket_error_handling():
  294. while sent < msg_length:
  295. s = self.handler.select([], [self._socket], [], timeout)[1]
  296. if not s: # pragma: nocover
  297. # If the write list is empty, we got a timeout. We don't
  298. # have to check rlist and xlist as we don't set any
  299. raise self.handler.timeout_exception(
  300. "socket time-out" " during write"
  301. )
  302. msg_slice = buffer(msg, sent)
  303. try:
  304. bytes_sent = self._socket.send(msg_slice)
  305. except ssl.SSLError as e:
  306. if e.errno in (
  307. ssl.SSL_ERROR_WANT_READ,
  308. ssl.SSL_ERROR_WANT_WRITE,
  309. ):
  310. continue
  311. else:
  312. raise
  313. if not bytes_sent:
  314. raise ConnectionDropped("socket connection broken")
  315. sent += bytes_sent
  316. def _read_watch_event(self, buffer, offset):
  317. client = self.client
  318. watch, offset = Watch.deserialize(buffer, offset)
  319. path = watch.path
  320. self.logger.debug("Received EVENT: %s", watch)
  321. watchers = []
  322. if watch.type in (CREATED_EVENT, CHANGED_EVENT):
  323. watchers.extend(client._data_watchers.pop(path, []))
  324. elif watch.type == DELETED_EVENT:
  325. watchers.extend(client._data_watchers.pop(path, []))
  326. watchers.extend(client._child_watchers.pop(path, []))
  327. elif watch.type == CHILD_EVENT:
  328. watchers.extend(client._child_watchers.pop(path, []))
  329. else:
  330. self.logger.warn("Received unknown event %r", watch.type)
  331. return
  332. # Strip the chroot if needed
  333. path = client.unchroot(path)
  334. ev = WatchedEvent(EVENT_TYPE_MAP[watch.type], client._state, path)
  335. # Last check to ignore watches if we've been stopped
  336. if client._stopped.is_set():
  337. return
  338. # Dump the watchers to the watch thread
  339. for watch in watchers:
  340. client.handler.dispatch_callback(Callback("watch", watch, (ev,)))
  341. def _read_response(self, header, buffer, offset):
  342. client = self.client
  343. request, async_object, xid = client._pending.popleft()
  344. if header.zxid and header.zxid > 0:
  345. client.last_zxid = header.zxid
  346. if header.xid != xid:
  347. exc = RuntimeError(
  348. "xids do not match, expected %r " "received %r",
  349. xid,
  350. header.xid,
  351. )
  352. async_object.set_exception(exc)
  353. raise exc
  354. # Determine if its an exists request and a no node error
  355. exists_error = (
  356. header.err == NoNodeError.code and request.type == Exists.type
  357. )
  358. # Set the exception if its not an exists error
  359. if header.err and not exists_error:
  360. callback_exception = EXCEPTIONS[header.err]()
  361. self.logger.debug(
  362. "Received error(xid=%s) %r", xid, callback_exception
  363. )
  364. if async_object:
  365. async_object.set_exception(callback_exception)
  366. elif request and async_object:
  367. if exists_error:
  368. # It's a NoNodeError, which is fine for an exists
  369. # request
  370. async_object.set(None)
  371. else:
  372. try:
  373. response = request.deserialize(buffer, offset)
  374. except Exception as exc:
  375. self.logger.exception(
  376. "Exception raised during deserialization "
  377. "of request: %s",
  378. request,
  379. )
  380. async_object.set_exception(exc)
  381. return
  382. self.logger.debug(
  383. "Received response(xid=%s): %r", xid, response
  384. )
  385. # We special case a Transaction as we have to unchroot things
  386. if request.type == Transaction.type:
  387. response = Transaction.unchroot(client, response)
  388. async_object.set(response)
  389. # Determine if watchers should be registered
  390. watcher = getattr(request, "watcher", None)
  391. if not client._stopped.is_set() and watcher:
  392. if isinstance(request, (GetChildren, GetChildren2)):
  393. client._child_watchers[request.path].add(watcher)
  394. else:
  395. client._data_watchers[request.path].add(watcher)
  396. if isinstance(request, Close):
  397. self.logger.log(BLATHER, "Read close response")
  398. return CLOSE_RESPONSE
  399. def _read_socket(self, read_timeout):
  400. """Called when there's something to read on the socket"""
  401. client = self.client
  402. header, buffer, offset = self._read_header(read_timeout)
  403. if header.xid == PING_XID:
  404. self.logger.log(BLATHER, "Received Ping")
  405. self.ping_outstanding.clear()
  406. elif header.xid == AUTH_XID:
  407. self.logger.log(BLATHER, "Received AUTH")
  408. request, async_object, xid = client._pending.popleft()
  409. if header.err:
  410. async_object.set_exception(AuthFailedError())
  411. client._session_callback(KeeperState.AUTH_FAILED)
  412. else:
  413. async_object.set(True)
  414. elif header.xid == WATCH_XID:
  415. self._read_watch_event(buffer, offset)
  416. else:
  417. self.logger.log(BLATHER, "Reading for header %r", header)
  418. return self._read_response(header, buffer, offset)
  419. def _send_request(self, read_timeout, connect_timeout):
  420. """Called when we have something to send out on the socket"""
  421. client = self.client
  422. try:
  423. request, async_object = client._queue[0]
  424. except IndexError:
  425. # Not actually something on the queue, this can occur if
  426. # something happens to cancel the request such that we
  427. # don't clear the socket below after sending
  428. try:
  429. # Clear possible inconsistence (no request in the queue
  430. # but have data in the read socket), which causes cpu to spin.
  431. self._read_sock.recv(1)
  432. except OSError:
  433. pass
  434. return
  435. # Special case for testing, if this is a _SessionExpire object
  436. # then throw a SessionExpiration error as if we were dropped
  437. if request is _SESSION_EXPIRED:
  438. raise SessionExpiredError("Session expired: Testing")
  439. if request is _CONNECTION_DROP:
  440. raise ConnectionDropped("Connection dropped: Testing")
  441. # Special case for auth packets
  442. if request.type == Auth.type:
  443. xid = AUTH_XID
  444. else:
  445. self._xid = (self._xid % 2147483647) + 1
  446. xid = self._xid
  447. self._submit(request, connect_timeout, xid)
  448. client._queue.popleft()
  449. self._read_sock.recv(1)
  450. client._pending.append((request, async_object, xid))
  451. def _send_ping(self, connect_timeout):
  452. self.ping_outstanding.set()
  453. self._submit(PingInstance, connect_timeout, PING_XID)
  454. # Determine if we need to check for a r/w server
  455. if self._ro_mode:
  456. result = advance_iterator(self._ro_mode)
  457. if result:
  458. self._rw_server = result
  459. raise RWServerAvailable()
  460. def zk_loop(self):
  461. """Main Zookeeper handling loop"""
  462. self.logger.log(BLATHER, "ZK loop started")
  463. self.connection_stopped.clear()
  464. retry = self.retry_sleeper.copy()
  465. try:
  466. while not self.client._stopped.is_set():
  467. # If the connect_loop returns STOP_CONNECTING, stop retrying
  468. if retry(self._connect_loop, retry) is STOP_CONNECTING:
  469. break
  470. except RetryFailedError:
  471. self.logger.warning(
  472. "Failed connecting to Zookeeper "
  473. "within the connection retry policy."
  474. )
  475. finally:
  476. self.connection_stopped.set()
  477. self.client._session_callback(KeeperState.CLOSED)
  478. self.logger.log(BLATHER, "Connection stopped")
  479. def _expand_client_hosts(self):
  480. # Expand the entire list in advance so we can randomize it if needed
  481. host_ports = []
  482. for host, port in self.client.hosts:
  483. try:
  484. host = host.strip()
  485. for rhost in socket.getaddrinfo(
  486. host, port, 0, 0, socket.IPPROTO_TCP
  487. ):
  488. host_ports.append((host, rhost[4][0], rhost[4][1]))
  489. except socket.gaierror as e:
  490. # Skip hosts that don't resolve
  491. self.logger.warning("Cannot resolve %s: %s", host, e)
  492. pass
  493. if self.client.randomize_hosts:
  494. random.shuffle(host_ports)
  495. return host_ports
  496. def _connect_loop(self, retry):
  497. # Iterate through the hosts a full cycle before starting over
  498. status = None
  499. host_ports = self._expand_client_hosts()
  500. # Check for an empty hostlist, indicating none resolved
  501. if len(host_ports) == 0:
  502. raise ForceRetryError("No host resolved. Reconnecting")
  503. for host, hostip, port in host_ports:
  504. if self.client._stopped.is_set():
  505. status = STOP_CONNECTING
  506. break
  507. status = self._connect_attempt(host, hostip, port, retry)
  508. if status is STOP_CONNECTING:
  509. break
  510. if status is STOP_CONNECTING:
  511. return STOP_CONNECTING
  512. else:
  513. raise ForceRetryError("Reconnecting")
  514. def _connect_attempt(self, host, hostip, port, retry):
  515. client = self.client
  516. KazooTimeoutError = self.handler.timeout_exception
  517. self._socket = None
  518. # Were we given a r/w server? If so, use that instead
  519. if self._rw_server:
  520. self.logger.log(
  521. BLATHER, "Found r/w server to use, %s:%s", host, port
  522. )
  523. host, port = self._rw_server
  524. self._rw_server = None
  525. if client._state != KeeperState.CONNECTING:
  526. client._session_callback(KeeperState.CONNECTING)
  527. try:
  528. self._xid = 0
  529. read_timeout, connect_timeout = self._connect(host, hostip, port)
  530. read_timeout = read_timeout / 1000.0
  531. connect_timeout = connect_timeout / 1000.0
  532. retry.reset()
  533. self.ping_outstanding.clear()
  534. last_send = time.time()
  535. with self._socket_error_handling():
  536. while not self.client._stopped.is_set():
  537. # Watch for something to read or send
  538. jitter_time = random.randint(1, 40) / 100.0
  539. deadline = last_send + read_timeout / 2.0 - jitter_time
  540. # Ensure our timeout is positive
  541. timeout = max([deadline - time.time(), jitter_time])
  542. s = self.handler.select(
  543. [self._socket, self._read_sock], [], [], timeout
  544. )[0]
  545. if not s:
  546. if self.ping_outstanding.is_set():
  547. self.ping_outstanding.clear()
  548. raise ConnectionDropped(
  549. "outstanding heartbeat ping not received"
  550. )
  551. else:
  552. if self._socket in s:
  553. response = self._read_socket(read_timeout)
  554. if response == CLOSE_RESPONSE:
  555. break
  556. # Check if any requests need sending before proceeding
  557. # to process more responses. Otherwise the responses
  558. # may choke out the requests. See PR#633.
  559. if self._read_sock in s:
  560. self._send_request(read_timeout, connect_timeout)
  561. # Requests act as implicit pings.
  562. last_send = time.time()
  563. continue
  564. if time.time() >= deadline:
  565. self._send_ping(connect_timeout)
  566. last_send = time.time()
  567. self.logger.info("Closing connection to %s:%s", host, port)
  568. client._session_callback(KeeperState.CLOSED)
  569. return STOP_CONNECTING
  570. except (ConnectionDropped, KazooTimeoutError) as e:
  571. if isinstance(e, ConnectionDropped):
  572. self.logger.warning("Connection dropped: %s", e)
  573. else:
  574. self.logger.warning("Connection time-out: %s", e)
  575. if client._state != KeeperState.CONNECTING:
  576. self.logger.warning("Transition to CONNECTING")
  577. client._session_callback(KeeperState.CONNECTING)
  578. except AuthFailedError as err:
  579. retry.reset()
  580. self.logger.warning("AUTH_FAILED closing: %s", err)
  581. client._session_callback(KeeperState.AUTH_FAILED)
  582. return STOP_CONNECTING
  583. except SessionExpiredError:
  584. retry.reset()
  585. self.logger.warning("Session has expired")
  586. client._session_callback(KeeperState.EXPIRED_SESSION)
  587. except RWServerAvailable:
  588. retry.reset()
  589. self.logger.warning("Found a RW server, dropping connection")
  590. client._session_callback(KeeperState.CONNECTING)
  591. except Exception:
  592. self.logger.exception("Unhandled exception in connection loop")
  593. raise
  594. finally:
  595. if self._socket is not None:
  596. self._socket.close()
  597. def _connect(self, host, hostip, port):
  598. client = self.client
  599. self.logger.info(
  600. "Connecting to %s(%s):%s, use_ssl: %r",
  601. host,
  602. hostip,
  603. port,
  604. self.client.use_ssl,
  605. )
  606. self.logger.log(
  607. BLATHER,
  608. " Using session_id: %r session_passwd: %s",
  609. client._session_id,
  610. hexlify(client._session_passwd),
  611. )
  612. with self._socket_error_handling():
  613. self._socket = self.handler.create_connection(
  614. address=(hostip, port),
  615. timeout=client._session_timeout / 1000.0,
  616. use_ssl=self.client.use_ssl,
  617. keyfile=self.client.keyfile,
  618. certfile=self.client.certfile,
  619. ca=self.client.ca,
  620. keyfile_password=self.client.keyfile_password,
  621. verify_certs=self.client.verify_certs,
  622. )
  623. self._socket.setblocking(0)
  624. connect = Connect(
  625. 0,
  626. client.last_zxid,
  627. client._session_timeout,
  628. client._session_id or 0,
  629. client._session_passwd,
  630. client.read_only,
  631. )
  632. connect_result, zxid = self._invoke(
  633. client._session_timeout / 1000.0 / len(client.hosts), connect
  634. )
  635. if connect_result.time_out <= 0:
  636. raise SessionExpiredError("Session has expired")
  637. if zxid:
  638. client.last_zxid = zxid
  639. # Load return values
  640. client._session_id = connect_result.session_id
  641. client._protocol_version = connect_result.protocol_version
  642. negotiated_session_timeout = connect_result.time_out
  643. connect_timeout = negotiated_session_timeout / len(client.hosts)
  644. read_timeout = negotiated_session_timeout * 2.0 / 3.0
  645. client._session_passwd = connect_result.passwd
  646. self.logger.log(
  647. BLATHER,
  648. "Session created, session_id: %r session_passwd: %s\n"
  649. " negotiated session timeout: %s\n"
  650. " connect timeout: %s\n"
  651. " read timeout: %s",
  652. client._session_id,
  653. hexlify(client._session_passwd),
  654. negotiated_session_timeout,
  655. connect_timeout,
  656. read_timeout,
  657. )
  658. if connect_result.read_only:
  659. client._session_callback(KeeperState.CONNECTED_RO)
  660. self._ro_mode = iter(self._server_pinger())
  661. else:
  662. client._session_callback(KeeperState.CONNECTED)
  663. self._ro_mode = None
  664. if self.sasl_options is not None:
  665. self._authenticate_with_sasl(host, connect_timeout / 1000.0)
  666. # Get a copy of the auth data before iterating, in case it is
  667. # changed.
  668. client_auth_data_copy = copy.copy(client.auth_data)
  669. for scheme, auth in client_auth_data_copy:
  670. ap = Auth(0, scheme, auth)
  671. zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID)
  672. if zxid:
  673. client.last_zxid = zxid
  674. return read_timeout, connect_timeout
  675. def _authenticate_with_sasl(self, host, timeout):
  676. """Establish a SASL authenticated connection to the server."""
  677. if not PURESASL_AVAILABLE:
  678. raise SASLException("Missing SASL support")
  679. if "service" not in self.sasl_options:
  680. self.sasl_options["service"] = "zookeeper"
  681. # NOTE: Zookeeper hardcoded the domain for Digest authentication
  682. # instead of using the hostname. See
  683. # zookeeper/util/SecurityUtils.java#L74 and Server/Client
  684. # initializations.
  685. if self.sasl_options["mechanism"] == "DIGEST-MD5":
  686. host = "zk-sasl-md5"
  687. sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient(
  688. host=host, **self.sasl_options
  689. )
  690. # Inititalize the process with an empty challenge token
  691. challenge = None
  692. xid = 0
  693. while True:
  694. if sasl_cli.complete:
  695. break
  696. try:
  697. response = sasl_cli.process(challenge=challenge)
  698. except puresasl.SASLError as err:
  699. raise SASLException("library error") from err
  700. except puresasl.SASLProtocolException as exc:
  701. raise AuthFailedError("protocol error") from exc
  702. except Exception as exc:
  703. raise AuthFailedError("Unknown error") from exc
  704. if sasl_cli.complete and not response:
  705. break
  706. elif response is None:
  707. response = b""
  708. xid = (xid % 2147483647) + 1
  709. request = SASL(response)
  710. self._submit(request, timeout, xid)
  711. try:
  712. header, buffer, offset = self._read_header(timeout)
  713. except ConnectionDropped as exc:
  714. # Zookeeper simply drops connections with failed authentication
  715. raise AuthFailedError("Connection dropped in SASL") from exc
  716. if header.xid != xid:
  717. raise RuntimeError(
  718. "xids do not match, expected %r " "received %r",
  719. xid,
  720. header.xid,
  721. )
  722. if header.zxid > 0:
  723. self.client.last_zxid = header.zxid
  724. if header.err:
  725. callback_exception = EXCEPTIONS[header.err]()
  726. self.logger.debug(
  727. "Received error(xid=%s) %r", xid, callback_exception
  728. )
  729. raise callback_exception
  730. challenge, _ = SASL.deserialize(buffer, offset)
  731. # If we made it here, authentication is ok, and we are connected.
  732. # Remove sensible information from the object.
  733. sasl_cli.dispose()