@@ -589,11 +589,14 @@ def _try_authenticate_plain(self, future):
589
589
self .config ['sasl_plain_password' ]]).encode ('utf-8' ))
590
590
size = Int32 .encode (len (msg ))
591
591
try :
592
- self ._send_bytes_blocking (size + msg )
592
+ with self ._lock :
593
+ if not self ._can_send_recv ():
594
+ return future .failure (Errors .NodeNotReadyError (str (self )))
595
+ self ._send_bytes_blocking (size + msg )
593
596
594
- # The server will send a zero sized message (that is Int32(0)) on success.
595
- # The connection is closed on failure
596
- data = self ._recv_bytes_blocking (4 )
597
+ # The server will send a zero sized message (that is Int32(0)) on success.
598
+ # The connection is closed on failure
599
+ data = self ._recv_bytes_blocking (4 )
597
600
598
601
except ConnectionError as e :
599
602
log .exception ("%s: Error receiving reply from server" , self )
@@ -617,6 +620,9 @@ def _try_authenticate_gssapi(self, future):
617
620
).canonicalize (gssapi .MechType .kerberos )
618
621
log .debug ('%s: GSSAPI name: %s' , self , gssapi_name )
619
622
623
+ self ._lock .acquire ()
624
+ if not self ._can_send_recv ():
625
+ return future .failure (Errors .NodeNotReadyError (str (self )))
620
626
# Establish security context and negotiate protection level
621
627
# For reference RFC 2222, section 7.2.1
622
628
try :
@@ -659,13 +665,16 @@ def _try_authenticate_gssapi(self, future):
659
665
self ._send_bytes_blocking (size + msg )
660
666
661
667
except ConnectionError as e :
668
+ self ._lock .release ()
662
669
log .exception ("%s: Error receiving reply from server" , self )
663
670
error = Errors .KafkaConnectionError ("%s: %s" % (self , e ))
664
671
self .close (error = error )
665
672
return future .failure (error )
666
673
except Exception as e :
674
+ self ._lock .release ()
667
675
return future .failure (e )
668
676
677
+ self ._lock .release ()
669
678
log .info ('%s: Authenticated as %s via GSSAPI' , self , gssapi_name )
670
679
return future .success (True )
671
680
@@ -674,6 +683,9 @@ def _try_authenticate_oauth(self, future):
674
683
675
684
msg = bytes (self ._build_oauth_client_request ().encode ("utf-8" ))
676
685
size = Int32 .encode (len (msg ))
686
+ self ._lock .acquire ()
687
+ if not self ._can_send_recv ():
688
+ return future .failure (Errors .NodeNotReadyError (str (self )))
677
689
try :
678
690
# Send SASL OAuthBearer request with OAuth token
679
691
self ._send_bytes_blocking (size + msg )
@@ -683,11 +695,14 @@ def _try_authenticate_oauth(self, future):
683
695
data = self ._recv_bytes_blocking (4 )
684
696
685
697
except ConnectionError as e :
698
+ self ._lock .release ()
686
699
log .exception ("%s: Error receiving reply from server" , self )
687
700
error = Errors .KafkaConnectionError ("%s: %s" % (self , e ))
688
701
self .close (error = error )
689
702
return future .failure (error )
690
703
704
+ self ._lock .release ()
705
+
691
706
if data != b'\x00 \x00 \x00 \x00 ' :
692
707
error = Errors .AuthenticationFailedError ('Unrecognized response during authentication' )
693
708
return future .failure (error )
@@ -787,26 +802,33 @@ def close(self, error=None):
787
802
will be failed with this exception.
788
803
Default: kafka.errors.KafkaConnectionError.
789
804
"""
790
- if self .state is ConnectionStates .DISCONNECTED :
791
- if error is not None :
792
- log .warning ('%s: Duplicate close() with error: %s' , self , error )
793
- return
794
- log .info ('%s: Closing connection. %s' , self , error or '' )
795
- self .state = ConnectionStates .DISCONNECTING
796
- self .config ['state_change_callback' ](self )
797
- self ._update_reconnect_backoff ()
798
- self ._close_socket ()
799
- self .state = ConnectionStates .DISCONNECTED
800
- self ._sasl_auth_future = None
801
- self ._protocol = KafkaProtocol (
802
- client_id = self .config ['client_id' ],
803
- api_version = self .config ['api_version' ])
804
- if error is None :
805
- error = Errors .Cancelled (str (self ))
806
- while self .in_flight_requests :
807
- (_correlation_id , (future , _timestamp )) = self .in_flight_requests .popitem ()
805
+ with self ._lock :
806
+ if self .state is ConnectionStates .DISCONNECTED :
807
+ return
808
+ log .info ('%s: Closing connection. %s' , self , error or '' )
809
+ self .state = ConnectionStates .DISCONNECTING
810
+ self .config ['state_change_callback' ](self )
811
+ self ._update_reconnect_backoff ()
812
+ self ._close_socket ()
813
+ self .state = ConnectionStates .DISCONNECTED
814
+ self ._sasl_auth_future = None
815
+ self ._protocol = KafkaProtocol (
816
+ client_id = self .config ['client_id' ],
817
+ api_version = self .config ['api_version' ])
818
+ if error is None :
819
+ error = Errors .Cancelled (str (self ))
820
+ ifrs = list (self .in_flight_requests .items ())
821
+ self .in_flight_requests .clear ()
822
+ self .config ['state_change_callback' ](self )
823
+
824
+ # drop lock before processing futures
825
+ for (_correlation_id , (future , _timestamp )) in ifrs :
808
826
future .failure (error )
809
- self .config ['state_change_callback' ](self )
827
+
828
+ def _can_send_recv (self ):
829
+ """Return True iff socket is ready for requests / responses"""
830
+ return self .state in (ConnectionStates .AUTHENTICATING ,
831
+ ConnectionStates .CONNECTED )
810
832
811
833
def send (self , request , blocking = True ):
812
834
"""Queue request for async network send, return Future()"""
@@ -820,18 +842,20 @@ def send(self, request, blocking=True):
820
842
return self ._send (request , blocking = blocking )
821
843
822
844
def _send (self , request , blocking = True ):
823
- assert self .state in (ConnectionStates .AUTHENTICATING , ConnectionStates .CONNECTED )
824
845
future = Future ()
825
846
with self ._lock :
847
+ if not self ._can_send_recv ():
848
+ return future .failure (Errors .NodeNotReadyError (str (self )))
849
+
826
850
correlation_id = self ._protocol .send_request (request )
827
851
828
- log .debug ('%s Request %d: %s' , self , correlation_id , request )
829
- if request .expect_response ():
830
- sent_time = time .time ()
831
- assert correlation_id not in self .in_flight_requests , 'Correlation ID already in-flight!'
832
- self .in_flight_requests [correlation_id ] = (future , sent_time )
833
- else :
834
- future .success (None )
852
+ log .debug ('%s Request %d: %s' , self , correlation_id , request )
853
+ if request .expect_response ():
854
+ sent_time = time .time ()
855
+ assert correlation_id not in self .in_flight_requests , 'Correlation ID already in-flight!'
856
+ self .in_flight_requests [correlation_id ] = (future , sent_time )
857
+ else :
858
+ future .success (None )
835
859
836
860
# Attempt to replicate behavior from prior to introduction of
837
861
# send_pending_requests() / async sends
@@ -842,16 +866,15 @@ def _send(self, request, blocking=True):
842
866
843
867
def send_pending_requests (self ):
844
868
"""Can block on network if request is larger than send_buffer_bytes"""
845
- if self .state not in (ConnectionStates .AUTHENTICATING ,
846
- ConnectionStates .CONNECTED ):
847
- return Errors .NodeNotReadyError (str (self ))
848
- with self ._lock :
849
- data = self ._protocol .send_bytes ()
850
869
try :
851
- # In the future we might manage an internal write buffer
852
- # and send bytes asynchronously. For now, just block
853
- # sending each request payload
854
- total_bytes = self ._send_bytes_blocking (data )
870
+ with self ._lock :
871
+ if not self ._can_send_recv ():
872
+ return Errors .NodeNotReadyError (str (self ))
873
+ # In the future we might manage an internal write buffer
874
+ # and send bytes asynchronously. For now, just block
875
+ # sending each request payload
876
+ data = self ._protocol .send_bytes ()
877
+ total_bytes = self ._send_bytes_blocking (data )
855
878
if self ._sensors :
856
879
self ._sensors .bytes_sent .record (total_bytes )
857
880
return total_bytes
@@ -871,18 +894,6 @@ def recv(self):
871
894
872
895
Return list of (response, future) tuples
873
896
"""
874
- if not self .connected () and not self .state is ConnectionStates .AUTHENTICATING :
875
- log .warning ('%s cannot recv: socket not connected' , self )
876
- # If requests are pending, we should close the socket and
877
- # fail all the pending request futures
878
- if self .in_flight_requests :
879
- self .close (Errors .KafkaConnectionError ('Socket not connected during recv with in-flight-requests' ))
880
- return ()
881
-
882
- elif not self .in_flight_requests :
883
- log .warning ('%s: No in-flight-requests to recv' , self )
884
- return ()
885
-
886
897
responses = self ._recv ()
887
898
if not responses and self .requests_timed_out ():
888
899
log .warning ('%s timed out after %s ms. Closing connection.' ,
@@ -895,7 +906,8 @@ def recv(self):
895
906
# augment respones w/ correlation_id, future, and timestamp
896
907
for i , (correlation_id , response ) in enumerate (responses ):
897
908
try :
898
- (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
909
+ with self ._lock :
910
+ (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
899
911
except KeyError :
900
912
self .close (Errors .KafkaConnectionError ('Received unrecognized correlation id' ))
901
913
return ()
@@ -911,6 +923,12 @@ def recv(self):
911
923
def _recv (self ):
912
924
"""Take all available bytes from socket, return list of any responses from parser"""
913
925
recvd = []
926
+ self ._lock .acquire ()
927
+ if not self ._can_send_recv ():
928
+ log .warning ('%s cannot recv: socket not connected' , self )
929
+ self ._lock .release ()
930
+ return ()
931
+
914
932
while len (recvd ) < self .config ['sock_chunk_buffer_count' ]:
915
933
try :
916
934
data = self ._sock .recv (self .config ['sock_chunk_bytes' ])
@@ -920,6 +938,7 @@ def _recv(self):
920
938
# without an exception raised
921
939
if not data :
922
940
log .error ('%s: socket disconnected' , self )
941
+ self ._lock .release ()
923
942
self .close (error = Errors .KafkaConnectionError ('socket disconnected' ))
924
943
return []
925
944
else :
@@ -932,11 +951,13 @@ def _recv(self):
932
951
break
933
952
log .exception ('%s: Error receiving network data'
934
953
' closing socket' , self )
954
+ self ._lock .release ()
935
955
self .close (error = Errors .KafkaConnectionError (e ))
936
956
return []
937
957
except BlockingIOError :
938
958
if six .PY3 :
939
959
break
960
+ self ._lock .release ()
940
961
raise
941
962
942
963
recvd_data = b'' .join (recvd )
@@ -946,20 +967,23 @@ def _recv(self):
946
967
try :
947
968
responses = self ._protocol .receive_bytes (recvd_data )
948
969
except Errors .KafkaProtocolError as e :
970
+ self ._lock .release ()
949
971
self .close (e )
950
972
return []
951
973
else :
974
+ self ._lock .release ()
952
975
return responses
953
976
954
977
def requests_timed_out (self ):
955
- if self .in_flight_requests :
956
- get_timestamp = lambda v : v [1 ]
957
- oldest_at = min (map (get_timestamp ,
958
- self .in_flight_requests .values ()))
959
- timeout = self .config ['request_timeout_ms' ] / 1000.0
960
- if time .time () >= oldest_at + timeout :
961
- return True
962
- return False
978
+ with self ._lock :
979
+ if self .in_flight_requests :
980
+ get_timestamp = lambda v : v [1 ]
981
+ oldest_at = min (map (get_timestamp ,
982
+ self .in_flight_requests .values ()))
983
+ timeout = self .config ['request_timeout_ms' ] / 1000.0
984
+ if time .time () >= oldest_at + timeout :
985
+ return True
986
+ return False
963
987
964
988
def _handle_api_version_response (self , response ):
965
989
error_type = Errors .for_code (response .error_code )
0 commit comments