1
+ from enum import Enum
1
2
import random
2
3
import socket
3
4
import sys
@@ -190,6 +191,27 @@ def cleanup_kwargs(**kwargs):
190
191
191
192
return connection_kwargs
192
193
194
+ class ReadFromReplicasMode (Enum ):
195
+ ReadFromPrimary = 0
196
+ ReadFromPrimaryAndReplica = 1
197
+ ReadFromReplicaOnly = 2
198
+
199
+ @staticmethod
200
+ def from_parameters (input : bool | "ReadFromReplicasMode" ):
201
+ if input == True :
202
+ return ReadFromReplicasMode .ReadFromPrimaryAndReplica
203
+ elif input == False :
204
+ return ReadFromReplicasMode .ReadFromPrimary
205
+ if not input in ReadFromReplicasMode :
206
+ raise RedisClusterException ("Argument 'read_from_replicas' must be a boolean or a value of ReadFromReplicasMode" )
207
+ return input
208
+
209
+ def get_replica_mode_for_command (self , command : str ):
210
+ if self == ReadFromReplicasMode .ReadFromPrimary :
211
+ return ReadFromReplicasMode .ReadFromPrimary
212
+ if not command in READ_COMMANDS :
213
+ return ReadFromReplicasMode .ReadFromPrimary
214
+ return self
193
215
194
216
class ClusterParser (DefaultParser ):
195
217
EXCEPTION_CLASSES = dict_merge (
@@ -503,7 +525,7 @@ def __init__(
503
525
retry : Optional ["Retry" ] = None ,
504
526
require_full_coverage : bool = False ,
505
527
reinitialize_steps : int = 5 ,
506
- read_from_replicas : bool = False ,
528
+ read_from_replicas : bool | ReadFromReplicasMode = False ,
507
529
dynamic_startup_nodes : bool = True ,
508
530
url : Optional [str ] = None ,
509
531
address_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
@@ -532,7 +554,9 @@ def __init__(
532
554
Enable read from replicas in READONLY mode. You can read possibly
533
555
stale data.
534
556
When set to true, read commands will be assigned between the
535
- primary and its replications in a Round-Robin manner.
557
+ primary and its replications in a Round-Robin manner. When set to
558
+ ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
559
+ the replicas
536
560
:param dynamic_startup_nodes:
537
561
Set the RedisCluster's startup nodes to all of the discovered nodes.
538
562
If true (default value), the cluster's discovered nodes will be used to
@@ -633,7 +657,8 @@ def __init__(
633
657
self .cluster_error_retry_attempts = cluster_error_retry_attempts
634
658
self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
635
659
self .node_flags = self .__class__ .NODE_FLAGS .copy ()
636
- self .read_from_replicas = read_from_replicas
660
+ self .read_from_replicas_mode = ReadFromReplicasMode .from_parameters (read_from_replicas )
661
+ self .read_from_replicas_mode = read_from_replicas
637
662
self .reinitialize_counter = 0
638
663
self .reinitialize_steps = reinitialize_steps
639
664
self .nodes_manager = NodesManager (
@@ -678,7 +703,7 @@ def on_connect(self, connection):
678
703
connection .set_parser (ClusterParser )
679
704
connection .on_connect ()
680
705
681
- if self .read_from_replicas :
706
+ if self .read_from_replicas != ReadFromReplicasMode . ReadFromPrimary :
682
707
# Sending READONLY command to server to configure connection as
683
708
# readonly. Since each cluster node may change its server type due
684
709
# to a failover, we should establish a READONLY connection
@@ -706,6 +731,13 @@ def get_primaries(self):
706
731
707
732
def get_replicas (self ):
708
733
return self .nodes_manager .get_nodes_by_server_type (REPLICA )
734
+
735
+ def get_read_from_replica_mode_for_command (self , command : str ):
736
+ if (
737
+ (self .read_from_replicas_mode == ReadFromReplicasMode .ReadFromPrimary ) or
738
+ (not command in READ_COMMANDS )):
739
+ return ReadFromReplicasMode .ReadFromPrimary
740
+ return self .read_from_replicas_mode
709
741
710
742
def get_random_node (self ):
711
743
return random .choice (list (self .nodes_manager .nodes_cache .values ()))
@@ -804,7 +836,7 @@ def pipeline(self, transaction=None, shard_hint=None):
804
836
result_callbacks = self .result_callbacks ,
805
837
cluster_response_callbacks = self .cluster_response_callbacks ,
806
838
cluster_error_retry_attempts = self .cluster_error_retry_attempts ,
807
- read_from_replicas = self .read_from_replicas ,
839
+ read_from_replicas_mode = self .read_from_replicas_mode ,
808
840
reinitialize_steps = self .reinitialize_steps ,
809
841
lock = self ._lock ,
810
842
)
@@ -922,7 +954,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
922
954
# get the node that holds the key's slot
923
955
slot = self .determine_slot (* args )
924
956
node = self .nodes_manager .get_node_from_slot (
925
- slot , self .read_from_replicas and command in READ_COMMANDS
957
+ slot , self .read_from_replicas_mode . get_replica_mode_for_command ( command )
926
958
)
927
959
return [node ]
928
960
@@ -1144,7 +1176,7 @@ def _execute_command(self, target_node, *args, **kwargs):
1144
1176
# refresh the target node
1145
1177
slot = self .determine_slot (* args )
1146
1178
target_node = self .nodes_manager .get_node_from_slot (
1147
- slot , self .read_from_replicas and command in READ_COMMANDS
1179
+ slot , self .read_from_replicas_mode . get_replica_mode_for_command ( command )
1148
1180
)
1149
1181
moved = False
1150
1182
@@ -1293,7 +1325,6 @@ def __del__(self):
1293
1325
if self .redis_connection is not None :
1294
1326
self .redis_connection .close ()
1295
1327
1296
-
1297
1328
class LoadBalancer :
1298
1329
"""
1299
1330
Round-Robin Load Balancing
@@ -1302,11 +1333,30 @@ class LoadBalancer:
1302
1333
def __init__ (self , start_index : int = 0 ) -> None :
1303
1334
self .primary_to_idx = {}
1304
1335
self .start_index = start_index
1305
-
1306
- def get_server_index (self , primary : str , list_size : int ) -> int :
1307
- server_index = self .primary_to_idx .setdefault (primary , self .start_index )
1308
- # Update the index
1309
- self .primary_to_idx [primary ] = (server_index + 1 ) % list_size
1336
+
1337
+ def get_node_from_slot (self , slot_index : int , slot_nodes : list [ClusterNode ] | None , read_from_replicas_mode : ReadFromReplicasMode ):
1338
+ if slot_nodes is None or len (slot_nodes ) == 0 :
1339
+ raise SlotNotCoveredError (
1340
+ f'Slot "{ slot_index } " not covered by the cluster. '
1341
+ )
1342
+ if read_from_replicas_mode == ReadFromReplicasMode .ReadFromPrimary :
1343
+ node_idx = 0
1344
+ else :
1345
+ skip_primary = read_from_replicas_mode == ReadFromReplicasMode .ReadFromReplicaOnly
1346
+ # get the server index in a Round-Robin manner
1347
+ primary_name = slot_nodes [0 ].name
1348
+ node_idx = self .read_load_balancer .get_server_index (
1349
+ primary_name , len (slot_nodes ), skip_primary
1350
+ )
1351
+ return slot_nodes [node_idx ]
1352
+
1353
+ def get_server_index (self , primary : str , list_size : int , skip_primary :bool ) -> int :
1354
+ # default to -1 if not found, so after incrementing it will be 0
1355
+ server_index = (self .primary_to_idx .get (primary , - 1 ) + 1 ) % list_size
1356
+ # If we skip primary, skip the zero-index node.
1357
+ if skip_primary and server_index == 0 and list_size > 1 :
1358
+ server_index = server_index + 1
1359
+ self .primary_to_idx [primary ] = server_index
1310
1360
return server_index
1311
1361
1312
1362
def reset (self ) -> None :
@@ -1401,41 +1451,23 @@ def _update_moved_slots(self):
1401
1451
# Reset moved_exception
1402
1452
self ._moved_exception = None
1403
1453
1404
- def get_node_from_slot (self , slot , read_from_replicas = False , server_type = None ):
1454
+ def get_node_from_slot (
1455
+ self , slot : int , read_from_replicas_mode : ReadFromReplicasMode
1456
+ ) -> "ClusterNode" :
1405
1457
"""
1406
1458
Gets a node that servers this hash slot
1407
1459
"""
1408
1460
if self ._moved_exception :
1409
1461
with self ._lock :
1410
1462
if self ._moved_exception :
1411
1463
self ._update_moved_slots ()
1412
-
1413
- if self .slots_cache .get (slot ) is None or len (self .slots_cache [slot ]) == 0 :
1414
- raise SlotNotCoveredError (
1415
- f'Slot "{ slot } " not covered by the cluster. '
1416
- f'"require_full_coverage={ self ._require_full_coverage } "'
1417
- )
1418
-
1419
- if read_from_replicas is True :
1420
- # get the server index in a Round-Robin manner
1421
- primary_name = self .slots_cache [slot ][0 ].name
1422
- node_idx = self .read_load_balancer .get_server_index (
1423
- primary_name , len (self .slots_cache [slot ])
1464
+
1465
+ return self .read_load_balancer .get_node_from_slot (
1466
+ slot ,
1467
+ self .slots_cache .get (slot , None ),
1468
+ read_from_replicas_mode ,
1424
1469
)
1425
- elif (
1426
- server_type is None
1427
- or server_type == PRIMARY
1428
- or len (self .slots_cache [slot ]) == 1
1429
- ):
1430
- # return a primary
1431
- node_idx = 0
1432
- else :
1433
- # return a replica
1434
- # randomly choose one of the replicas
1435
- node_idx = random .randint (1 , len (self .slots_cache [slot ]) - 1 )
1436
-
1437
- return self .slots_cache [slot ][node_idx ]
1438
-
1470
+
1439
1471
def get_nodes_by_server_type (self , server_type ):
1440
1472
"""
1441
1473
Get all nodes with the specified server type
@@ -1775,7 +1807,7 @@ def execute_command(self, *args):
1775
1807
channel = args [1 ]
1776
1808
slot = self .cluster .keyslot (channel )
1777
1809
node = self .cluster .nodes_manager .get_node_from_slot (
1778
- slot , self .cluster .read_from_replicas
1810
+ slot , self .cluster .read_from_replicas_mode
1779
1811
)
1780
1812
else :
1781
1813
# Get a random node
@@ -1915,7 +1947,7 @@ def __init__(
1915
1947
result_callbacks : Optional [Dict [str , Callable ]] = None ,
1916
1948
cluster_response_callbacks : Optional [Dict [str , Callable ]] = None ,
1917
1949
startup_nodes : Optional [List ["ClusterNode" ]] = None ,
1918
- read_from_replicas : bool = False ,
1950
+ read_from_replicas_mode : ReadFromReplicasMode = ReadFromReplicasMode . ReadFromPrimary ,
1919
1951
cluster_error_retry_attempts : int = 3 ,
1920
1952
reinitialize_steps : int = 5 ,
1921
1953
lock = None ,
@@ -1930,7 +1962,7 @@ def __init__(
1930
1962
result_callbacks or self .__class__ .RESULT_CALLBACKS .copy ()
1931
1963
)
1932
1964
self .startup_nodes = startup_nodes if startup_nodes else []
1933
- self .read_from_replicas = read_from_replicas
1965
+ self .read_from_replicas_mode = read_from_replicas_mode
1934
1966
self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
1935
1967
self .cluster_response_callbacks = cluster_response_callbacks
1936
1968
self .cluster_error_retry_attempts = cluster_error_retry_attempts
0 commit comments