You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
852 lines
30 KiB
852 lines
30 KiB
"""Zookeeper Protocol Connection Handler"""
|
|
from binascii import hexlify
|
|
from contextlib import contextmanager
|
|
import copy
|
|
import logging
|
|
import random
|
|
import select
|
|
import socket
|
|
import ssl
|
|
import sys
|
|
import time
|
|
|
|
from kazoo.exceptions import (
|
|
AuthFailedError,
|
|
ConnectionDropped,
|
|
EXCEPTIONS,
|
|
SessionExpiredError,
|
|
NoNodeError,
|
|
SASLException,
|
|
)
|
|
from kazoo.loggingsupport import BLATHER
|
|
from kazoo.protocol.serialization import (
|
|
Auth,
|
|
Close,
|
|
Connect,
|
|
Exists,
|
|
GetChildren,
|
|
GetChildren2,
|
|
Ping,
|
|
PingInstance,
|
|
ReplyHeader,
|
|
SASL,
|
|
Transaction,
|
|
Watch,
|
|
int_struct,
|
|
)
|
|
from kazoo.protocol.states import (
|
|
Callback,
|
|
KeeperState,
|
|
WatchedEvent,
|
|
EVENT_TYPE_MAP,
|
|
)
|
|
from kazoo.retry import (
|
|
ForceRetryError,
|
|
RetryFailedError,
|
|
)
|
|
|
|
try:
|
|
import puresasl
|
|
import puresasl.client
|
|
|
|
PURESASL_AVAILABLE = True
|
|
except ImportError:
|
|
PURESASL_AVAILABLE = False
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# Special testing hook objects used to force a session expired error as
|
|
# if it came from the server
|
|
_SESSION_EXPIRED = object()
|
|
_CONNECTION_DROP = object()
|
|
|
|
STOP_CONNECTING = object()
|
|
|
|
CREATED_EVENT = 1
|
|
DELETED_EVENT = 2
|
|
CHANGED_EVENT = 3
|
|
CHILD_EVENT = 4
|
|
|
|
WATCH_XID = -1
|
|
PING_XID = -2
|
|
AUTH_XID = -4
|
|
|
|
CLOSE_RESPONSE = Close.type
|
|
|
|
if sys.version_info > (3,): # pragma: nocover
|
|
|
|
def buffer(obj, offset=0):
|
|
return memoryview(obj)[offset:]
|
|
|
|
advance_iterator = next
|
|
else: # pragma: nocover
|
|
|
|
def advance_iterator(it):
|
|
return it.next()
|
|
|
|
|
|
class RWPinger(object):
|
|
"""A Read/Write Server Pinger Iterable
|
|
|
|
This object is initialized with the hosts iterator object and the
|
|
socket creation function. Anytime `next` is called on its iterator
|
|
it yields either False, or a host, port tuple if it found a r/w
|
|
capable Zookeeper node.
|
|
|
|
After the first run-through of hosts, an exponential back-off delay
|
|
is added before the next run. This delay is tracked internally and
|
|
the iterator will yield False if called too soon.
|
|
|
|
"""
|
|
|
|
def __init__(self, hosts, connection_func, socket_handling):
|
|
self.hosts = hosts
|
|
self.connection = connection_func
|
|
self.last_attempt = None
|
|
self.socket_handling = socket_handling
|
|
|
|
def __iter__(self):
|
|
if not self.last_attempt:
|
|
self.last_attempt = time.time()
|
|
delay = 0.5
|
|
while True:
|
|
yield self._next_server(delay)
|
|
|
|
def _next_server(self, delay):
|
|
jitter = random.randint(0, 100) / 100.0
|
|
while time.time() < self.last_attempt + delay + jitter:
|
|
# Skip rw ping checks if its too soon
|
|
return False
|
|
for host, port in self.hosts:
|
|
log.debug("Pinging server for r/w: %s:%s", host, port)
|
|
self.last_attempt = time.time()
|
|
try:
|
|
with self.socket_handling():
|
|
sock = self.connection((host, port))
|
|
sock.sendall(b"isro")
|
|
result = sock.recv(8192)
|
|
sock.close()
|
|
if result == b"rw":
|
|
return (host, port)
|
|
else:
|
|
return False
|
|
except ConnectionDropped:
|
|
return False
|
|
|
|
# Add some jitter between host pings
|
|
while time.time() < self.last_attempt + jitter:
|
|
return False
|
|
delay *= 2
|
|
|
|
|
|
class RWServerAvailable(Exception):
|
|
"""Thrown if a RW Server becomes available"""
|
|
|
|
|
|
class ConnectionHandler(object):
|
|
"""Zookeeper connection handler"""
|
|
|
|
def __init__(self, client, retry_sleeper, logger=None, sasl_options=None):
|
|
self.client = client
|
|
self.handler = client.handler
|
|
self.retry_sleeper = retry_sleeper
|
|
self.logger = logger or log
|
|
|
|
# Our event objects
|
|
self.connection_closed = client.handler.event_object()
|
|
self.connection_closed.set()
|
|
self.connection_stopped = client.handler.event_object()
|
|
self.connection_stopped.set()
|
|
self.ping_outstanding = client.handler.event_object()
|
|
|
|
self._read_sock = None
|
|
self._write_sock = None
|
|
|
|
self._socket = None
|
|
self._xid = None
|
|
self._rw_server = None
|
|
self._ro_mode = False
|
|
|
|
self._connection_routine = None
|
|
|
|
self.sasl_options = sasl_options
|
|
self.sasl_cli = None
|
|
|
|
# This is instance specific to avoid odd thread bug issues in Python
|
|
# during shutdown global cleanup
|
|
@contextmanager
|
|
def _socket_error_handling(self):
|
|
try:
|
|
yield
|
|
except (socket.error, select.error) as e:
|
|
err = getattr(e, "strerror", e)
|
|
raise ConnectionDropped("socket connection error: %s" % (err,))
|
|
|
|
def start(self):
|
|
"""Start the connection up"""
|
|
if self.connection_closed.is_set():
|
|
rw_sockets = self.handler.create_socket_pair()
|
|
self._read_sock, self._write_sock = rw_sockets
|
|
self.connection_closed.clear()
|
|
if self._connection_routine:
|
|
raise Exception(
|
|
"Unable to start, connection routine already " "active."
|
|
)
|
|
self._connection_routine = self.handler.spawn(self.zk_loop)
|
|
|
|
def stop(self, timeout=None):
|
|
"""Ensure the writer has stopped, wait to see if it does."""
|
|
self.connection_stopped.wait(timeout)
|
|
if self._connection_routine:
|
|
self._connection_routine.join()
|
|
self._connection_routine = None
|
|
return self.connection_stopped.is_set()
|
|
|
|
def close(self):
|
|
"""Release resources held by the connection
|
|
|
|
The connection can be restarted afterwards.
|
|
"""
|
|
if not self.connection_stopped.is_set():
|
|
raise Exception("Cannot close connection until it is stopped")
|
|
self.connection_closed.set()
|
|
ws, rs = self._write_sock, self._read_sock
|
|
self._write_sock = self._read_sock = None
|
|
if ws is not None:
|
|
ws.close()
|
|
if rs is not None:
|
|
rs.close()
|
|
|
|
def _server_pinger(self):
|
|
"""Returns a server pinger iterable, that will ping the next
|
|
server in the list, and apply a back-off between attempts."""
|
|
return RWPinger(
|
|
self.client.hosts,
|
|
self.handler.create_connection,
|
|
self._socket_error_handling,
|
|
)
|
|
|
|
def _read_header(self, timeout):
|
|
b = self._read(4, timeout)
|
|
length = int_struct.unpack(b)[0]
|
|
b = self._read(length, timeout)
|
|
header, offset = ReplyHeader.deserialize(b, 0)
|
|
return header, b, offset
|
|
|
|
def _read(self, length, timeout):
|
|
msgparts = []
|
|
remaining = length
|
|
with self._socket_error_handling():
|
|
while remaining > 0:
|
|
# Because of SSL framing, a select may not return when using
|
|
# an SSL socket because the underlying physical socket may not
|
|
# have anything to select, but the wrapped object may still
|
|
# have something to read as it has previously gotten enough
|
|
# data from the underlying socket.
|
|
if (
|
|
hasattr(self._socket, "pending")
|
|
and self._socket.pending() > 0
|
|
):
|
|
pass
|
|
else:
|
|
s = self.handler.select([self._socket], [], [], timeout)[0]
|
|
if not s: # pragma: nocover
|
|
# If the read list is empty, we got a timeout. We don't
|
|
# have to check wlist and xlist as we don't set any
|
|
raise self.handler.timeout_exception(
|
|
"socket time-out during read"
|
|
)
|
|
try:
|
|
chunk = self._socket.recv(remaining)
|
|
except ssl.SSLError as e:
|
|
if e.errno in (
|
|
ssl.SSL_ERROR_WANT_READ,
|
|
ssl.SSL_ERROR_WANT_WRITE,
|
|
):
|
|
continue
|
|
else:
|
|
raise
|
|
if chunk == b"":
|
|
raise ConnectionDropped("socket connection broken")
|
|
msgparts.append(chunk)
|
|
remaining -= len(chunk)
|
|
return b"".join(msgparts)
|
|
|
|
def _invoke(self, timeout, request, xid=None):
|
|
"""A special writer used during connection establishment
|
|
only"""
|
|
self._submit(request, timeout, xid)
|
|
zxid = None
|
|
if xid:
|
|
header, buffer, offset = self._read_header(timeout)
|
|
if header.xid != xid:
|
|
raise RuntimeError(
|
|
"xids do not match, expected %r " "received %r",
|
|
xid,
|
|
header.xid,
|
|
)
|
|
if header.zxid > 0:
|
|
zxid = header.zxid
|
|
if header.err:
|
|
callback_exception = EXCEPTIONS[header.err]()
|
|
self.logger.debug(
|
|
"Received error(xid=%s) %r", xid, callback_exception
|
|
)
|
|
raise callback_exception
|
|
return zxid
|
|
|
|
msg = self._read(4, timeout)
|
|
length = int_struct.unpack(msg)[0]
|
|
msg = self._read(length, timeout)
|
|
|
|
if hasattr(request, "deserialize"):
|
|
try:
|
|
obj, _ = request.deserialize(msg, 0)
|
|
except Exception:
|
|
self.logger.exception(
|
|
"Exception raised during deserialization "
|
|
"of request: %s",
|
|
request,
|
|
)
|
|
|
|
# raise ConnectionDropped so connect loop will retry
|
|
raise ConnectionDropped("invalid server response")
|
|
self.logger.log(BLATHER, "Read response %s", obj)
|
|
return obj, zxid
|
|
|
|
return zxid
|
|
|
|
def _submit(self, request, timeout, xid=None):
|
|
"""Submit a request object with a timeout value and optional
|
|
xid"""
|
|
b = bytearray()
|
|
if xid:
|
|
b.extend(int_struct.pack(xid))
|
|
if request.type:
|
|
b.extend(int_struct.pack(request.type))
|
|
b += request.serialize()
|
|
self.logger.log(
|
|
(BLATHER if isinstance(request, Ping) else logging.DEBUG),
|
|
"Sending request(xid=%s): %s",
|
|
xid,
|
|
request,
|
|
)
|
|
self._write(int_struct.pack(len(b)) + b, timeout)
|
|
|
|
def _write(self, msg, timeout):
|
|
"""Write a raw msg to the socket"""
|
|
sent = 0
|
|
msg_length = len(msg)
|
|
with self._socket_error_handling():
|
|
while sent < msg_length:
|
|
s = self.handler.select([], [self._socket], [], timeout)[1]
|
|
if not s: # pragma: nocover
|
|
# If the write list is empty, we got a timeout. We don't
|
|
# have to check rlist and xlist as we don't set any
|
|
raise self.handler.timeout_exception(
|
|
"socket time-out" " during write"
|
|
)
|
|
msg_slice = buffer(msg, sent)
|
|
try:
|
|
bytes_sent = self._socket.send(msg_slice)
|
|
except ssl.SSLError as e:
|
|
if e.errno in (
|
|
ssl.SSL_ERROR_WANT_READ,
|
|
ssl.SSL_ERROR_WANT_WRITE,
|
|
):
|
|
continue
|
|
else:
|
|
raise
|
|
if not bytes_sent:
|
|
raise ConnectionDropped("socket connection broken")
|
|
sent += bytes_sent
|
|
|
|
def _read_watch_event(self, buffer, offset):
|
|
client = self.client
|
|
watch, offset = Watch.deserialize(buffer, offset)
|
|
path = watch.path
|
|
|
|
self.logger.debug("Received EVENT: %s", watch)
|
|
|
|
watchers = []
|
|
|
|
if watch.type in (CREATED_EVENT, CHANGED_EVENT):
|
|
watchers.extend(client._data_watchers.pop(path, []))
|
|
elif watch.type == DELETED_EVENT:
|
|
watchers.extend(client._data_watchers.pop(path, []))
|
|
watchers.extend(client._child_watchers.pop(path, []))
|
|
elif watch.type == CHILD_EVENT:
|
|
watchers.extend(client._child_watchers.pop(path, []))
|
|
else:
|
|
self.logger.warn("Received unknown event %r", watch.type)
|
|
return
|
|
|
|
# Strip the chroot if needed
|
|
path = client.unchroot(path)
|
|
ev = WatchedEvent(EVENT_TYPE_MAP[watch.type], client._state, path)
|
|
|
|
# Last check to ignore watches if we've been stopped
|
|
if client._stopped.is_set():
|
|
return
|
|
|
|
# Dump the watchers to the watch thread
|
|
for watch in watchers:
|
|
client.handler.dispatch_callback(Callback("watch", watch, (ev,)))
|
|
|
|
def _read_response(self, header, buffer, offset):
|
|
client = self.client
|
|
request, async_object, xid = client._pending.popleft()
|
|
if header.zxid and header.zxid > 0:
|
|
client.last_zxid = header.zxid
|
|
if header.xid != xid:
|
|
exc = RuntimeError(
|
|
"xids do not match, expected %r " "received %r",
|
|
xid,
|
|
header.xid,
|
|
)
|
|
async_object.set_exception(exc)
|
|
raise exc
|
|
|
|
# Determine if its an exists request and a no node error
|
|
exists_error = (
|
|
header.err == NoNodeError.code and request.type == Exists.type
|
|
)
|
|
|
|
# Set the exception if its not an exists error
|
|
if header.err and not exists_error:
|
|
callback_exception = EXCEPTIONS[header.err]()
|
|
self.logger.debug(
|
|
"Received error(xid=%s) %r", xid, callback_exception
|
|
)
|
|
if async_object:
|
|
async_object.set_exception(callback_exception)
|
|
elif request and async_object:
|
|
if exists_error:
|
|
# It's a NoNodeError, which is fine for an exists
|
|
# request
|
|
async_object.set(None)
|
|
else:
|
|
try:
|
|
response = request.deserialize(buffer, offset)
|
|
except Exception as exc:
|
|
self.logger.exception(
|
|
"Exception raised during deserialization "
|
|
"of request: %s",
|
|
request,
|
|
)
|
|
async_object.set_exception(exc)
|
|
return
|
|
self.logger.debug(
|
|
"Received response(xid=%s): %r", xid, response
|
|
)
|
|
|
|
# We special case a Transaction as we have to unchroot things
|
|
if request.type == Transaction.type:
|
|
response = Transaction.unchroot(client, response)
|
|
|
|
async_object.set(response)
|
|
|
|
# Determine if watchers should be registered
|
|
watcher = getattr(request, "watcher", None)
|
|
if not client._stopped.is_set() and watcher:
|
|
if isinstance(request, (GetChildren, GetChildren2)):
|
|
client._child_watchers[request.path].add(watcher)
|
|
else:
|
|
client._data_watchers[request.path].add(watcher)
|
|
|
|
if isinstance(request, Close):
|
|
self.logger.log(BLATHER, "Read close response")
|
|
return CLOSE_RESPONSE
|
|
|
|
def _read_socket(self, read_timeout):
|
|
"""Called when there's something to read on the socket"""
|
|
client = self.client
|
|
|
|
header, buffer, offset = self._read_header(read_timeout)
|
|
if header.xid == PING_XID:
|
|
self.logger.log(BLATHER, "Received Ping")
|
|
self.ping_outstanding.clear()
|
|
elif header.xid == AUTH_XID:
|
|
self.logger.log(BLATHER, "Received AUTH")
|
|
|
|
request, async_object, xid = client._pending.popleft()
|
|
if header.err:
|
|
async_object.set_exception(AuthFailedError())
|
|
client._session_callback(KeeperState.AUTH_FAILED)
|
|
else:
|
|
async_object.set(True)
|
|
elif header.xid == WATCH_XID:
|
|
self._read_watch_event(buffer, offset)
|
|
else:
|
|
self.logger.log(BLATHER, "Reading for header %r", header)
|
|
|
|
return self._read_response(header, buffer, offset)
|
|
|
|
def _send_request(self, read_timeout, connect_timeout):
|
|
"""Called when we have something to send out on the socket"""
|
|
client = self.client
|
|
try:
|
|
request, async_object = client._queue[0]
|
|
except IndexError:
|
|
# Not actually something on the queue, this can occur if
|
|
# something happens to cancel the request such that we
|
|
# don't clear the socket below after sending
|
|
try:
|
|
# Clear possible inconsistence (no request in the queue
|
|
# but have data in the read socket), which causes cpu to spin.
|
|
self._read_sock.recv(1)
|
|
except OSError:
|
|
pass
|
|
return
|
|
|
|
# Special case for testing, if this is a _SessionExpire object
|
|
# then throw a SessionExpiration error as if we were dropped
|
|
if request is _SESSION_EXPIRED:
|
|
raise SessionExpiredError("Session expired: Testing")
|
|
if request is _CONNECTION_DROP:
|
|
raise ConnectionDropped("Connection dropped: Testing")
|
|
|
|
# Special case for auth packets
|
|
if request.type == Auth.type:
|
|
xid = AUTH_XID
|
|
else:
|
|
self._xid = (self._xid % 2147483647) + 1
|
|
xid = self._xid
|
|
|
|
self._submit(request, connect_timeout, xid)
|
|
client._queue.popleft()
|
|
self._read_sock.recv(1)
|
|
client._pending.append((request, async_object, xid))
|
|
|
|
def _send_ping(self, connect_timeout):
|
|
self.ping_outstanding.set()
|
|
self._submit(PingInstance, connect_timeout, PING_XID)
|
|
|
|
# Determine if we need to check for a r/w server
|
|
if self._ro_mode:
|
|
result = advance_iterator(self._ro_mode)
|
|
if result:
|
|
self._rw_server = result
|
|
raise RWServerAvailable()
|
|
|
|
def zk_loop(self):
|
|
"""Main Zookeeper handling loop"""
|
|
self.logger.log(BLATHER, "ZK loop started")
|
|
|
|
self.connection_stopped.clear()
|
|
|
|
retry = self.retry_sleeper.copy()
|
|
try:
|
|
while not self.client._stopped.is_set():
|
|
# If the connect_loop returns STOP_CONNECTING, stop retrying
|
|
if retry(self._connect_loop, retry) is STOP_CONNECTING:
|
|
break
|
|
except RetryFailedError:
|
|
self.logger.warning(
|
|
"Failed connecting to Zookeeper "
|
|
"within the connection retry policy."
|
|
)
|
|
finally:
|
|
self.connection_stopped.set()
|
|
self.client._session_callback(KeeperState.CLOSED)
|
|
self.logger.log(BLATHER, "Connection stopped")
|
|
|
|
def _expand_client_hosts(self):
|
|
# Expand the entire list in advance so we can randomize it if needed
|
|
host_ports = []
|
|
for host, port in self.client.hosts:
|
|
try:
|
|
host = host.strip()
|
|
for rhost in socket.getaddrinfo(
|
|
host, port, 0, 0, socket.IPPROTO_TCP
|
|
):
|
|
host_ports.append((host, rhost[4][0], rhost[4][1]))
|
|
except socket.gaierror as e:
|
|
# Skip hosts that don't resolve
|
|
self.logger.warning("Cannot resolve %s: %s", host, e)
|
|
pass
|
|
if self.client.randomize_hosts:
|
|
random.shuffle(host_ports)
|
|
return host_ports
|
|
|
|
def _connect_loop(self, retry):
|
|
# Iterate through the hosts a full cycle before starting over
|
|
status = None
|
|
host_ports = self._expand_client_hosts()
|
|
|
|
# Check for an empty hostlist, indicating none resolved
|
|
if len(host_ports) == 0:
|
|
raise ForceRetryError("No host resolved. Reconnecting")
|
|
|
|
for host, hostip, port in host_ports:
|
|
if self.client._stopped.is_set():
|
|
status = STOP_CONNECTING
|
|
break
|
|
status = self._connect_attempt(host, hostip, port, retry)
|
|
if status is STOP_CONNECTING:
|
|
break
|
|
|
|
if status is STOP_CONNECTING:
|
|
return STOP_CONNECTING
|
|
else:
|
|
raise ForceRetryError("Reconnecting")
|
|
|
|
def _connect_attempt(self, host, hostip, port, retry):
|
|
client = self.client
|
|
KazooTimeoutError = self.handler.timeout_exception
|
|
|
|
self._socket = None
|
|
|
|
# Were we given a r/w server? If so, use that instead
|
|
if self._rw_server:
|
|
self.logger.log(
|
|
BLATHER, "Found r/w server to use, %s:%s", host, port
|
|
)
|
|
host, port = self._rw_server
|
|
self._rw_server = None
|
|
|
|
if client._state != KeeperState.CONNECTING:
|
|
client._session_callback(KeeperState.CONNECTING)
|
|
|
|
try:
|
|
self._xid = 0
|
|
read_timeout, connect_timeout = self._connect(host, hostip, port)
|
|
read_timeout = read_timeout / 1000.0
|
|
connect_timeout = connect_timeout / 1000.0
|
|
retry.reset()
|
|
self.ping_outstanding.clear()
|
|
last_send = time.time()
|
|
with self._socket_error_handling():
|
|
while not self.client._stopped.is_set():
|
|
# Watch for something to read or send
|
|
jitter_time = random.randint(1, 40) / 100.0
|
|
deadline = last_send + read_timeout / 2.0 - jitter_time
|
|
# Ensure our timeout is positive
|
|
timeout = max([deadline - time.time(), jitter_time])
|
|
s = self.handler.select(
|
|
[self._socket, self._read_sock], [], [], timeout
|
|
)[0]
|
|
|
|
if not s:
|
|
if self.ping_outstanding.is_set():
|
|
self.ping_outstanding.clear()
|
|
raise ConnectionDropped(
|
|
"outstanding heartbeat ping not received"
|
|
)
|
|
else:
|
|
if self._socket in s:
|
|
response = self._read_socket(read_timeout)
|
|
if response == CLOSE_RESPONSE:
|
|
break
|
|
# Check if any requests need sending before proceeding
|
|
# to process more responses. Otherwise the responses
|
|
# may choke out the requests. See PR#633.
|
|
if self._read_sock in s:
|
|
self._send_request(read_timeout, connect_timeout)
|
|
# Requests act as implicit pings.
|
|
last_send = time.time()
|
|
continue
|
|
|
|
if time.time() >= deadline:
|
|
self._send_ping(connect_timeout)
|
|
last_send = time.time()
|
|
self.logger.info("Closing connection to %s:%s", host, port)
|
|
client._session_callback(KeeperState.CLOSED)
|
|
return STOP_CONNECTING
|
|
except (ConnectionDropped, KazooTimeoutError) as e:
|
|
if isinstance(e, ConnectionDropped):
|
|
self.logger.warning("Connection dropped: %s", e)
|
|
else:
|
|
self.logger.warning("Connection time-out: %s", e)
|
|
if client._state != KeeperState.CONNECTING:
|
|
self.logger.warning("Transition to CONNECTING")
|
|
client._session_callback(KeeperState.CONNECTING)
|
|
except AuthFailedError as err:
|
|
retry.reset()
|
|
self.logger.warning("AUTH_FAILED closing: %s", err)
|
|
client._session_callback(KeeperState.AUTH_FAILED)
|
|
return STOP_CONNECTING
|
|
except SessionExpiredError:
|
|
retry.reset()
|
|
self.logger.warning("Session has expired")
|
|
client._session_callback(KeeperState.EXPIRED_SESSION)
|
|
except RWServerAvailable:
|
|
retry.reset()
|
|
self.logger.warning("Found a RW server, dropping connection")
|
|
client._session_callback(KeeperState.CONNECTING)
|
|
except Exception:
|
|
self.logger.exception("Unhandled exception in connection loop")
|
|
raise
|
|
finally:
|
|
if self._socket is not None:
|
|
self._socket.close()
|
|
|
|
def _connect(self, host, hostip, port):
|
|
client = self.client
|
|
self.logger.info(
|
|
"Connecting to %s(%s):%s, use_ssl: %r",
|
|
host,
|
|
hostip,
|
|
port,
|
|
self.client.use_ssl,
|
|
)
|
|
|
|
self.logger.log(
|
|
BLATHER,
|
|
" Using session_id: %r session_passwd: %s",
|
|
client._session_id,
|
|
hexlify(client._session_passwd),
|
|
)
|
|
|
|
with self._socket_error_handling():
|
|
self._socket = self.handler.create_connection(
|
|
address=(hostip, port),
|
|
timeout=client._session_timeout / 1000.0,
|
|
use_ssl=self.client.use_ssl,
|
|
keyfile=self.client.keyfile,
|
|
certfile=self.client.certfile,
|
|
ca=self.client.ca,
|
|
keyfile_password=self.client.keyfile_password,
|
|
verify_certs=self.client.verify_certs,
|
|
)
|
|
|
|
self._socket.setblocking(0)
|
|
|
|
connect = Connect(
|
|
0,
|
|
client.last_zxid,
|
|
client._session_timeout,
|
|
client._session_id or 0,
|
|
client._session_passwd,
|
|
client.read_only,
|
|
)
|
|
|
|
connect_result, zxid = self._invoke(
|
|
client._session_timeout / 1000.0 / len(client.hosts), connect
|
|
)
|
|
|
|
if connect_result.time_out <= 0:
|
|
raise SessionExpiredError("Session has expired")
|
|
|
|
if zxid:
|
|
client.last_zxid = zxid
|
|
|
|
# Load return values
|
|
client._session_id = connect_result.session_id
|
|
client._protocol_version = connect_result.protocol_version
|
|
negotiated_session_timeout = connect_result.time_out
|
|
connect_timeout = negotiated_session_timeout / len(client.hosts)
|
|
read_timeout = negotiated_session_timeout * 2.0 / 3.0
|
|
client._session_passwd = connect_result.passwd
|
|
|
|
self.logger.log(
|
|
BLATHER,
|
|
"Session created, session_id: %r session_passwd: %s\n"
|
|
" negotiated session timeout: %s\n"
|
|
" connect timeout: %s\n"
|
|
" read timeout: %s",
|
|
client._session_id,
|
|
hexlify(client._session_passwd),
|
|
negotiated_session_timeout,
|
|
connect_timeout,
|
|
read_timeout,
|
|
)
|
|
|
|
if connect_result.read_only:
|
|
client._session_callback(KeeperState.CONNECTED_RO)
|
|
self._ro_mode = iter(self._server_pinger())
|
|
else:
|
|
client._session_callback(KeeperState.CONNECTED)
|
|
self._ro_mode = None
|
|
|
|
if self.sasl_options is not None:
|
|
self._authenticate_with_sasl(host, connect_timeout / 1000.0)
|
|
|
|
# Get a copy of the auth data before iterating, in case it is
|
|
# changed.
|
|
client_auth_data_copy = copy.copy(client.auth_data)
|
|
|
|
for scheme, auth in client_auth_data_copy:
|
|
ap = Auth(0, scheme, auth)
|
|
zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID)
|
|
if zxid:
|
|
client.last_zxid = zxid
|
|
|
|
return read_timeout, connect_timeout
|
|
|
|
def _authenticate_with_sasl(self, host, timeout):
|
|
"""Establish a SASL authenticated connection to the server."""
|
|
if not PURESASL_AVAILABLE:
|
|
raise SASLException("Missing SASL support")
|
|
|
|
if "service" not in self.sasl_options:
|
|
self.sasl_options["service"] = "zookeeper"
|
|
|
|
# NOTE: Zookeeper hardcoded the domain for Digest authentication
|
|
# instead of using the hostname. See
|
|
# zookeeper/util/SecurityUtils.java#L74 and Server/Client
|
|
# initializations.
|
|
if self.sasl_options["mechanism"] == "DIGEST-MD5":
|
|
host = "zk-sasl-md5"
|
|
|
|
sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient(
|
|
host=host, **self.sasl_options
|
|
)
|
|
|
|
# Inititalize the process with an empty challenge token
|
|
challenge = None
|
|
xid = 0
|
|
|
|
while True:
|
|
if sasl_cli.complete:
|
|
break
|
|
|
|
try:
|
|
response = sasl_cli.process(challenge=challenge)
|
|
except puresasl.SASLError as err:
|
|
raise SASLException("library error") from err
|
|
except puresasl.SASLProtocolException as exc:
|
|
raise AuthFailedError("protocol error") from exc
|
|
except Exception as exc:
|
|
raise AuthFailedError("Unknown error") from exc
|
|
|
|
if sasl_cli.complete and not response:
|
|
break
|
|
elif response is None:
|
|
response = b""
|
|
|
|
xid = (xid % 2147483647) + 1
|
|
|
|
request = SASL(response)
|
|
self._submit(request, timeout, xid)
|
|
|
|
try:
|
|
header, buffer, offset = self._read_header(timeout)
|
|
except ConnectionDropped as exc:
|
|
# Zookeeper simply drops connections with failed authentication
|
|
raise AuthFailedError("Connection dropped in SASL") from exc
|
|
|
|
if header.xid != xid:
|
|
raise RuntimeError(
|
|
"xids do not match, expected %r " "received %r",
|
|
xid,
|
|
header.xid,
|
|
)
|
|
|
|
if header.zxid > 0:
|
|
self.client.last_zxid = header.zxid
|
|
|
|
if header.err:
|
|
callback_exception = EXCEPTIONS[header.err]()
|
|
self.logger.debug(
|
|
"Received error(xid=%s) %r", xid, callback_exception
|
|
)
|
|
raise callback_exception
|
|
|
|
challenge, _ = SASL.deserialize(buffer, offset)
|
|
|
|
# If we made it here, authentication is ok, and we are connected.
|
|
# Remove sensible information from the object.
|
|
sasl_cli.dispose()
|