Skip to content

Commit 5d1d424

Browse files
authored
Wrap consumer.poll() for KafkaConsumer iteration (#1902)
1 parent a9f513c commit 5d1d424

File tree

3 files changed

+74
-11
lines changed

3 files changed

+74
-11
lines changed

kafka/consumer/fetcher.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")):
292292
raise Errors.KafkaTimeoutError(
293293
"Failed to get offsets by timestamps in %s ms" % (timeout_ms,))
294294

295-
def fetched_records(self, max_records=None):
295+
def fetched_records(self, max_records=None, update_offsets=True):
296296
"""Returns previously fetched records and updates consumed offsets.
297297
298298
Arguments:
@@ -330,10 +330,11 @@ def fetched_records(self, max_records=None):
330330
else:
331331
records_remaining -= self._append(drained,
332332
self._next_partition_records,
333-
records_remaining)
333+
records_remaining,
334+
update_offsets)
334335
return dict(drained), bool(self._completed_fetches)
335336

336-
def _append(self, drained, part, max_records):
337+
def _append(self, drained, part, max_records, update_offsets):
337338
if not part:
338339
return 0
339340

@@ -366,7 +367,8 @@ def _append(self, drained, part, max_records):
366367
for record in part_records:
367368
drained[tp].append(record)
368369

369-
self._subscriptions.assignment[tp].position = next_offset
370+
if update_offsets:
371+
self._subscriptions.assignment[tp].position = next_offset
370372
return len(part_records)
371373

372374
else:

kafka/consumer/group.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ class KafkaConsumer(six.Iterator):
302302
'sasl_plain_password': None,
303303
'sasl_kerberos_service_name': 'kafka',
304304
'sasl_kerberos_domain_name': None,
305-
'sasl_oauth_token_provider': None
305+
'sasl_oauth_token_provider': None,
306+
'legacy_iterator': False, # enable to revert to < 1.4.7 iterator
306307
}
307308
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000
308309

@@ -597,7 +598,7 @@ def partitions_for_topic(self, topic):
597598
partitions = cluster.partitions_for_topic(topic)
598599
return partitions
599600

600-
def poll(self, timeout_ms=0, max_records=None):
601+
def poll(self, timeout_ms=0, max_records=None, update_offsets=True):
601602
"""Fetch data from assigned topics / partitions.
602603
603604
Records are fetched and returned in batches by topic-partition.
@@ -621,6 +622,12 @@ def poll(self, timeout_ms=0, max_records=None):
621622
dict: Topic to list of records since the last fetch for the
622623
subscribed list of topics and partitions.
623624
"""
625+
# Note: update_offsets is an internal-use only argument. It is used to
626+
# support the python iterator interface, and which wraps consumer.poll()
627+
# and requires that the partition offsets tracked by the fetcher are not
628+
# updated until the iterator returns each record to the user. As such,
629+
# the argument is not documented and should not be relied on by library
630+
# users to not break in the future.
624631
assert timeout_ms >= 0, 'Timeout must not be negative'
625632
if max_records is None:
626633
max_records = self.config['max_poll_records']
@@ -631,7 +638,7 @@ def poll(self, timeout_ms=0, max_records=None):
631638
start = time.time()
632639
remaining = timeout_ms
633640
while True:
634-
records = self._poll_once(remaining, max_records)
641+
records = self._poll_once(remaining, max_records, update_offsets=update_offsets)
635642
if records:
636643
return records
637644

@@ -641,7 +648,7 @@ def poll(self, timeout_ms=0, max_records=None):
641648
if remaining <= 0:
642649
return {}
643650

644-
def _poll_once(self, timeout_ms, max_records):
651+
def _poll_once(self, timeout_ms, max_records, update_offsets=True):
645652
"""Do one round of polling. In addition to checking for new data, this does
646653
any needed heart-beating, auto-commits, and offset updates.
647654
@@ -660,7 +667,7 @@ def _poll_once(self, timeout_ms, max_records):
660667

661668
# If data is available already, e.g. from a previous network client
662669
# poll() call to commit, then just return it immediately
663-
records, partial = self._fetcher.fetched_records(max_records)
670+
records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets)
664671
if records:
665672
# Before returning the fetched records, we can send off the
666673
# next round of fetches and avoid block waiting for their
@@ -680,7 +687,7 @@ def _poll_once(self, timeout_ms, max_records):
680687
if self._coordinator.need_rejoin():
681688
return {}
682689

683-
records, _ = self._fetcher.fetched_records(max_records)
690+
records, _ = self._fetcher.fetched_records(max_records, update_offsets=update_offsets)
684691
return records
685692

686693
def position(self, partition):
@@ -743,6 +750,9 @@ def pause(self, *partitions):
743750
for partition in partitions:
744751
log.debug("Pausing partition %s", partition)
745752
self._subscription.pause(partition)
753+
# Because the iterator checks is_fetchable() on each iteration
754+
# we expect pauses to get handled automatically and therefore
755+
# we do not need to reset the full iterator (forcing a full refetch)
746756

