Skip to content

Send socket data via non-blocking IO with send buffer #1912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(self, **configs):
self._conns = Dict() # object to support weakrefs
self._api_versions = None
self._connecting = set()
self._sending = set()
self._refresh_on_disconnects = True
self._last_bootstrap = 0
self._bootstrap_fails = 0
Expand Down Expand Up @@ -532,6 +533,7 @@ def send(self, node_id, request, wakeup=True):
# we will need to call send_pending_requests()
# to trigger network I/O
future = conn.send(request, blocking=False)
self._sending.add(conn)

# Wakeup signal is useful in case another thread is
# blocked waiting for incoming network traffic while holding
Expand Down Expand Up @@ -604,14 +606,23 @@ def poll(self, timeout_ms=None, future=None):

return responses

def _register_send_sockets(self):
while self._sending:
conn = self._sending.pop()
try:
key = self._selector.get_key(conn._sock)
events = key.events | selectors.EVENT_WRITE
self._selector.modify(key.fileobj, events, key.data)
except KeyError:
self._selector.register(conn._sock, selectors.EVENT_WRITE, conn)

def _poll(self, timeout):
# This needs to be locked, but since it is only called from within the
# locked section of poll(), there is no additional lock acquisition here
processed = set()

# Send pending requests first, before polling for responses
for conn in six.itervalues(self._conns):
conn.send_pending_requests()
self._register_send_sockets()

start_select = time.time()
ready = self._selector.select(timeout)
Expand All @@ -623,10 +634,24 @@ def _poll(self, timeout):
if key.fileobj is self._wake_r:
self._clear_wake_fd()
continue

# Send pending requests if socket is ready to write
if events & selectors.EVENT_WRITE:
conn = key.data
if conn.connecting():
conn.connect()
else:
if conn.send_pending_requests_v2():
# If send is complete, we dont need to track write readiness
# for this socket anymore
if key.events ^ selectors.EVENT_WRITE:
self._selector.modify(
key.fileobj,
key.events ^ selectors.EVENT_WRITE,
key.data)
else:
self._selector.unregister(key.fileobj)

if not (events & selectors.EVENT_READ):
continue
conn = key.data
Expand Down
80 changes: 72 additions & 8 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def __init__(self, host, port, afi, **configs):
self.state = ConnectionStates.DISCONNECTED
self._reset_reconnect_backoff()
self._sock = None
self._send_buffer = b''
self._ssl_context = None
if self.config['ssl_context'] is not None:
self._ssl_context = self.config['ssl_context']
Expand Down Expand Up @@ -557,6 +558,32 @@ def _handle_sasl_handshake_response(self, future, response):
'kafka-python does not support SASL mechanism %s' %
self.config['sasl_mechanism']))

def _send_bytes(self, data):
"""Send some data via non-blocking IO

Note: this method is not synchronized internally; you should
always hold the _lock before calling

Returns: number of bytes
Raises: socket exception
"""
total_sent = 0
while total_sent < len(data):
try:
sent_bytes = self._sock.send(data[total_sent:])
total_sent += sent_bytes
except (SSLWantReadError, SSLWantWriteError):
break
except (ConnectionError, TimeoutError) as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK:
break
raise
except BlockingIOError:
if six.PY3:
break
raise
return total_sent

def _send_bytes_blocking(self, data):
self._sock.settimeout(self.config['request_timeout_ms'] / 1000)
total_sent = 0
Expand Down Expand Up @@ -839,6 +866,7 @@ def close(self, error=None):
self._protocol = KafkaProtocol(
client_id=self.config['client_id'],
api_version=self.config['api_version'])
self._send_buffer = b''
if error is None:
error = Errors.Cancelled(str(self))
ifrs = list(self.in_flight_requests.items())
Expand Down Expand Up @@ -901,24 +929,60 @@ def _send(self, request, blocking=True):
return future

def send_pending_requests(self):
"""Can block on network if request is larger than send_buffer_bytes"""
"""Attempts to send pending requests messages via blocking IO
If all requests have been sent, return True
Otherwise, if the socket is blocked and there are more bytes to send,
return False.
"""
try:
with self._lock:
if not self._can_send_recv():
return Errors.NodeNotReadyError(str(self))
# In the future we might manage an internal write buffer
# and send bytes asynchronously. For now, just block
# sending each request payload
return False
data = self._protocol.send_bytes()
total_bytes = self._send_bytes_blocking(data)

if self._sensors:
self._sensors.bytes_sent.record(total_bytes)
return total_bytes
return True

except (ConnectionError, TimeoutError) as e:
log.exception("Error sending request data to %s", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return error
return False

def send_pending_requests_v2(self):
"""Attempts to send pending requests messages via non-blocking IO
If all requests have been sent, return True
Otherwise, if the socket is blocked and there are more bytes to send,
return False.
"""
try:
with self._lock:
if not self._can_send_recv():
return False

# _protocol.send_bytes returns encoded requests to send
# we send them via _send_bytes()
# and hold leftover bytes in _send_buffer
if not self._send_buffer:
self._send_buffer = self._protocol.send_bytes()

total_bytes = 0
if self._send_buffer:
total_bytes = self._send_bytes(self._send_buffer)
self._send_buffer = self._send_buffer[total_bytes:]

if self._sensors:
self._sensors.bytes_sent.record(total_bytes)
# Return True iff send buffer is empty
return len(self._send_buffer) == 0

except (ConnectionError, TimeoutError, Exception) as e:
log.exception("Error sending request data to %s", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return False

def can_send_more(self):
"""Return True unless there are max_in_flight_requests_per_connection."""
Expand Down Expand Up @@ -979,7 +1043,7 @@ def _recv(self):
else:
recvd.append(data)

except SSLWantReadError:
except (SSLWantReadError, SSLWantWriteError):
break
except (ConnectionError, TimeoutError) as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK:
Expand Down
8 changes: 6 additions & 2 deletions kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,15 @@ def _poll_once(self, timeout_ms, max_records, update_offsets=True):
# responses to enable pipelining while the user is handling the
# fetched records.
if not partial:
self._fetcher.send_fetches()
futures = self._fetcher.send_fetches()
if len(futures):
self._client.poll(timeout_ms=0)
return records

# Send any new fetches (won't resend pending fetches)
self._fetcher.send_fetches()
futures = self._fetcher.send_fetches()
if len(futures):
self._client.poll(timeout_ms=0)

timeout_ms = min(timeout_ms, self._coordinator.time_to_next_poll() * 1000)
self._client.poll(timeout_ms=timeout_ms)
Expand Down
4 changes: 3 additions & 1 deletion test/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
@pytest.fixture
def cli(mocker, conn):
client = KafkaClient(api_version=(0, 9))
mocker.patch.object(client, '_selector')
client.poll(future=client.cluster.request_update())
return client


def test_bootstrap(mocker, conn):
conn.state = ConnectionStates.CONNECTED
cli = KafkaClient(api_version=(0, 9))
mocker.patch.object(cli, '_selector')
future = cli.cluster.request_update()
cli.poll(future=future)

Expand Down Expand Up @@ -86,7 +88,7 @@ def test_maybe_connect(cli, conn):


def test_conn_state_change(mocker, cli, conn):
sel = mocker.patch.object(cli, '_selector')
sel = cli._selector

node_id = 0
cli._conns[node_id] = conn
Expand Down