Skip to content

Commit 778acbc

Browse files
committed
Cleanup handling of KAFKA_VERSION env var in tests
Now that we are using `pytest`, there is no need for a custom decorator because we can use `pytest.mark.skipif()`. This makes the code significantly simpler.
1 parent 6d1f715 commit 778acbc

File tree

6 files changed

+41
-119
lines changed

6 files changed

+41
-119
lines changed

test/conftest.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44

55
import pytest
66

7-
from test.fixtures import KafkaFixture, ZookeeperFixture, random_string, version as kafka_version
8-
9-
10-
@pytest.fixture(scope="module")
11-
def version():
12-
"""Return the Kafka version set in the OS environment"""
13-
return kafka_version()
7+
from test.fixtures import KafkaFixture, ZookeeperFixture, random_string
8+
from test.testutil import env_kafka_version
149

1510

1611
@pytest.fixture(scope="module")
@@ -28,9 +23,9 @@ def kafka_broker(kafka_broker_factory):
2823

2924

3025
@pytest.fixture(scope="module")
31-
def kafka_broker_factory(version, zookeeper):
26+
def kafka_broker_factory(zookeeper):
3227
"""Return a Kafka broker fixture factory"""
33-
assert version, 'KAFKA_VERSION must be specified to run integration tests'
28+
assert env_kafka_version(), 'KAFKA_VERSION must be specified to run integration tests'
3429

3530
_brokers = []
3631
def factory(**broker_params):

test/fixtures.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from kafka.protocol.admin import CreateTopicsRequest
2121
from kafka.protocol.metadata import MetadataRequest
2222
from test.service import ExternalService, SpawnedService
23+
from test.testutil import env_kafka_version
2324

2425
log = logging.getLogger(__name__)
2526

@@ -28,20 +29,6 @@ def random_string(length):
2829
return "".join(random.choice(string.ascii_letters) for i in range(length))
2930

3031

31-
def version_str_to_tuple(version_str):
32-
"""Transform a version string into a tuple.
33-
34-
Example: '0.8.1.1' --> (0, 8, 1, 1)
35-
"""
36-
return tuple(map(int, version_str.split('.')))
37-
38-
39-
def version():
40-
if 'KAFKA_VERSION' not in os.environ:
41-
return ()
42-
return version_str_to_tuple(os.environ['KAFKA_VERSION'])
43-
44-
4532
def get_open_port():
4633
sock = socket.socket()
4734
sock.bind(("", 0))
@@ -477,7 +464,7 @@ def _create_topic(self, topic_name, num_partitions, replication_factor, timeout_
477464
num_partitions == self.partitions and \
478465
replication_factor == self.replicas:
479466
self._send_request(MetadataRequest[0]([topic_name]))
480-
elif version() >= (0, 10, 1, 0):
467+
elif env_kafka_version() >= (0, 10, 1, 0):
481468
request = CreateTopicsRequest[0]([(topic_name, num_partitions,
482469
replication_factor, [], [])], timeout_ms)
483470
result = self._send_request(request, timeout=timeout_ms)
@@ -497,7 +484,7 @@ def _create_topic(self, topic_name, num_partitions, replication_factor, timeout_
497484
'--replication-factor', self.replicas \
498485
if replication_factor is None \
499486
else replication_factor)
500-
if version() >= (0, 10):
487+
if env_kafka_version() >= (0, 10):
501488
args.append('--if-not-exists')
502489
env = self.kafka_run_class_env()
503490
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

test/test_consumer_group.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
from kafka.coordinator.base import MemberState
1212
from kafka.structs import TopicPartition
1313

14-
from test.fixtures import random_string, version
14+
from test.fixtures import random_string
15+
from test.testutil import env_kafka_version
1516

1617

1718
def get_connect_str(kafka_broker):
1819
return kafka_broker.host + ':' + str(kafka_broker.port)
1920

2021

21-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
22-
def test_consumer(kafka_broker, topic, version):
22+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
23+
def test_consumer(kafka_broker, topic):
2324
# The `topic` fixture is included because
2425
# 0.8.2 brokers need a topic to function well
2526
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
@@ -29,17 +30,16 @@ def test_consumer(kafka_broker, topic, version):
2930
assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED
3031
consumer.close()
3132

32-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
33-
def test_consumer_topics(kafka_broker, topic, version):
33+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
34+
def test_consumer_topics(kafka_broker, topic):
3435
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
3536
# Necessary to drive the IO
3637
consumer.poll(500)
3738
assert topic in consumer.topics()
3839
assert len(consumer.partitions_for_topic(topic)) > 0
3940
consumer.close()
4041

41-
@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
42-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
42+
@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version')
4343
def test_group(kafka_broker, topic):
4444
num_partitions = 4
4545
connect_str = get_connect_str(kafka_broker)
@@ -129,7 +129,7 @@ def consumer_thread(i):
129129
threads[c] = None
130130

131131

132-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
132+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
133133
def test_paused(kafka_broker, topic):
134134
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
135135
topics = [TopicPartition(topic, 1)]
@@ -148,8 +148,7 @@ def test_paused(kafka_broker, topic):
148148
consumer.close()
149149

150150

151-
@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
152-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
151+
@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version')
153152
def test_heartbeat_thread(kafka_broker, topic):
154153
group_id = 'test-group-' + random_string(6)
155154
consumer = KafkaConsumer(topic,

test/test_consumer_integration.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
)
1313
from kafka.structs import TopicPartition, OffsetAndTimestamp
1414

15-
from test.fixtures import random_string, version
16-
from test.testutil import kafka_versions, Timer, assert_message_count
15+
from test.fixtures import random_string
16+
from test.testutil import env_kafka_version, Timer, assert_message_count
1717

1818

19-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
19+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
2020
def test_kafka_consumer(kafka_consumer_factory, send_messages):
2121
"""Test KafkaConsumer"""
2222
consumer = kafka_consumer_factory(auto_offset_reset='earliest')
@@ -35,7 +35,7 @@ def test_kafka_consumer(kafka_consumer_factory, send_messages):
3535
assert_message_count(messages[1], 100)
3636

3737

38-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
38+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
3939
def test_kafka_consumer_unsupported_encoding(
4040
topic, kafka_producer_factory, kafka_consumer_factory):
4141
# Send a compressed message
@@ -53,7 +53,7 @@ def test_kafka_consumer_unsupported_encoding(
5353
consumer.poll(timeout_ms=2000)
5454

5555

56-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
56+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
5757
def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages):
5858
TIMEOUT_MS = 500
5959
consumer = kafka_consumer_factory(auto_offset_reset='earliest',
@@ -92,7 +92,7 @@ def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages):
9292
assert t.interval >= (TIMEOUT_MS / 1000.0)
9393

9494

95-
@kafka_versions('>=0.8.1')
95+
@pytest.mark.skipif(env_kafka_version() < (0, 8, 1), reason="Requires KAFKA_VERSION >= 0.8.1")
9696
def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messages):
9797
GROUP_ID = random_string(10)
9898

