Skip to content

Commit 27cd93b

Browse files
authored
Additional BrokerConnection locks to synchronize protocol/IFR state (#1768)
1 parent ed4cab6 commit 27cd93b

File tree

1 file changed

+85
-61
lines changed

1 file changed

+85
-61
lines changed

kafka/conn.py

Lines changed: 85 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,14 @@ def _try_authenticate_plain(self, future):
589589
self.config['sasl_plain_password']]).encode('utf-8'))
590590
size = Int32.encode(len(msg))
591591
try:
592-
self._send_bytes_blocking(size + msg)
592+
with self._lock:
593+
if not self._can_send_recv():
594+
return future.failure(Errors.NodeNotReadyError(str(self)))
595+
self._send_bytes_blocking(size + msg)
593596

594-
# The server will send a zero sized message (that is Int32(0)) on success.
595-
# The connection is closed on failure
596-
data = self._recv_bytes_blocking(4)
597+
# The server will send a zero sized message (that is Int32(0)) on success.
598+
# The connection is closed on failure
599+
data = self._recv_bytes_blocking(4)
597600

598601
except ConnectionError as e:
599602
log.exception("%s: Error receiving reply from server", self)
@@ -617,6 +620,9 @@ def _try_authenticate_gssapi(self, future):
617620
).canonicalize(gssapi.MechType.kerberos)
618621
log.debug('%s: GSSAPI name: %s', self, gssapi_name)
619622

623+
self._lock.acquire()
624+
if not self._can_send_recv():
625+
return future.failure(Errors.NodeNotReadyError(str(self)))
620626
# Establish security context and negotiate protection level
621627
# For reference RFC 2222, section 7.2.1
622628
try:
@@ -659,13 +665,16 @@ def _try_authenticate_gssapi(self, future):
659665
self._send_bytes_blocking(size + msg)
660666

661667
except ConnectionError as e:
668+
self._lock.release()
662669
log.exception("%s: Error receiving reply from server", self)
663670
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
664671
self.close(error=error)
665672
return future.failure(error)
666673
except Exception as e:
674+
self._lock.release()
667675
return future.failure(e)
668676

677+
self._lock.release()
669678
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
670679
return future.success(True)
671680

@@ -674,6 +683,9 @@ def _try_authenticate_oauth(self, future):
674683

675684
msg = bytes(self._build_oauth_client_request().encode("utf-8"))
676685
size = Int32.encode(len(msg))
686+
self._lock.acquire()
687+
if not self._can_send_recv():
688+
return future.failure(Errors.NodeNotReadyError(str(self)))
677689
try:
678690
# Send SASL OAuthBearer request with OAuth token
679691
self._send_bytes_blocking(size + msg)
@@ -683,11 +695,14 @@ def _try_authenticate_oauth(self, future):
683695
data = self._recv_bytes_blocking(4)
684696

685697
except ConnectionError as e:
698+
self._lock.release()
686699
log.exception("%s: Error receiving reply from server", self)
687700
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
688701
self.close(error=error)
689702
return future.failure(error)
690703

