@@ -1324,18 +1324,32 @@ class LoadBalancer:
1324
1324
Round-Robin Load Balancing
1325
1325
"""
1326
1326
1327
- def __init__ (self , start_index : int = 0 ) -> None :
1328
- self .primary_to_idx = {}
1329
- self .start_index = start_index
1327
+ def __init__ (self ) -> None :
1328
+ self .primary_name_to_last_used_index :dict [str ,int ] = {}
1329
+
1330
+ def get_node_from_slot (self , slot_index : int , slot_nodes : list [ClusterNode ] | None , read_from_replicas : bool ) -> ClusterNode :
1331
+ if slot_nodes is None or len (slot_nodes ) == 0 :
1332
+ raise SlotNotCoveredError (
1333
+ f'Slot "{ slot_index } " not covered by the cluster. '
1334
+ )
1335
+ if not read_from_replicas :
1336
+ node_idx = 0
1337
+ else :
1338
+ primary_name = slot_nodes [0 ].name
1339
+ node_idx = self .get_server_index (
1340
+ primary_name , len (slot_nodes )
1341
+ )
1342
+ return slot_nodes [node_idx ]
1330
1343
1331
- def get_server_index (self , primary : str , list_size : int ) -> int :
1332
- server_index = self .primary_to_idx .setdefault (primary , self .start_index )
1344
+ def get_server_index (self , primary : str , list_size : int ) -> int :
1345
+ # default to -1 if not found, so after incrementing it will be 0
1346
+ server_index = (self .primary_name_to_last_used_index .get (primary , - 1 ) + 1 ) % list_size
1333
1347
# Update the index
1334
- self .primary_to_idx [primary ] = ( server_index + 1 ) % list_size
1348
+ self .primary_name_to_last_used_index [primary ] = server_index
1335
1349
return server_index
1336
1350
1337
1351
def reset (self ) -> None :
1338
- self .primary_to_idx .clear ()
1352
+ self .primary_name_to_last_used_index .clear ()
1339
1353
1340
1354
1341
1355
class NodesManager :
@@ -1426,41 +1440,23 @@ def _update_moved_slots(self):
1426
1440
# Reset moved_exception
1427
1441
self ._moved_exception = None
1428
1442
1429
- def get_node_from_slot (self , slot , read_from_replicas = False , server_type = None ):
1443
+ def get_node_from_slot (
1444
+ self , slot : int , read_from_replicas : bool
1445
+ ) -> "ClusterNode" :
1430
1446
"""
1431
- Gets a node that servers this hash slot
1447
+ Gets a node that serves this hash slot
1432
1448
"""
1433
1449
if self ._moved_exception :
1434
1450
with self ._lock :
1435
1451
if self ._moved_exception :
1436
1452
self ._update_moved_slots ()
1437
-
1438
- if self .slots_cache . get ( slot ) is None or len ( self . slots_cache [ slot ]) == 0 :
1439
- raise SlotNotCoveredError (
1440
- f'Slot " { slot } " not covered by the cluster. '
1441
- f'"require_full_coverage= { self . _require_full_coverage } "'
1453
+
1454
+ return self .read_load_balancer . get_node_from_slot (
1455
+ slot ,
1456
+ self . slots_cache . get ( slot , None ),
1457
+ read_from_replicas ,
1442
1458
)
1443
-
1444
- if read_from_replicas is True :
1445
- # get the server index in a Round-Robin manner
1446
- primary_name = self .slots_cache [slot ][0 ].name
1447
- node_idx = self .read_load_balancer .get_server_index (
1448
- primary_name , len (self .slots_cache [slot ])
1449
- )
1450
- elif (
1451
- server_type is None
1452
- or server_type == PRIMARY
1453
- or len (self .slots_cache [slot ]) == 1
1454
- ):
1455
- # return a primary
1456
- node_idx = 0
1457
- else :
1458
- # return a replica
1459
- # randomly choose one of the replicas
1460
- node_idx = random .randint (1 , len (self .slots_cache [slot ]) - 1 )
1461
-
1462
- return self .slots_cache [slot ][node_idx ]
1463
-
1459
+
1464
1460
def get_nodes_by_server_type (self , server_type ):
1465
1461
"""
1466
1462
Get all nodes with the specified server type
0 commit comments