@@ -131,7 +131,7 @@ def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messa
131131
assert_message_count(output_msgs1.extend(output_msgs2), 200)
132132

133133

134-
@kafka_versions('>=0.10.1')
134+
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
135135
def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_messages):
136136
send_messages(range(100, 200), partition=0)
137137
send_messages(range(200, 300), partition=1)
@@ -150,7 +150,7 @@ def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_mes
150150
assert seen_partitions == {TopicPartition(topic, 0), TopicPartition(topic, 1)}
151151

152152

153-
@kafka_versions('>=0.10.1')
153+
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
154154
def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages):
155155
# We send to only 1 partition so we don't have parallel requests to 2
156156
# nodes for data.
@@ -176,7 +176,7 @@ def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages)
176176
assert_message_count(fetched_msgs, 10)
177177

178178

179-
@kafka_versions('>=0.10.1')
179+
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
180180
def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer):
181181
late_time = int(time.time()) * 1000
182182
middle_time = late_time - 1000
@@ -225,7 +225,7 @@ def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer):
225225
assert offsets == {tp: late_msg.offset + 1}
226226

227227

228-
@kafka_versions('>=0.10.1')
228+
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
229229
def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_producer, topic):
230230
tp0 = TopicPartition(topic, 0)
231231
tp1 = TopicPartition(topic, 1)
@@ -263,7 +263,7 @@ def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_pro
263263
}
264264

265265

266-
@kafka_versions('<0.10.1')
266+
@pytest.mark.skipif(env_kafka_version() >= (0, 10, 1), reason="Requires KAFKA_VERSION < 0.10.1")
267267
def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic):
268268
consumer = kafka_consumer
269269
tp = TopicPartition(topic, 0)
@@ -272,7 +272,7 @@ def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic):
272272
consumer.offsets_for_times({tp: int(time.time())})
273273

274274

