Skip to content

Commit 234cca5

Browse files
author
Gabriel Tincu
committed
client: allow for custom kafka clients
Provide the consumer, producer and admin client with the option to create the kafka client from a custom callable, thus allowing more flexibility in handling certain low level errors
1 parent 6f932ba commit 234cca5

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

kafka/admin/client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from kafka.admin.acl_resource import ACLOperation, ACLPermissionType, ACLFilter, ACL, ResourcePattern, ResourceType, \
1212
ACLResourcePatternType
13-
from kafka.client_async import KafkaClient, selectors
13+
from kafka.client_async import selectors
1414
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ConsumerProtocol
1515
import kafka.errors as Errors
1616
from kafka.errors import (
@@ -26,6 +26,7 @@
2626
from kafka.protocol.metadata import MetadataRequest
2727
from kafka.protocol.types import Array
2828
from kafka.structs import TopicPartition, OffsetAndMetadata, MemberInformation, GroupInformation
29+
from kafka.util import get_client_factory
2930
from kafka.version import __version__
3031

3132

@@ -146,6 +147,7 @@ class KafkaAdminClient(object):
146147
sasl mechanism handshake. Default: one of bootstrap servers
147148
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
148149
instance. (See kafka.oauth.abstract). Default: None
150+
client_factory (callable): Custom class / callable for creating KafkaClient instances
149151
150152
"""
151153
DEFAULT_CONFIG = {
@@ -186,6 +188,7 @@ class KafkaAdminClient(object):
186188
'metric_reporters': [],
187189
'metrics_num_samples': 2,
188190
'metrics_sample_window_ms': 30000,
191+
'client_factory': None,
189192
}
190193

191194
def __init__(self, **configs):
@@ -205,9 +208,11 @@ def __init__(self, **configs):
205208
reporters = [reporter() for reporter in self.config['metric_reporters']]
206209
self._metrics = Metrics(metric_config, reporters)
207210

208-
self._client = KafkaClient(metrics=self._metrics,
209-
metric_group_prefix='admin',
210-
**self.config)
211+
self._client = get_client_factory(self.config)(
212+
metrics=self._metrics,
213+
metric_group_prefix='admin',
214+
**self.config
215+
)
211216
self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000))
212217

213218
# Get auto-discovered version from client if necessary

kafka/consumer/group.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from kafka.vendor import six
1111

12-
from kafka.client_async import KafkaClient, selectors
12+
from kafka.client_async import selectors
1313
from kafka.consumer.fetcher import Fetcher
1414
from kafka.consumer.subscription_state import SubscriptionState
1515
from kafka.coordinator.consumer import ConsumerCoordinator
@@ -18,6 +18,7 @@
1818
from kafka.metrics import MetricConfig, Metrics
1919
from kafka.protocol.offset import OffsetResetStrategy
2020
from kafka.structs import TopicPartition
21+
from kafka.util import get_client_factory
2122
from kafka.version import __version__
2223

2324
log = logging.getLogger(__name__)
@@ -244,6 +245,7 @@ class KafkaConsumer(six.Iterator):
244245
sasl mechanism handshake. Default: one of bootstrap servers
245246
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
246247
instance. (See kafka.oauth.abstract). Default: None
248+
client_factory (callable): Custom class / callable for creating KafkaClient instances
247249
248250
Note:
249251
Configuration parameters are described in more detail at
@@ -306,6 +308,7 @@ class KafkaConsumer(six.Iterator):
306308
'sasl_kerberos_domain_name': None,
307309
'sasl_oauth_token_provider': None,
308310
'legacy_iterator': False, # enable to revert to < 1.4.7 iterator
311+
'client_factory': None,
309312
}
310313
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000
311314

@@ -353,7 +356,7 @@ def __init__(self, *topics, **configs):
353356
log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated',
354357
str(self.config['api_version']), str_version)
355358

356-
self._client = KafkaClient(metrics=self._metrics, **self.config)
359+
self._client = get_client_factory(self.config)(metrics=self._metrics, **self.config)
357360

358361
# Get auto-discovered version from client if necessary
359362
if self.config['api_version'] is None:

kafka/producer/kafka.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from kafka.vendor import six
1212

1313
import kafka.errors as Errors
14-
from kafka.client_async import KafkaClient, selectors
14+
from kafka.client_async import selectors
1515
from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd
1616
from kafka.metrics import MetricConfig, Metrics
1717
from kafka.partitioner.default import DefaultPartitioner
@@ -22,6 +22,7 @@
2222
from kafka.record.legacy_records import LegacyRecordBatchBuilder
2323
from kafka.serializer import Serializer
2424
from kafka.structs import TopicPartition
25+
from kafka.util import get_client_factory
2526

2627

2728
log = logging.getLogger(__name__)
@@ -280,6 +281,7 @@ class KafkaProducer(object):
280281
sasl mechanism handshake. Default: one of bootstrap servers
281282
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
282283
instance. (See kafka.oauth.abstract). Default: None
284+
client_factory (callable): Custom class / callable for creating KafkaClient instances
283285
284286
Note:
285287
Configuration parameters are described in more detail at
@@ -332,7 +334,8 @@ class KafkaProducer(object):
332334
'sasl_plain_password': None,
333335
'sasl_kerberos_service_name': 'kafka',
334336
'sasl_kerberos_domain_name': None,
335-
'sasl_oauth_token_provider': None
337+
'sasl_oauth_token_provider': None,
338+
'client_factory': None,
336339
}
337340

338341
_COMPRESSORS = {
@@ -378,9 +381,10 @@ def __init__(self, **configs):
378381
reporters = [reporter() for reporter in self.config['metric_reporters']]
379382
self._metrics = Metrics(metric_config, reporters)
380383

381-
client = KafkaClient(metrics=self._metrics, metric_group_prefix='producer',
382-
wakeup_timeout_ms=self.config['max_block_ms'],
383-
**self.config)
384+
client = get_client_factory(self.config)(
385+
metrics=self._metrics, metric_group_prefix='producer',
386+
wakeup_timeout_ms=self.config['max_block_ms'],
387+
**self.config)
384388

385389
# Get auto-discovered version from client if necessary
386390
if self.config['api_version'] is None:

kafka/util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import absolute_import
22

33
import binascii
4+
import kafka
45
import weakref
56

67
from kafka.vendor import six
@@ -64,3 +65,12 @@ class Dict(dict):
6465
See: https://docs.python.org/2/library/weakref.html
6566
"""
6667
pass
68+
69+
70+
def get_client_factory(config):
71+
if config.get('client_factory') is not None:
72+
client_factory = config['client_factory']
73+
assert callable(client_factory), "'client_factory' should be a callable or None, is {}".format(type(client_factory))
74+
else:
75+
client_factory = kafka.client_async.KafkaClient
76+
return client_factory

0 commit comments

Comments
 (0)