diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index ba35a7b7b8..e2a4fbe2cc 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -38,6 +38,7 @@ SLOT_ID, AbstractRedisCluster, LoadBalancer, + LoadBalancingStrategy, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -65,6 +66,7 @@ from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( SSL_AVAILABLE, + deprecated_args, deprecated_function, get_lib_version, safe_str, @@ -121,9 +123,15 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters :param read_from_replicas: - | Enable read from replicas in READONLY mode. You can read possibly stale data. + | @deprecated - please use load_balancing_strategy instead + | Enable read from replicas in READONLY mode. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. + The data read from replicas is eventually consistent with the data in primary nodes. + :param load_balancing_strategy: + | Enable read from replicas in READONLY mode and defines the load balancing + strategy that will be used for cluster node selection. + The data read from replicas is eventually consistent with the data in primary nodes. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -216,6 +224,11 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": "result_callbacks", ) + @deprecated_args( + args_to_warn=["read_from_replicas"], + reason="Please configure the 'load_balancing_strategy' instead", + version="5.0.3", + ) def __init__( self, host: Optional[str] = None, @@ -224,6 +237,7 @@ def __init__( startup_nodes: Optional[List["ClusterNode"]] = None, require_full_coverage: bool = True, read_from_replicas: bool = False, + load_balancing_strategy: Optional[LoadBalancingStrategy] = None, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 3, @@ -322,7 +336,7 @@ def __init__( } ) - if read_from_replicas: + if read_from_replicas or load_balancing_strategy: # Call our on_connect function to configure READONLY mode kwargs["redis_connect_func"] = self.on_connect @@ -371,6 +385,7 @@ def __init__( ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts @@ -589,6 +604,7 @@ async def _determine_nodes( self.nodes_manager.get_node_from_slot( await self._determine_slot(command, *args), self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, ) ] @@ -769,7 +785,11 @@ async def _execute_command( # refresh the target node slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and args[0] in READ_COMMANDS + slot, + self.read_from_replicas and args[0] in READ_COMMANDS, + self.load_balancing_strategy + if args[0] in READ_COMMANDS + else None, ) moved = False @@ -1231,17 +1251,23 @@ def _update_moved_slots(self) -> None: self._moved_exception = None def get_node_from_slot( - self, slot: int, read_from_replicas: bool = False + self, + slot: int, + read_from_replicas: bool = False, + load_balancing_strategy=None, ) -> "ClusterNode": if self._moved_exception: self._update_moved_slots() + if read_from_replicas is True and load_balancing_strategy is None: + load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN + try: - if read_from_replicas: - # get the server index in a Round-Robin manner + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + primary_name, len(self.slots_cache[slot]), load_balancing_strategy ) return self.slots_cache[slot][node_idx] return self.slots_cache[slot][0] diff --git a/redis/cluster.py b/redis/cluster.py index 8edf82e413..0488608a60 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -4,6 +4,7 @@ import threading import time from collections import OrderedDict +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union from redis._parsers import CommandsParser, Encoder @@ -482,6 +483,11 @@ class initializer. In the case of conflicting arguments, querystring """ return cls(url=url, **kwargs) + @deprecated_args( + args_to_warn=["read_from_replicas"], + reason="Please configure the 'load_balancing_strategy' instead", + version="5.0.3", + ) def __init__( self, host: Optional[str] = None, @@ -492,6 +498,7 @@ def __init__( require_full_coverage: bool = False, reinitialize_steps: int = 5, read_from_replicas: bool = False, + load_balancing_strategy: Optional["LoadBalancingStrategy"] = None, dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, @@ -520,11 +527,16 @@ def __init__( cluster client. If not all slots are covered, RedisClusterException will be thrown. :param read_from_replicas: + @deprecated - please use load_balancing_strategy instead Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. - :param dynamic_startup_nodes: + :param load_balancing_strategy: + Enable read from replicas in READONLY mode and defines the load balancing + strategy that will be used for cluster node selection. + The data read from replicas is eventually consistent with the data in primary nodes. + :param dynamic_startup_nodes: Set the RedisCluster's startup nodes to all of the discovered nodes. If true (default value), the cluster's discovered nodes will be used to determine the cluster nodes-slots mapping in the next topology refresh. @@ -629,6 +641,7 @@ def __init__( self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps if event_dispatcher is None: @@ -683,7 +696,7 @@ def on_connect(self, connection): """ connection.on_connect() - if self.read_from_replicas: + if self.read_from_replicas or self.load_balancing_strategy: # Sending READONLY command to server to configure connection as # readonly. Since each cluster node may change its server type due # to a failover, we should establish a READONLY connection @@ -810,6 +823,7 @@ def pipeline(self, transaction=None, shard_hint=None): cluster_response_callbacks=self.cluster_response_callbacks, cluster_error_retry_attempts=self.cluster_error_retry_attempts, read_from_replicas=self.read_from_replicas, + load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, lock=self._lock, ) @@ -934,7 +948,9 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, ) return [node] @@ -1158,7 +1174,11 @@ def _execute_command(self, target_node, *args, **kwargs): # refresh the target node slot = self.determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy + if command in READ_COMMANDS + else None, ) moved = False @@ -1307,6 +1327,12 @@ def __del__(self): self.redis_connection.close() +class LoadBalancingStrategy(Enum): + ROUND_ROBIN = "round_robin" + ROUND_ROBIN_REPLICAS = "round_robin_replicas" + RANDOM_REPLICA = "random_replica" + + class LoadBalancer: """ Round-Robin Load Balancing @@ -1316,15 +1342,38 @@ def __init__(self, start_index: int = 0) -> None: self.primary_to_idx = {} self.start_index = start_index - def get_server_index(self, primary: str, list_size: int) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index + def get_server_index( + self, + primary: str, + list_size: int, + load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, + ) -> int: + if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: + return self._get_random_replica_index(list_size) + else: + return self._get_round_robin_index( + primary, + list_size, + load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) def reset(self) -> None: self.primary_to_idx.clear() + def _get_random_replica_index(self, list_size: int) -> int: + return random.randint(1, list_size - 1) + + def _get_round_robin_index( + self, primary: str, list_size: int, replicas_only: bool + ) -> int: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + if replicas_only and server_index == 0: + # skip the primary node index + server_index = 1 + # Update the index for the next round + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + class NodesManager: def __init__( @@ -1428,7 +1477,21 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + @deprecated_args( + args_to_warn=["server_type"], + reason=( + "In case you need select some load balancing strategy " + "that will use replicas, please set it through 'load_balancing_strategy'" + ), + version="5.0.3", + ) + def get_node_from_slot( + self, + slot, + read_from_replicas=False, + load_balancing_strategy=None, + server_type=None, + ): """ Gets a node that servers this hash slot """ @@ -1443,11 +1506,14 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): f'"require_full_coverage={self._require_full_coverage}"' ) - if read_from_replicas is True: - # get the server index in a Round-Robin manner + if read_from_replicas is True and load_balancing_strategy is None: + load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN + + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + primary_name, len(self.slots_cache[slot]), load_balancing_strategy ) elif ( server_type is None @@ -1730,7 +1796,7 @@ def __init__( first command execution. The node will be determined by: 1. Hashing the channel name in the request to find its keyslot 2. Selecting a node that handles the keyslot: If read_from_replicas is - set to true, a replica can be selected. + set to true or load_balancing_strategy is set, a replica can be selected. :type redis_cluster: RedisCluster :type node: ClusterNode @@ -1826,7 +1892,9 @@ def execute_command(self, *args): channel = args[1] slot = self.cluster.keyslot(channel) node = self.cluster.nodes_manager.get_node_from_slot( - slot, self.cluster.read_from_replicas + slot, + self.cluster.read_from_replicas, + self.cluster.load_balancing_strategy, ) else: # Get a random node @@ -1969,6 +2037,7 @@ def __init__( cluster_response_callbacks: Optional[Dict[str, Callable]] = None, startup_nodes: Optional[List["ClusterNode"]] = None, read_from_replicas: bool = False, + load_balancing_strategy: Optional[LoadBalancingStrategy] = None, cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, lock=None, @@ -1984,6 +2053,7 @@ def __init__( ) self.startup_nodes = startup_nodes if startup_nodes else [] self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks self.cluster_error_retry_attempts = cluster_error_retry_attempts diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 4fbfcf62ce..f57718b44f 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -14,7 +14,13 @@ from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff -from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name +from redis.cluster import ( + PIPELINE_BLOCKED_COMMANDS, + PRIMARY, + REPLICA, + LoadBalancingStrategy, + get_node_name, +) from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -181,7 +187,18 @@ def cmd_init_mock(self, r: ClusterNode) -> None: cmd_parser_initialize.side_effect = cmd_init_mock - return await RedisCluster(*args, **kwargs) + # Create a subclass of RedisCluster that overrides __del__ + class MockedRedisCluster(RedisCluster): + def __del__(self): + # Override to prevent connection cleanup attempts + pass + + @property + def connection_pool(self): + # Required abstract property implementation + return self.nodes_manager.get_default_node().redis_connection.connection_pool + + return await MockedRedisCluster(*args, **kwargs) def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: @@ -677,7 +694,24 @@ def cmd_init_mock(self, r: ClusterNode) -> None: assert execute_command.failed_calls == 1 assert execute_command.successful_calls == 1 - async def test_reading_from_replicas_in_round_robin(self) -> None: + @pytest.mark.parametrize( + "read_from_replicas,load_balancing_strategy,mocks_srv_ports", + [ + (True, None, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (True, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (False, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + ], + ) + async def test_reading_with_load_balancing_strategies( + self, + read_from_replicas: bool, + load_balancing_strategy: LoadBalancingStrategy, + mocks_srv_ports: List[int], + ) -> None: with mock.patch.multiple( Connection, send_command=mock.DEFAULT, @@ -693,19 +727,19 @@ async def test_reading_from_replicas_in_round_robin(self) -> None: async def execute_command_mock_first(self, *args, **options): await self.connection_class(**self.connection_kwargs).connect() # Primary - assert self.port == 7001 + assert self.port == mocks_srv_ports[0] execute_command.side_effect = execute_command_mock_second return "MOCK_OK" def execute_command_mock_second(self, *args, **options): # Replica - assert self.port == 7002 + assert self.port == mocks_srv_ports[1] execute_command.side_effect = execute_command_mock_third return "MOCK_OK" def execute_command_mock_third(self, *args, **options): # Primary - assert self.port == 7001 + assert self.port == mocks_srv_ports[2] return "MOCK_OK" # We don't need to create a real cluster connection but we @@ -720,9 +754,13 @@ def execute_command_mock_third(self, *args, **options): # Create a cluster with reading from replications read_cluster = await get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True + host=default_host, + port=default_port, + read_from_replicas=read_from_replicas, + load_balancing_strategy=load_balancing_strategy, ) - assert read_cluster.read_from_replicas is True + assert read_cluster.read_from_replicas is read_from_replicas + assert read_cluster.load_balancing_strategy is load_balancing_strategy # Check that we read from the slot's nodes in a round robin # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: @@ -970,6 +1008,34 @@ async def test_get_and_set(self, r: RedisCluster) -> None: assert await r.get("integer") == str(integer).encode() assert (await r.get("unicode_string")).decode("utf-8") == unicode_string + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN, + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + async def test_get_and_set_with_load_balanced_client( + self, create_redis, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + r = await create_redis( + cls=RedisCluster, + load_balancing_strategy=load_balancing_strategy, + ) + + # get and set can't be tested independently of each other + assert await r.get("a") is None + + byte_string = b"value" + assert await r.set("byte_string", byte_string) + + # run the get command for the same key several times + # to iterate over the read nodes + assert await r.get("byte_string") == byte_string + assert await r.get("byte_string") == byte_string + assert await r.get("byte_string") == byte_string + async def test_mget_nonatomic(self, r: RedisCluster) -> None: assert await r.mget_nonatomic([]) == [] assert await r.mget_nonatomic(["a", "b"]) == [None, None] @@ -2370,11 +2436,14 @@ async def test_load_balancer(self, r: RedisCluster) -> None: primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) + + # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary1_name, list1_size) == 1 assert lb.get_server_index(primary1_name, list1_size) == 2 assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 assert lb.get_server_index(primary2_name, list2_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 1 @@ -2384,6 +2453,29 @@ async def test_load_balancer(self, r: RedisCluster) -> None: assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS + for i in [1, 2, 1]: + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + assert srv_index == i + + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA + for i in range(5): + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, + ) + + assert srv_index > 0 and srv_index <= 2 + async def test_init_slots_cache_not_all_slots_covered(self) -> None: """ Test that if not all slots are covered it should raise an exception @@ -2866,6 +2958,37 @@ async def test_readonly_pipeline_from_readonly_client( break assert executed_on_replica + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + async def test_readonly_pipeline_with_reading_from_replicas_strategies( + self, r: RedisCluster, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + """ + Test that the pipeline uses replicas for different replica-based + load balancing strategies. + """ + # Set the load balancing strategy + r.load_balancing_strategy = load_balancing_strategy + key = "bar" + await r.set(key, "foo") + + async with r.pipeline() as pipe: + mock_all_nodes_resp(r, "MOCK_OK") + assert await pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = r.nodes_manager.slots_cache[r.keyslot(key)] + executed_on_replicas_only = True + for node in slot_nodes: + if node.server_type == PRIMARY: + if node._free.pop().read_response.await_count > 0: + executed_on_replicas_only = False + break + assert executed_on_replicas_only + async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: """Test that the pipeline can be used concurrently.""" await asyncio.gather( diff --git a/tests/test_cluster.py b/tests/test_cluster.py index e64db3690b..b71908d396 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -4,6 +4,7 @@ import socket import socketserver import threading +from typing import List import warnings from queue import LifoQueue, Queue from time import sleep @@ -19,6 +20,7 @@ REDIS_CLUSTER_HASH_SLOTS, REPLICA, ClusterNode, + LoadBalancingStrategy, NodesManager, RedisCluster, get_node_name, @@ -202,7 +204,18 @@ def cmd_init_mock(self, r): cmd_parser_initialize.side_effect = cmd_init_mock - return RedisCluster(*args, **kwargs) + # Create a subclass of RedisCluster that overrides __del__ + class MockedRedisCluster(RedisCluster): + def __del__(self): + # Override to prevent connection cleanup attempts + pass + + @property + def connection_pool(self): + # Required abstract property implementation + return self.nodes_manager.get_default_node().redis_connection.connection_pool + + return MockedRedisCluster(*args, **kwargs) def mock_node_resp(node, response): @@ -590,7 +603,24 @@ def cmd_init_mock(self, r): assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 - def test_reading_from_replicas_in_round_robin(self): + @pytest.mark.parametrize( + "read_from_replicas,load_balancing_strategy,mocks_srv_ports", + [ + (True, None, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (True, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (False, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + ], + ) + def test_reading_with_load_balancing_strategies( + self, + read_from_replicas: bool, + load_balancing_strategy: LoadBalancingStrategy, + mocks_srv_ports: List[int], + ): with patch.multiple( Connection, send_command=DEFAULT, @@ -603,19 +633,19 @@ def test_reading_from_replicas_in_round_robin(self): def parse_response_mock_first(connection, *args, **options): # Primary - assert connection.port == 7001 + assert connection.port == mocks_srv_ports[0] parse_response.side_effect = parse_response_mock_second return "MOCK_OK" def parse_response_mock_second(connection, *args, **options): # Replica - assert connection.port == 7002 + assert connection.port == mocks_srv_ports[1] parse_response.side_effect = parse_response_mock_third return "MOCK_OK" def parse_response_mock_third(connection, *args, **options): # Primary - assert connection.port == 7001 + assert connection.port == mocks_srv_ports[2] return "MOCK_OK" # We don't need to create a real cluster connection but we @@ -630,9 +660,13 @@ def parse_response_mock_third(connection, *args, **options): # Create a cluster with reading from replications read_cluster = get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True + host=default_host, + port=default_port, + read_from_replicas=read_from_replicas, + load_balancing_strategy=load_balancing_strategy, ) - assert read_cluster.read_from_replicas is True + assert read_cluster.read_from_replicas is read_from_replicas + assert read_cluster.load_balancing_strategy is load_balancing_strategy # Check that we read from the slot's nodes in a round robin # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: @@ -640,16 +674,27 @@ def parse_response_mock_third(connection, *args, **options): read_cluster.get("foo") read_cluster.get("foo") read_cluster.get("foo") - mocks["send_command"].assert_has_calls( + expected_calls_list = [] + expected_calls_list.append(call("READONLY")) + expected_calls_list.append(call("GET", "foo", keys=["foo"])) + + if ( + load_balancing_strategy is None + or load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN + ): + # in the round robin strategy the primary node can also receive read + # requests and this means that there will be second node connected + expected_calls_list.append(call("READONLY")) + + expected_calls_list.extend( [ - call("READONLY"), - call("GET", "foo", keys=["foo"]), - call("READONLY"), call("GET", "foo", keys=["foo"]), call("GET", "foo", keys=["foo"]), ] ) + mocks["send_command"].assert_has_calls(expected_calls_list) + def test_keyslot(self, r): """ Test that method will compute correct key in all supported cases @@ -975,6 +1020,35 @@ def test_get_and_set(self, r): assert r.get("integer") == str(integer).encode() assert r.get("unicode_string").decode("utf-8") == unicode_string + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN, + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + def test_get_and_set_with_load_balanced_client( + self, request, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + r = _get_client( + cls=RedisCluster, + request=request, + load_balancing_strategy=load_balancing_strategy, + ) + + # get and set can't be tested independently of each other + assert r.get("a") is None + + byte_string = b"value" + assert r.set("byte_string", byte_string) + + # run the get command for the same key several times + # to iterate over the read nodes + assert r.get("byte_string") == byte_string + assert r.get("byte_string") == byte_string + assert r.get("byte_string") == byte_string + def test_mget_nonatomic(self, r): assert r.mget_nonatomic([]) == [] assert r.mget_nonatomic(["a", "b"]) == [None, None] @@ -2473,6 +2547,8 @@ def test_load_balancer(self, r): primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) + + # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary1_name, list1_size) == 1 @@ -2487,6 +2563,29 @@ def test_load_balancer(self, r): assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS + for i in [1, 2, 1]: + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + assert srv_index == i + + # reset the indexes before load balancing strategy test + lb.reset() # reset the indexes + # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA + for i in range(5): + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, + ) + + assert srv_index > 0 and srv_index <= 2 + def test_init_slots_cache_not_all_slots_covered(self): """ Test that if not all slots are covered it should raise an exception @@ -3333,6 +3432,45 @@ def test_readonly_pipeline_from_readonly_client(self, request): break assert executed_on_replica is True + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + def test_readonly_pipeline_with_reading_from_replicas_strategies( + self, request, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + """ + Test that the pipeline uses replicas for different replica-based + load balancing strategies. + """ + ro = _get_client( + RedisCluster, + request, + load_balancing_strategy=load_balancing_strategy, + ) + key = "bar" + ro.set(key, "foo") + import time + + time.sleep(0.2) + + with ro.pipeline() as readonly_pipe: + mock_all_nodes_resp(ro, "MOCK_OK") + assert readonly_pipe.load_balancing_strategy == load_balancing_strategy + assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + executed_on_replicas_only = True + for node in slot_nodes: + if node.server_type == PRIMARY: + conn = node.redis_connection.connection + if conn.read_response.called: + executed_on_replicas_only = False + break + assert executed_on_replicas_only + @pytest.mark.onlycluster class TestClusterMonitor: diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 79301b93f1..549eeb49a2 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -23,7 +23,7 @@ class TestMultiprocessing: # The code in this module does not work with it, # hence the explicit change to 'fork' # See https://github.com/python/cpython/issues/125714 - if multiprocessing.get_start_method() == "forkserver": + if multiprocessing.get_start_method() in ["forkserver", "spawn"]: _mp_context = multiprocessing.get_context(method="fork") else: _mp_context = multiprocessing.get_context() @@ -119,7 +119,7 @@ def target(pool, parent_conn): assert child_conn in pool._available_connections assert parent_conn not in pool._available_connections - proc = multiprocessing.Process(target=target, args=(pool, parent_conn)) + proc = self._mp_context.Process(target=target, args=(pool, parent_conn)) proc.start() proc.join(3) assert proc.exitcode == 0