275-
@kafka_versions('>=0.10.1')
275+
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
276276
def test_kafka_consumer_offsets_for_times_errors(kafka_consumer_factory, topic):
277277
consumer = kafka_consumer_factory(fetch_max_wait_ms=200,
278278
request_timeout_ms=500)

test/test_producer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from kafka import KafkaConsumer, KafkaProducer, TopicPartition
99
from kafka.producer.buffer import SimpleBufferPool
10-
from test.fixtures import random_string, version
10+
from test.fixtures import random_string
11+
from test.testutil import env_kafka_version
1112

1213

1314
def test_buffer_pool():
@@ -22,13 +23,13 @@ def test_buffer_pool():
2223
assert buf2.read() == b''
2324

2425

25-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
26+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
2627
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4'])
2728
def test_end_to_end(kafka_broker, compression):
2829

2930
if compression == 'lz4':
3031
# LZ4 requires 0.8.2
31-
if version() < (0, 8, 2):
32+
if env_kafka_version() < (0, 8, 2):
3233
return
3334
# python-lz4 crashes on older versions of pypy
3435
elif platform.python_implementation() == 'PyPy':
@@ -80,7 +81,7 @@ def test_kafka_producer_gc_cleanup():
8081
assert threading.active_count() == threads
8182

8283

83-
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
84+
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
8485
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4'])
8586
def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
8687
connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)])
@@ -91,7 +92,7 @@ def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
9192
magic = producer._max_usable_produce_magic()
9293

9394
# record headers are supported in 0.11.0
94-
if version() < (0, 11, 0):
95+
if env_kafka_version() < (0, 11, 0):
9596
headers = None
9697
else:
9798
headers = [("Header Key", b"Header Value")]

test/testutil.py

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,17 @@
11
from __future__ import absolute_import
22

3-
import functools
4-
import operator
3+
import os
54
import time
65

7-
import pytest
86

9-
from test.fixtures import version_str_to_tuple, version as kafka_version
7+
def env_kafka_version():
8+
"""Return the Kafka version set in the OS environment as a tuple.
109
11-
12-
def kafka_versions(*versions):
10+
Example: '0.8.1.1' --> (0, 8, 1, 1)
1311
"""
14-
Describe the Kafka versions this test is relevant to.
15-
16-
The versions are passed in as strings, for example:
17-
'0.11.0'
18-
'>=0.10.1.0'
19-
'>0.9', '<1.0' # since this accepts multiple versions args
20-
21-
The current KAFKA_VERSION will be evaluated against this version. If the
22-
result is False, then the test is skipped. Similarly, if KAFKA_VERSION is
23-
not set the test is skipped.
24-
25-
Note: For simplicity, this decorator accepts Kafka versions as strings even
26-
though the similarly functioning `api_version` only accepts tuples. Trying
27-
to convert it to tuples quickly gets ugly due to mixing operator strings
28-
alongside version tuples. While doable when one version is passed in, it
29-
isn't pretty when multiple versions are passed in.
30-
"""
31-
32-
def construct_lambda(s):
33-
if s[0].isdigit():
34-
op_str = '='
35-
v_str = s
36-
elif s[1].isdigit():
37-
op_str = s[0] # ! < > =
38-
v_str = s[1:]
39-
elif s[2].isdigit():
40-
op_str = s[0:2] # >= <=
41-
v_str = s[2:]
42-
else:
43-
raise ValueError('Unrecognized kafka version / operator: %s' % (s,))
44-
45-
op_map = {
46-
'=': operator.eq,
47-
'!': operator.ne,
48-
'>': operator.gt,
49-
'<': operator.lt,
50-
'>=': operator.ge,
51-
'<=': operator.le
52-
}
53-
op = op_map[op_str]
54-
version = version_str_to_tuple(v_str)
55-
return lambda a: op(a, version)
56-
57-
validators = map(construct_lambda, versions)
58-
59-
def real_kafka_versions(func):
60-
@functools.wraps(func)
61-
def wrapper(func, *args, **kwargs):
62-
version = kafka_version()
63-
64-
if not version:
65-
pytest.skip("no kafka version set in KAFKA_VERSION env var")
66-
67-
for f in validators:
68-
if not f(version):
69-
pytest.skip("unsupported kafka version")
70-
71-
return func(*args, **kwargs)
72-
return wrapper
73-
74-
return real_kafka_versions
12+
if 'KAFKA_VERSION' not in os.environ:
13+
return ()
14+
return tuple(map(int, os.environ['KAFKA_VERSION'].split('.')))
7515

7616

7717
def assert_message_count(messages, num_messages):

0 commit comments

Comments
 (0)