704+
self._lock.release()
705+
691706
if data != b'\x00\x00\x00\x00':
692707
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
693708
return future.failure(error)
@@ -787,26 +802,33 @@ def close(self, error=None):
787802
will be failed with this exception.
788803
Default: kafka.errors.KafkaConnectionError.
789804
"""
790-
if self.state is ConnectionStates.DISCONNECTED:
791-
if error is not None:
792-
log.warning('%s: Duplicate close() with error: %s', self, error)
793-
return
794-
log.info('%s: Closing connection. %s', self, error or '')
795-
self.state = ConnectionStates.DISCONNECTING
796-
self.config['state_change_callback'](self)
797-
self._update_reconnect_backoff()
798-
self._close_socket()
799-
self.state = ConnectionStates.DISCONNECTED
800-
self._sasl_auth_future = None
801-
self._protocol = KafkaProtocol(
802-
client_id=self.config['client_id'],
803-
api_version=self.config['api_version'])
804-
if error is None:
805-
error = Errors.Cancelled(str(self))
806-
while self.in_flight_requests:
807-
(_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
805+
with self._lock:
806+
if self.state is ConnectionStates.DISCONNECTED:
807+
return
808+
log.info('%s: Closing connection. %s', self, error or '')
809+
self.state = ConnectionStates.DISCONNECTING
810+
self.config['state_change_callback'](self)
811+
self._update_reconnect_backoff()
812+
self._close_socket()
813+
self.state = ConnectionStates.DISCONNECTED
814+
self._sasl_auth_future = None
815+
self._protocol = KafkaProtocol(
816+
client_id=self.config['client_id'],
817+
api_version=self.config['api_version'])
818+
if error is None:
819+
error = Errors.Cancelled(str(self))
820+
ifrs = list(self.in_flight_requests.items())
821+
self.in_flight_requests.clear()
822+
self.config['state_change_callback'](self)
823+
824+
# drop lock before processing futures
825+
for (_correlation_id, (future, _timestamp)) in ifrs:
808826
future.failure(error)
809-
self.config['state_change_callback'](self)
827+
828+
def _can_send_recv(self):
829+
"""Return True iff socket is ready for requests / responses"""
830+
return self.state in (ConnectionStates.AUTHENTICATING,
831+
ConnectionStates.CONNECTED)
810832

811833
def send(self, request, blocking=True):
812834
"""Queue request for async network send, return Future()"""
@@ -820,18 +842,20 @@ def send(self, request, blocking=True):
820842
return self._send(request, blocking=blocking)
821843

822844
def _send(self, request, blocking=True):
823-
assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
824845
future = Future()
825846
with self._lock:
847+
if not self._can_send_recv():
848+
return future.failure(Errors.NodeNotReadyError(str(self)))
849+
826850
correlation_id = self._protocol.send_request(request)
827851

828-
log.debug('%s Request %d: %s', self, correlation_id, request)
829-
if request.expect_response():
830-
sent_time = time.time()
831-
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
832-
self.in_flight_requests[correlation_id] = (future, sent_time)
833-
else:
834-
future.success(None)
852+
log.debug('%s Request %d: %s', self, correlation_id, request)
853+
if request.expect_response():
854+
sent_time = time.time()
855+
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
856+
self.in_flight_requests[correlation_id] = (future, sent_time)
857+
else:
858+
future.success(None)
835859

836860
# Attempt to replicate behavior from prior to introduction of
837861
# send_pending_requests() / async sends
@@ -842,16 +866,15 @@ def _send(self, request, blocking=True):
842866

843867
def send_pending_requests(self):
844868
"""Can block on network if request is larger than send_buffer_bytes"""
845-
if self.state not in (ConnectionStates.AUTHENTICATING,
846-
ConnectionStates.CONNECTED):
847-
return Errors.NodeNotReadyError(str(self))
848-
with self._lock:
849-
data = self._protocol.send_bytes()
850869
try:
851-
# In the future we might manage an internal write buffer
852-
# and send bytes asynchronously. For now, just block
853-
# sending each request payload
854-
total_bytes = self._send_bytes_blocking(data)
870+
with self._lock:
871+
if not self._can_send_recv():
872+
return Errors.NodeNotReadyError(str(self))
873+
# In the future we might manage an internal write buffer
874+
# and send bytes asynchronously. For now, just block
875+
# sending each request payload
876+
data = self._protocol.send_bytes()
877+
total_bytes = self._send_bytes_blocking(data)
855878
if self._sensors:
856879
self._sensors.bytes_sent.record(total_bytes)
857880
return total_bytes
@@ -871,18 +894,6 @@ def recv(self):
871894
872895
Return list of (response, future) tuples
873896
"""
874-
if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
875-
log.warning('%s cannot recv: socket not connected', self)
876-
# If requests are pending, we should close the socket and
877-
# fail all the pending request futures
878-
if self.in_flight_requests:
879-
self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests'))
880-
return ()
881-
882-
elif not self.in_flight_requests:
883-
log.warning('%s: No in-flight-requests to recv', self)
884-
return ()
885-
886897
responses = self._recv()
887898
if not responses and self.requests_timed_out():
888899
log.warning('%s timed out after %s ms. Closing connection.',
@@ -895,7 +906,8 @@ def recv(self):
895906
# augment respones w/ correlation_id, future, and timestamp
896907
for i, (correlation_id, response) in enumerate(responses):
897908
try:
898-
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
909+
with self._lock:
910+
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
899911
except KeyError:
900912
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
901913
return ()
@@ -911,6 +923,12 @@ def recv(self):
911923
def _recv(self):
912924
"""Take all available bytes from socket, return list of any responses from parser"""
913925
recvd = []
926+
self._lock.acquire()
927+
if not self._can_send_recv():
928+
log.warning('%s cannot recv: socket not connected', self)
929+
self._lock.release()
930+
return ()
931+
914932
while len(recvd) < self.config['sock_chunk_buffer_count']:
915933
try:
916934
data = self._sock.recv(self.config['sock_chunk_bytes'])
@@ -920,6 +938,7 @@ def _recv(self):
920938
# without an exception raised
921939
if not data:
922940
log.error('%s: socket disconnected', self)
941+
self._lock.release()
923942
self.close(error=Errors.KafkaConnectionError('socket disconnected'))
924943
return []
925944
else:
@@ -932,11 +951,13 @@ def _recv(self):
932951
break
933952
log.exception('%s: Error receiving network data'
934953
' closing socket', self)
954+
self._lock.release()
935955
self.close(error=Errors.KafkaConnectionError(e))
936956
return []
937957
except BlockingIOError:
938958
if six.PY3:
939959
break
960+
self._lock.release()
940961
raise
941962

942963
recvd_data = b''.join(recvd)
@@ -946,20 +967,23 @@ def _recv(self):
946967
try:
947968
responses = self._protocol.receive_bytes(recvd_data)
948969
except Errors.KafkaProtocolError as e:
970+
self._lock.release()
949971
self.close(e)
950972
return []
951973
else:
974+
self._lock.release()
952975
return responses
953976

954977
def requests_timed_out(self):
955-
if self.in_flight_requests:
956-
get_timestamp = lambda v: v[1]
957-
oldest_at = min(map(get_timestamp,
958-
self.in_flight_requests.values()))
959-
timeout = self.config['request_timeout_ms'] / 1000.0
960-
if time.time() >= oldest_at + timeout:
961-
return True
962-
return False
978+
with self._lock:
979+
if self.in_flight_requests:
980+
get_timestamp = lambda v: v[1]
981+
oldest_at = min(map(get_timestamp,
982+
self.in_flight_requests.values()))
983+
timeout = self.config['request_timeout_ms'] / 1000.0
984+
if time.time() >= oldest_at + timeout:
985+
return True
986+
return False
963987

964988
def _handle_api_version_response(self, response):
965989
error_type = Errors.for_code(response.error_code)

0 commit comments

Comments
 (0)