1
+ from enum import Enum
1
2
import random
2
3
import socket
3
4
import sys
@@ -190,6 +191,27 @@ def cleanup_kwargs(**kwargs):
190
191
191
192
return connection_kwargs
192
193
194
+ class ReadFromReplicasMode (Enum ):
195
+ ReadFromPrimary = 0
196
+ ReadFromPrimaryAndReplica = 1
197
+ ReadFromReplicaOnly = 2
198
+
199
+ @staticmethod
200
+ def from_parameters (input : bool | "ReadFromReplicasMode" ):
201
+ if input == True :
202
+ return ReadFromReplicasMode .ReadFromPrimaryAndReplica
203
+ elif input == False :
204
+ return ReadFromReplicasMode .ReadFromPrimary
205
+ if not input in ReadFromReplicasMode :
206
+ raise RedisClusterException ("Argument 'read_from_replicas' must be a boolean or a value of ReadFromReplicasMode" )
207
+ return input
208
+
209
+ def get_replica_mode_for_command (self , command : str ):
210
+ if self == ReadFromReplicasMode .ReadFromPrimary :
211
+ return ReadFromReplicasMode .ReadFromPrimary
212
+ if not command in READ_COMMANDS :
213
+ return ReadFromReplicasMode .ReadFromPrimary
214
+ return self
193
215
194
216
class ClusterParser (DefaultParser ):
195
217
EXCEPTION_CLASSES = dict_merge (
@@ -503,7 +525,7 @@ def __init__(
503
525
retry : Optional ["Retry" ] = None ,
504
526
require_full_coverage : bool = False ,
505
527
reinitialize_steps : int = 5 ,
506
- read_from_replicas : bool = False ,
528
+ read_from_replicas : bool | ReadFromReplicasMode = False ,
507
529
dynamic_startup_nodes : bool = True ,
508
530
url : Optional [str ] = None ,
509
531
address_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
@@ -532,7 +554,9 @@ def __init__(
532
554
Enable read from replicas in READONLY mode. You can read possibly
533
555
stale data.
534
556
When set to true, read commands will be assigned between the
535
- primary and its replications in a Round-Robin manner.
557
+ primary and its replications in a Round-Robin manner. When set to
558
+ ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
559
+ the replicas
536
560
:param dynamic_startup_nodes:
537
561
Set the RedisCluster's startup nodes to all of the discovered nodes.
538
562
If true (default value), the cluster's discovered nodes will be used to
@@ -633,7 +657,7 @@ def __init__(
633
657
self .cluster_error_retry_attempts = cluster_error_retry_attempts
634
658
self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
635
659
self .node_flags = self .__class__ .NODE_FLAGS .copy ()
636
- self .read_from_replicas = read_from_replicas
660
+ self .read_from_replicas_mode = ReadFromReplicasMode . from_parameters ( read_from_replicas )
637
661
self .reinitialize_counter = 0
638
662
self .reinitialize_steps = reinitialize_steps
639
663
self .nodes_manager = NodesManager (
@@ -678,7 +702,7 @@ def on_connect(self, connection):
678
702
connection .set_parser (ClusterParser )
679
703
connection .on_connect ()
680
704
681
- if self .read_from_replicas :
705
+ if self .read_from_replicas != ReadFromReplicasMode . ReadFromPrimary :
682
706
# Sending READONLY command to server to configure connection as
683
707
# readonly. Since each cluster node may change its server type due
684
708
# to a failover, we should establish a READONLY connection
@@ -706,6 +730,13 @@ def get_primaries(self):
706
730
707
731
def get_replicas (self ):
708
732
return self .nodes_manager .get_nodes_by_server_type (REPLICA )
733
+
734
+ def get_read_from_replica_mode_for_command (self , command : str ):
735
+ if (
736
+ (self .read_from_replicas_mode == ReadFromReplicasMode .ReadFromPrimary ) or
737
+ (not command in READ_COMMANDS )):
738
+ return ReadFromReplicasMode .ReadFromPrimary
739
+ return self .read_from_replicas_mode
709
740
710
741
def get_random_node (self ):
711
742
return random .choice (list (self .nodes_manager .nodes_cache .values ()))
@@ -804,7 +835,7 @@ def pipeline(self, transaction=None, shard_hint=None):
804
835
result_callbacks = self .result_callbacks ,
805
836
cluster_response_callbacks = self .cluster_response_callbacks ,
806
837
cluster_error_retry_attempts = self .cluster_error_retry_attempts ,
807
- read_from_replicas = self .read_from_replicas ,
838
+ read_from_replicas_mode = self .read_from_replicas_mode ,
808
839
reinitialize_steps = self .reinitialize_steps ,
809
840
lock = self ._lock ,
810
841
)
@@ -922,7 +953,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
922
953
# get the node that holds the key's slot
923
954
slot = self .determine_slot (* args )
924
955
node = self .nodes_manager .get_node_from_slot (
925
- slot , self .read_from_replicas and command in READ_COMMANDS
956
+ slot , self .read_from_replicas_mode . get_replica_mode_for_command ( command )
926
957
)
927
958
return [node ]
928
959
@@ -1144,7 +1175,7 @@ def _execute_command(self, target_node, *args, **kwargs):
1144
1175
# refresh the target node
1145
1176
slot = self .determine_slot (* args )
1146
1177
target_node = self .nodes_manager .get_node_from_slot (
1147
- slot , self .read_from_replicas and command in READ_COMMANDS
1178
+ slot , self .read_from_replicas_mode . get_replica_mode_for_command ( command )
1148
1179
)
1149
1180
moved = False
1150
1181
@@ -1293,7 +1324,6 @@ def __del__(self):
1293
1324
if self .redis_connection is not None :
1294
1325
self .redis_connection .close ()
1295
1326
1296
-
1297
1327
class LoadBalancer :
1298
1328
"""
1299
1329
Round-Robin Load Balancing
@@ -1302,11 +1332,30 @@ class LoadBalancer:
1302
1332
def __init__ (self , start_index : int = 0 ) -> None :
1303
1333
self .primary_to_idx = {}
1304
1334
self .start_index = start_index
1305
-
1306
- def get_server_index (self , primary : str , list_size : int ) -> int :
1307
- server_index = self .primary_to_idx .setdefault (primary , self .start_index )
1308
- # Update the index
1309
- self .primary_to_idx [primary ] = (server_index + 1 ) % list_size
1335
+
1336
+ def get_node_from_slot (self , slot_index : int , slot_nodes : list [ClusterNode ] | None , read_from_replicas_mode : ReadFromReplicasMode ):
1337
+ if slot_nodes is None or len (slot_nodes ) == 0 :
1338
+ raise SlotNotCoveredError (
1339
+ f'Slot "{ slot_index } " not covered by the cluster. '
1340
+ )
1341
+ if read_from_replicas_mode == ReadFromReplicasMode .ReadFromPrimary :
1342
+ node_idx = 0
1343
+ else :
1344
+ skip_primary = read_from_replicas_mode == ReadFromReplicasMode .ReadFromReplicaOnly
1345
+ # get the server index in a Round-Robin manner
1346
+ primary_name = slot_nodes [0 ].name
1347
+ node_idx = self .read_load_balancer .get_server_index (
1348
+ primary_name , len (slot_nodes ), skip_primary
1349
+ )
1350
+ return slot_nodes [node_idx ]
1351
+
1352
+ def get_server_index (self , primary : str , list_size : int , skip_primary :bool ) -> int :
1353
+ # default to -1 if not found, so after incrementing it will be 0
1354
+ server_index = (self .primary_to_idx .get (primary , - 1 ) + 1 ) % list_size
1355
+ # If we skip primary, skip the zero-index node.
1356
+ if skip_primary and server_index == 0 and list_size > 1 :
1357
+ server_index = server_index + 1
1358
+ self .primary_to_idx [primary ] = server_index
1310
1359
return server_index
1311
1360
1312
1361
def reset (self ) -> None :
@@ -1401,41 +1450,23 @@ def _update_moved_slots(self):
1401
1450
# Reset moved_exception
1402
1451
self ._moved_exception = None
1403
1452
1404
- def get_node_from_slot (self , slot , read_from_replicas = False , server_type = None ):
1453
+ def get_node_from_slot (
1454
+ self , slot : int , read_from_replicas_mode : ReadFromReplicasMode
1455
+ ) -> "ClusterNode" :
1405
1456
"""
1406
1457
Gets a node that servers this hash slot
1407
1458
"""
1408
1459
if self ._moved_exception :
1409
1460
with self ._lock :
1410
1461
if self ._moved_exception :
1411
1462
self ._update_moved_slots ()
1412
-
1413
- if self .slots_cache .get (slot ) is None or len (self .slots_cache [slot ]) == 0 :
1414
- raise SlotNotCoveredError (
1415
- f'Slot "{ slot } " not covered by the cluster. '
1416
- f'"require_full_coverage={ self ._require_full_coverage } "'
1417
- )
1418
-
1419
- if read_from_replicas is True :
1420
- # get the server index in a Round-Robin manner
1421
- primary_name = self .slots_cache [slot ][0 ].name
1422
- node_idx = self .read_load_balancer .get_server_index (
1423
- primary_name , len (self .slots_cache [slot ])
1463
+
1464
+ return self .read_load_balancer .get_node_from_slot (
1465
+ slot ,
1466
+ self .slots_cache .get (slot , None ),
1467
+ read_from_replicas_mode ,
1424
1468
)
1425
- elif (
1426
- server_type is None
1427
- or server_type == PRIMARY
1428
- or len (self .slots_cache [slot ]) == 1
1429
- ):
1430
- # return a primary
1431
- node_idx = 0
1432
- else :
1433
- # return a replica
1434
- # randomly choose one of the replicas
1435
- node_idx = random .randint (1 , len (self .slots_cache [slot ]) - 1 )
1436
-
1437
- return self .slots_cache [slot ][node_idx ]
1438
-
1469
+
1439
1470
def get_nodes_by_server_type (self , server_type ):
1440
1471
"""
1441
1472
Get all nodes with the specified server type
@@ -1775,7 +1806,7 @@ def execute_command(self, *args):
1775
1806
channel = args [1 ]
1776
1807
slot = self .cluster .keyslot (channel )
1777
1808
node = self .cluster .nodes_manager .get_node_from_slot (
1778
- slot , self .cluster .read_from_replicas
1809
+ slot , self .cluster .read_from_replicas_mode
1779
1810
)
1780
1811
else :
1781
1812
# Get a random node
@@ -1915,7 +1946,7 @@ def __init__(
1915
1946
result_callbacks : Optional [Dict [str , Callable ]] = None ,
1916
1947
cluster_response_callbacks : Optional [Dict [str , Callable ]] = None ,
1917
1948
startup_nodes : Optional [List ["ClusterNode" ]] = None ,
1918
- read_from_replicas : bool = False ,
1949
+ read_from_replicas_mode : ReadFromReplicasMode = ReadFromReplicasMode . ReadFromPrimary ,
1919
1950
cluster_error_retry_attempts : int = 3 ,
1920
1951
reinitialize_steps : int = 5 ,
1921
1952
lock = None ,
@@ -1930,7 +1961,7 @@ def __init__(
1930
1961
result_callbacks or self .__class__ .RESULT_CALLBACKS .copy ()
1931
1962
)
1932
1963
self .startup_nodes = startup_nodes if startup_nodes else []
1933
- self .read_from_replicas = read_from_replicas
1964
+ self .read_from_replicas_mode = read_from_replicas_mode
1934
1965
self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
1935
1966
self .cluster_response_callbacks = cluster_response_callbacks
1936
1967
self .cluster_error_retry_attempts = cluster_error_retry_attempts
0 commit comments