747757
def paused(self):
748758
"""Get the partitions that were previously paused using
@@ -790,6 +800,8 @@ def seek(self, partition, offset):
790800
assert partition in self._subscription.assigned_partitions(), 'Unassigned partition'
791801
log.debug("Seeking to offset %s for partition %s", offset, partition)
792802
self._subscription.assignment[partition].seek(offset)
803+
if not self.config['legacy_iterator']:
804+
self._iterator = None
793805

794806
def seek_to_beginning(self, *partitions):
795807
"""Seek to the oldest available offset for partitions.
@@ -814,6 +826,8 @@ def seek_to_beginning(self, *partitions):
814826
for tp in partitions:
815827
log.debug("Seeking to beginning of partition %s", tp)
816828
self._subscription.need_offset_reset(tp, OffsetResetStrategy.EARLIEST)
829+
if not self.config['legacy_iterator']:
830+
self._iterator = None
817831

818832
def seek_to_end(self, *partitions):
819833
"""Seek to the most recent available offset for partitions.
@@ -838,6 +852,8 @@ def seek_to_end(self, *partitions):
838852
for tp in partitions:
839853
log.debug("Seeking to end of partition %s", tp)
840854
self._subscription.need_offset_reset(tp, OffsetResetStrategy.LATEST)
855+
if not self.config['legacy_iterator']:
856+
self._iterator = None
841857

842858
def subscribe(self, topics=(), pattern=None, listener=None):
843859
"""Subscribe to a list of topics, or a topic regex pattern.
@@ -913,6 +929,8 @@ def unsubscribe(self):
913929
self._client.cluster.need_all_topic_metadata = False
914930
self._client.set_topics([])
915931
log.debug("Unsubscribed all topics or patterns and assigned partitions")
932+
if not self.config['legacy_iterator']:
933+
self._iterator = None
916934

917935
def metrics(self, raw=False):
918936
"""Get metrics on consumer performance.
@@ -1075,6 +1093,25 @@ def _update_fetch_positions(self, partitions):
10751093
# Then, do any offset lookups in case some positions are not known
10761094
self._fetcher.update_fetch_positions(partitions)
10771095

1096+
def _message_generator_v2(self):
1097+
timeout_ms = 1000 * (self._consumer_timeout - time.time())
1098+
record_map = self.poll(timeout_ms=timeout_ms, update_offsets=False)
1099+
for tp, records in six.iteritems(record_map):
1100+
# Generators are stateful, and it is possible that the tp / records
1101+
# here may become stale during iteration -- i.e., we seek to a
1102+
# different offset, pause consumption, or lose assignment.
1103+
for record in records:
1104+
# is_fetchable(tp) should handle assignment changes and offset
1105+
# resets; for all other changes (e.g., seeks) we'll rely on the
1106+
# outer function destroying the existing iterator/generator
1107+
# via self._iterator = None
1108+
if not self._subscription.is_fetchable(tp):
1109+
log.debug("Not returning fetched records for partition %s"
1110+
" since it is no longer fetchable", tp)
1111+
break
1112+
self._subscription.assignment[tp].position = record.offset + 1
1113+
yield record
1114+
10781115
def _message_generator(self):
10791116
assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment'
10801117
while time.time() < self._consumer_timeout:
@@ -1127,6 +1164,26 @@ def __iter__(self): # pylint: disable=non-iterator-returned
11271164
return self
11281165

11291166
def __next__(self):
1167+
# Now that the heartbeat thread runs in the background
1168+
# there should be no reason to maintain a separate iterator
1169+
# but we'll keep it available for a few releases just in case
1170+
if self.config['legacy_iterator']:
1171+
return self.next_v1()
1172+
else:
1173+
return self.next_v2()
1174+
1175+
def next_v2(self):
1176+
self._set_consumer_timeout()
1177+
while time.time() < self._consumer_timeout:
1178+
if not self._iterator:
1179+
self._iterator = self._message_generator_v2()
1180+
try:
1181+
return next(self._iterator)
1182+
except StopIteration:
1183+
self._iterator = None
1184+
raise StopIteration()
1185+
1186+
def next_v1(self):
11301187
if not self._iterator:
11311188
self._iterator = self._message_generator()
11321189

kafka/coordinator/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,14 @@ def poll_heartbeat(self):
321321
self.heartbeat.poll()
322322

323323
def time_to_next_heartbeat(self):
324+
"""Returns seconds (float) remaining before next heartbeat should be sent
325+
326+
Note: Returns infinite if group is not joined
327+
"""
324328
with self._lock:
325329
# if we have not joined the group, we don't need to send heartbeats
326330
if self.state is MemberState.UNJOINED:
327-
return sys.maxsize
331+
return float('inf')
328332
return self.heartbeat.time_to_next_heartbeat()
329333

330334
def _handle_join_success(self, member_assignment_bytes):

0 commit comments

Comments
 (0)