Skip to content

Commit 78ba4ad

Browse files
dvora-hvladvildanov
authored andcommitted
Support client side caching with ConnectionPool (#3099)
* sync * async * fixs connection mocks * fix async connection mock * fix test_asyncio/test_connection.py::test_single_connection * add test for cache blacklist and flushdb at the end of each test * fix review comments
1 parent 8d9a59f commit 78ba4ad

File tree

10 files changed

+318
-246
lines changed

10 files changed

+318
-246
lines changed

redis/cache.py renamed to redis/_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ class _LocalCache:
178178
"""
179179

180180
def __init__(
181-
self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs
181+
self,
182+
max_size: int = 100,
183+
ttl: int = 0,
184+
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
185+
**kwargs,
182186
):
183187
self.max_size = max_size
184188
self.ttl = ttl

redis/asyncio/client.py

Lines changed: 18 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
cast,
2424
)
2525

26+
from redis._cache import (
27+
DEFAULT_BLACKLIST,
28+
DEFAULT_EVICTION_POLICY,
29+
DEFAULT_WHITELIST,
30+
_LocalCache,
31+
)
2632
from redis._parsers.helpers import (
2733
_RedisCallbacks,
2834
_RedisCallbacksRESP2,
@@ -37,12 +43,6 @@
3743
)
3844
from redis.asyncio.lock import Lock
3945
from redis.asyncio.retry import Retry
40-
from redis.cache import (
41-
DEFAULT_BLACKLIST,
42-
DEFAULT_EVICTION_POLICY,
43-
DEFAULT_WHITELIST,
44-
_LocalCache,
45-
)
4646
from redis.client import (
4747
EMPTY_RESPONSE,
4848
NEVER_DECODE,
@@ -66,7 +66,7 @@
6666
TimeoutError,
6767
WatchError,
6868
)
69-
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
69+
from redis.typing import ChannelT, EncodableT, KeyT
7070
from redis.utils import (
7171
HIREDIS_AVAILABLE,
7272
_set_info_logger,
@@ -293,6 +293,13 @@ def __init__(
293293
"lib_version": lib_version,
294294
"redis_connect_func": redis_connect_func,
295295
"protocol": protocol,
296+
"cache_enable": cache_enable,
297+
"client_cache": client_cache,
298+
"cache_max_size": cache_max_size,
299+
"cache_ttl": cache_ttl,
300+
"cache_eviction_policy": cache_eviction_policy,
301+
"cache_blacklist": cache_blacklist,
302+
"cache_whitelist": cache_whitelist,
296303
}
297304
# based on input, setup appropriate connection args
298305
if unix_socket_path is not None:
@@ -349,16 +356,6 @@ def __init__(
349356
# on a set of redis commands
350357
self._single_conn_lock = asyncio.Lock()
351358

352-
self.client_cache = client_cache
353-
if cache_enable:
354-
self.client_cache = _LocalCache(
355-
cache_max_size, cache_ttl, cache_eviction_policy
356-
)
357-
if self.client_cache is not None:
358-
self.cache_blacklist = cache_blacklist
359-
self.cache_whitelist = cache_whitelist
360-
self.client_cache_initialized = False
361-
362359
def __repr__(self):
363360
return f"{self.__class__.__name__}<{self.connection_pool!r}>"
364361

@@ -370,10 +367,6 @@ async def initialize(self: _RedisT) -> _RedisT:
370367
async with self._single_conn_lock:
371368
if self.connection is None:
372369
self.connection = await self.connection_pool.get_connection("_")
373-
if self.client_cache is not None:
374-
self.connection._parser.set_invalidation_push_handler(
375-
self._cache_invalidation_process
376-
)
377370
return self
378371

379372
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -592,8 +585,6 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
592585
close_connection_pool is None and self.auto_close_connection_pool
593586
):
594587
await self.connection_pool.disconnect()
595-
if self.client_cache:
596-
self.client_cache.flush()
597588

598589
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
599590
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
@@ -622,89 +613,28 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
622613
):
623614
raise error
624615

625-
def _cache_invalidation_process(
626-
self, data: List[Union[str, Optional[List[str]]]]
627-
) -> None:
628-
"""
629-
Invalidate (delete) all redis commands associated with a specific key.
630-
`data` is a list of strings, where the first string is the invalidation message
631-
and the second string is the list of keys to invalidate.
632-
(if the list of keys is None, then all keys are invalidated)
633-
"""
634-
if data[1] is not None:
635-
for key in data[1]:
636-
self.client_cache.invalidate(str_if_bytes(key))
637-
else:
638-
self.client_cache.flush()
639-
640-
async def _get_from_local_cache(self, command: str):
641-
"""
642-
If the command is in the local cache, return the response
643-
"""
644-
if (
645-
self.client_cache is None
646-
or command[0] in self.cache_blacklist
647-
or command[0] not in self.cache_whitelist
648-
):
649-
return None
650-
while not self.connection._is_socket_empty():
651-
await self.connection.read_response(push_request=True)
652-
return self.client_cache.get(command)
653-
654-
def _add_to_local_cache(
655-
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
656-
):
657-
"""
658-
Add the command and response to the local cache if the command
659-
is allowed to be cached
660-
"""
661-
if (
662-
self.client_cache is not None
663-
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
664-
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
665-
):
666-
self.client_cache.set(command, response, keys)
667-
668-
def delete_from_local_cache(self, command: str):
669-
"""
670-
Delete the command from the local cache
671-
"""
672-
try:
673-
self.client_cache.delete(command)
674-
except AttributeError:
675-
pass
676-
677616
# COMMAND EXECUTION AND PROTOCOL PARSING
678617
async def execute_command(self, *args, **options):
679618
"""Execute a command and return a parsed response"""
680619
await self.initialize()
681620
command_name = args[0]
682621
keys = options.pop("keys", None) # keys are used only for client side caching
683-
response_from_cache = await self._get_from_local_cache(args)
622+
pool = self.connection_pool
623+
conn = self.connection or await pool.get_connection(command_name, **options)
624+
response_from_cache = await conn._get_from_local_cache(args)
684625
if response_from_cache is not None:
685626
return response_from_cache
686627
else:
687-
pool = self.connection_pool
688-
conn = self.connection or await pool.get_connection(command_name, **options)
689-
690628
if self.single_connection_client:
691629
await self._single_conn_lock.acquire()
692630
try:
693-
if self.client_cache is not None and not self.client_cache_initialized:
694-
await conn.retry.call_with_retry(
695-
lambda: self._send_command_parse_response(
696-
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
697-
),
698-
lambda error: self._disconnect_raise(conn, error),
699-
)
700-
self.client_cache_initialized = True
701631
response = await conn.retry.call_with_retry(
702632
lambda: self._send_command_parse_response(
703633
conn, command_name, *args, **options
704634
),
705635
lambda error: self._disconnect_raise(conn, error),
706636
)
707-
self._add_to_local_cache(args, response, keys)
637+
conn._add_to_local_cache(args, response, keys)
708638
return response
709639
finally:
710640
if self.single_connection_client:

redis/asyncio/connection.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,15 @@
4646
ResponseError,
4747
TimeoutError,
4848
)
49-
from redis.typing import EncodableT
49+
from redis.typing import EncodableT, KeysT, ResponseT
5050
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
5151

52+
from .._cache import (
53+
DEFAULT_BLACKLIST,
54+
DEFAULT_EVICTION_POLICY,
55+
DEFAULT_WHITELIST,
56+
_LocalCache,
57+
)
5258
from .._parsers import (
5359
BaseParser,
5460
Encoder,
@@ -113,6 +119,9 @@ class AbstractConnection:
113119
"encoder",
114120
"ssl_context",
115121
"protocol",
122+
"client_cache",
123+
"cache_blacklist",
124+
"cache_whitelist",
116125
"_reader",
117126
"_writer",
118127
"_parser",
@@ -147,6 +156,13 @@ def __init__(
147156
encoder_class: Type[Encoder] = Encoder,
148157
credential_provider: Optional[CredentialProvider] = None,
149158
protocol: Optional[int] = 2,
159+
cache_enable: bool = False,
160+
client_cache: Optional[_LocalCache] = None,
161+
cache_max_size: int = 100,
162+
cache_ttl: int = 0,
163+
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
164+
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
165+
cache_whitelist: List[str] = DEFAULT_WHITELIST,
150166
):
151167
if (username or password) and credential_provider is not None:
152168
raise DataError(
@@ -204,6 +220,14 @@ def __init__(
204220
if p < 2 or p > 3:
205221
raise ConnectionError("protocol must be either 2 or 3")
206222
self.protocol = protocol
223+
if cache_enable:
224+
_cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy)
225+
else:
226+
_cache = None
227+
self.client_cache = client_cache if client_cache is not None else _cache
228+
if self.client_cache is not None:
229+
self.cache_blacklist = cache_blacklist
230+
self.cache_whitelist = cache_whitelist
207231

208232
def __del__(self, _warnings: Any = warnings):
209233
# For some reason, the individual streams don't get properly garbage
@@ -394,6 +418,11 @@ async def on_connect(self) -> None:
394418
# if a database is specified, switch to it. Also pipeline this
395419
if self.db:
396420
await self.send_command("SELECT", self.db)
421+
# if client caching is enabled, start tracking
422+
if self.client_cache:
423+
await self.send_command("CLIENT", "TRACKING", "ON")
424+
await self.read_response()
425+
self._parser.set_invalidation_push_handler(self._cache_invalidation_process)
397426

398427
# read responses from pipeline
399428
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -428,6 +457,9 @@ async def disconnect(self, nowait: bool = False) -> None:
428457
raise TimeoutError(
429458
f"Timed out closing connection after {self.socket_connect_timeout}"
430459
) from None
460+
finally:
461+
if self.client_cache:
462+
self.client_cache.flush()
431463

432464
async def _send_ping(self):
433465
"""Send PING, expect PONG in return"""
@@ -645,10 +677,62 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
645677
output.append(SYM_EMPTY.join(pieces))
646678
return output
647679

648-
def _is_socket_empty(self):
680+
def _socket_is_empty(self):
649681
"""Check if the socket is empty"""
650682
return not self._reader.at_eof()
651683

684+
def _cache_invalidation_process(
685+
self, data: List[Union[str, Optional[List[str]]]]
686+
) -> None:
687+
"""
688+
Invalidate (delete) all redis commands associated with a specific key.
689+
`data` is a list of strings, where the first string is the invalidation message
690+
and the second string is the list of keys to invalidate.
691+
(if the list of keys is None, then all keys are invalidated)
692+
"""
693+
if data[1] is not None:
694+
self.client_cache.flush()
695+
else:
696+
for key in data[1]:
697+
self.client_cache.invalidate(str_if_bytes(key))
698+
699+
async def _get_from_local_cache(self, command: str):
700+
"""
701+
If the command is in the local cache, return the response
702+
"""
703+
if (
704+
self.client_cache is None
705+
or command[0] in self.cache_blacklist
706+
or command[0] not in self.cache_whitelist
707+
):
708+
return None
709+
while not self._socket_is_empty():
710+
await self.read_response(push_request=True)
711+
return self.client_cache.get(command)
712+
713+
def _add_to_local_cache(
714+
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
715+
):
716+
"""
717+
Add the command and response to the local cache if the command
718+
is allowed to be cached
719+
"""
720+
if (
721+
self.client_cache is not None
722+
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
723+
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
724+
):
725+
self.client_cache.set(command, response, keys)
726+
727+
def delete_from_local_cache(self, command: str):
728+
"""
729+
Delete the command from the local cache
730+
"""
731+
try:
732+
self.client_cache.delete(command)
733+
except AttributeError:
734+
pass
735+
652736

653737
class Connection(AbstractConnection):
654738
"Manages TCP communication to and from a Redis server"

0 commit comments

Comments
 (0)