Skip to content

Commit f72d5da

Browse files
johan-seesawjhenkens
authored andcommitted
Fix get_node_from_slot during resharding
1 parent 70b4f48 commit f72d5da

File tree

2 files changed

+40
-49
lines changed

2 files changed

+40
-49
lines changed

redis/asyncio/cluster.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,23 +1244,18 @@ def _update_moved_slots(self) -> None:
12441244
def get_node_from_slot(
12451245
self, slot: int, read_from_replicas: bool = False
12461246
) -> "ClusterNode":
1247+
"""
1248+
Gets a node that serves this hash slot
1249+
"""
12471250
if self._moved_exception:
12481251
self._update_moved_slots()
1249-
1250-
try:
1251-
if read_from_replicas:
1252-
# get the server index in a Round-Robin manner
1253-
primary_name = self.slots_cache[slot][0].name
1254-
node_idx = self.read_load_balancer.get_server_index(
1255-
primary_name, len(self.slots_cache[slot])
1256-
)
1257-
return self.slots_cache[slot][node_idx]
1258-
return self.slots_cache[slot][0]
1259-
except (IndexError, TypeError):
1260-
raise SlotNotCoveredError(
1261-
f'Slot "{slot}" not covered by the cluster. '
1262-
f'"require_full_coverage={self.require_full_coverage}"'
1252+
1253+
return self.read_load_balancer.get_node_from_slot(
1254+
slot,
1255+
self.slots_cache.get(slot, None),
1256+
read_from_replicas,
12631257
)
1258+
12641259

12651260
def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
12661261
return [

redis/cluster.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,18 +1324,32 @@ class LoadBalancer:
13241324
Round-Robin Load Balancing
13251325
"""
13261326

1327-
def __init__(self, start_index: int = 0) -> None:
1328-
self.primary_to_idx = {}
1329-
self.start_index = start_index
1327+
def __init__(self) -> None:
1328+
self.primary_name_to_last_used_index:dict[str,int] = {}
1329+
1330+
def get_node_from_slot(self, slot_index: int, slot_nodes: list[ClusterNode] | None, read_from_replicas: bool) -> ClusterNode:
1331+
if slot_nodes is None or len(slot_nodes) == 0:
1332+
raise SlotNotCoveredError(
1333+
f'Slot "{slot_index}" not covered by the cluster. '
1334+
)
1335+
if not read_from_replicas:
1336+
node_idx = 0
1337+
else:
1338+
primary_name = slot_nodes[0].name
1339+
node_idx = self.get_server_index(
1340+
primary_name, len(slot_nodes)
1341+
)
1342+
return slot_nodes[node_idx]
13301343

1331-
def get_server_index(self, primary: str, list_size: int) -> int:
1332-
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
1344+
def get_server_index(self, primary: str, list_size: int) -> int:
1345+
# default to -1 if not found, so after incrementing it will be 0
1346+
server_index = (self.primary_name_to_last_used_index.get(primary, -1) + 1) % list_size
13331347
# Update the index
1334-
self.primary_to_idx[primary] = (server_index + 1) % list_size
1348+
self.primary_name_to_last_used_index[primary] = server_index
13351349
return server_index
13361350

13371351
def reset(self) -> None:
1338-
self.primary_to_idx.clear()
1352+
self.primary_name_to_last_used_index.clear()
13391353

13401354

13411355
class NodesManager:
@@ -1426,41 +1440,23 @@ def _update_moved_slots(self):
14261440
# Reset moved_exception
14271441
self._moved_exception = None
14281442

1429-
def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
1443+
def get_node_from_slot(
1444+
self, slot: int, read_from_replicas: bool
1445+
) -> "ClusterNode":
14301446
"""
1431-
Gets a node that servers this hash slot
1447+
Gets a node that serves this hash slot
14321448
"""
14331449
if self._moved_exception:
14341450
with self._lock:
14351451
if self._moved_exception:
14361452
self._update_moved_slots()
1437-
1438-
if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0:
1439-
raise SlotNotCoveredError(
1440-
f'Slot "{slot}" not covered by the cluster. '
1441-
f'"require_full_coverage={self._require_full_coverage}"'
1453+
1454+
return self.read_load_balancer.get_node_from_slot(
1455+
slot,
1456+
self.slots_cache.get(slot, None),
1457+
read_from_replicas,
14421458
)
1443-
1444-
if read_from_replicas is True:
1445-
# get the server index in a Round-Robin manner
1446-
primary_name = self.slots_cache[slot][0].name
1447-
node_idx = self.read_load_balancer.get_server_index(
1448-
primary_name, len(self.slots_cache[slot])
1449-
)
1450-
elif (
1451-
server_type is None
1452-
or server_type == PRIMARY
1453-
or len(self.slots_cache[slot]) == 1
1454-
):
1455-
# return a primary
1456-
node_idx = 0
1457-
else:
1458-
# return a replica
1459-
# randomly choose one of the replicas
1460-
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)
1461-
1462-
return self.slots_cache[slot][node_idx]
1463-
1459+
14641460
def get_nodes_by_server_type(self, server_type):
14651461
"""
14661462
Get all nodes with the specified server type

0 commit comments

Comments
 (0)