Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions cashews/backends/diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,24 @@
from diskcache import Cache, FanoutCache

from cashews._typing import Key, Value
from cashews.serialize import SerializerMixin
from cashews.serialize import DEFAULT_SERIALIZER, Serializer
from cashews.utils import Bitarray

from .interface import NOT_EXIST, UNLIMITED, Backend


class _DiskCache(Backend):
class DiskCache(Backend):
def __init__(self, *args, directory=None, shards=8, **kwargs: Any) -> None:
serializer = kwargs.pop("serializer", DEFAULT_SERIALIZER)
self.__is_init = False
self._set_locks: dict[str, asyncio.Lock] = {}
self._sharded = shards > 1
if not self._sharded:
self._cache = Cache(directory=directory, **kwargs)
else:
self._cache = FanoutCache(directory=directory, shards=shards, **kwargs)
super().__init__(**kwargs)
super().__init__(serializer=serializer, **kwargs)
self._serializer: Serializer

async def init(self):
self.__is_init = True
Expand All @@ -46,6 +48,7 @@ async def set(
expire: float | None = None,
exist: bool | None = None,
) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
future = self._run_in_executor(self._set, key, value, expire, exist)
if exist is not None:
# we should have async lock until value real set
Expand All @@ -69,25 +72,34 @@ async def set_raw(self, key: Key, value: Any, **kwargs: Any):
return self._cache.set(key, value, **kwargs)

async def get(self, key: Key, default: Value | None = None) -> Value:
return await self._run_in_executor(self._cache.get, key, default)
value = await self._run_in_executor(self._cache.get, key, default)
return await self._serializer.decode(self, key=key, value=value, default=default)

async def get_raw(self, key: Key) -> Value:
return self._cache.get(key)

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value]:
return await self._run_in_executor(self._get_many, keys, default)
async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
if not keys:
return ()
values = await self._run_in_executor(self._get_many, keys, default)
values = await asyncio.gather(
*[self._serializer.decode(self, key=key, value=value, default=default) for key, value in zip(keys, values)]
)
return tuple(None if isinstance(value, Bitarray) else value for value in values)

def _get_many(self, keys: list[Key], default: Value | None = None):
values = []
for key in keys:
val = self._cache.get(key, default=default)
if isinstance(val, Bitarray):
val = None
values.append(val)
return values

async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
return await self._run_in_executor(self._set_many, pairs, expire)
_pairs = {}
for key, value in pairs.items():
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
_pairs[key] = value
return await self._run_in_executor(self._set_many, _pairs, expire)

def _set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
for key, value in pairs.items():
Expand Down Expand Up @@ -215,6 +227,7 @@ async def is_locked(
return await self.exists(key)

async def unlock(self, key: Key, value: Value) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=None)
return await self._run_in_executor(self._unlock, key, value)

def _unlock(self, key: Key, value: Value) -> bool:
Expand Down Expand Up @@ -269,7 +282,3 @@ async def set_pop(self, key: Key, count: int = 100) -> Iterable[str]:

async def get_keys_count(self) -> int:
return await self._run_in_executor(lambda: len(self._cache))


class DiskCache(SerializerMixin, _DiskCache):
pass
5 changes: 4 additions & 1 deletion cashews/backends/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from cashews.commands import ALL, Command
from cashews.exceptions import CacheBackendInteractionError, LockedError
from cashews.serialize import Serializer

if TYPE_CHECKING: # pragma: no cover
from cashews._typing import Default, Key, OnRemoveCallback, Value
Expand Down Expand Up @@ -226,8 +227,10 @@ def enable(self, *cmds: Command) -> None:


class Backend(ControlMixin, _BackendInterface, metaclass=ABCMeta):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, serializer: Serializer | None = None, **kwargs) -> None:
super().__init__()
self._id = uuid.uuid4().hex
self._serializer = serializer
self._on_remove_callbacks: list[OnRemoveCallback] = []

def on_remove_callback(self, callback: OnRemoveCallback) -> None:
Expand Down
22 changes: 13 additions & 9 deletions cashews/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from copy import copy
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Mapping, overload

from cashews.serialize import SerializerMixin
from cashews.utils import Bitarray, get_obj_size

from .interface import NOT_EXIST, UNLIMITED, Backend
Expand All @@ -22,7 +21,7 @@
_missed = object()


class _Memory(Backend):
class Memory(Backend):
"""
Inmemory backend lru with ttl
"""
Expand Down Expand Up @@ -74,17 +73,22 @@ async def set(
) -> bool:
if exist is not None and (key in self.store) is not exist:
return False
if self._serializer:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)
return True

async def set_raw(self, key: Key, value: Value, **kwargs: Any) -> None:
self.store[key] = value
self.store[key] = (None, value)

async def get(self, key: Key, default: Value | None = None) -> Value:
return await self._get(key, default=default)

async def get_raw(self, key: Key) -> Value:
return self.store.get(key)
val = self.store.get(key)
if val:
return val[1]
return None

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
values = []
Expand All @@ -97,6 +101,8 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu

async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None):
for key, value in pairs.items():
if self._serializer:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
self._set(key, value, expire)

async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore
Expand Down Expand Up @@ -200,7 +206,9 @@ async def _get(self, key: Key, default: Default | None = None) -> Value | None:
if expire_at and expire_at < time.time():
await self._delete(key)
return default
return value
if not self._serializer:
return value
return await self._serializer.decode(self, key=key, value=value, default=default)

async def _key_exist(self, key: Key) -> bool:
return (await self._get(key, default=_missed)) is not _missed
Expand Down Expand Up @@ -279,7 +287,3 @@ async def close(self):
del self.__remove_expired_stop
self.__remove_expired_stop = None
self.__is_init = False


class Memory(SerializerMixin, _Memory):
pass
7 changes: 2 additions & 5 deletions cashews/backends/redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from cashews.picklers import DEFAULT_PICKLE
from cashews.serialize import SerializerMixin

from .backend import _Redis

__all__ = ["Redis"]


class Redis(SerializerMixin, _Redis):
pickle_type = DEFAULT_PICKLE
class Redis(_Redis):
pass
17 changes: 11 additions & 6 deletions cashews/backends/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from cashews._typing import Key, Value
from cashews.backends.interface import Backend
from cashews.serialize import DEFAULT_SERIALIZER, Serializer

from .client import Redis, SafePipeline, SafeRedis

Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(
self._kwargs = kwargs
self._address = address
self.__is_init = False
super().__init__()
super().__init__(serializer=kwargs.pop("serializer", None))
self._serializer: Serializer = self._serializer or DEFAULT_SERIALIZER

@property
def is_init(self) -> bool:
Expand Down Expand Up @@ -105,6 +107,7 @@ async def set(
expire: float | None = None,
exist=None,
) -> bool:
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
nx = xx = False
if exist is True:
xx = True
Expand All @@ -118,6 +121,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None
px = int(expire * 1000) if expire else None
async with self._pipeline as pipe:
for key, value in pairs.items():
value = await self._serializer.encode(self, key=key, value=value, expire=expire)
await pipe.set(key, value, px=px)
await pipe.execute()

Expand Down Expand Up @@ -211,23 +215,24 @@ async def get_size(self, key: Key) -> int:

async def get(self, key: Key, default: Value | None = None) -> Value:
value = await self._client.get(key)
return self._transform_value(value, default)
return await self._transform_value(key, value, default)

async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]:
if not keys:
return ()
values = await self._client.mget(*keys)
if values is None:
return tuple([default] * len(keys))
return tuple(self._transform_value(value, default) for value in values)
return tuple(
await asyncio.gather(*[self._transform_value(key, value, default) for key, value in zip(keys, values)])
)

@staticmethod
def _transform_value(value: bytes | None, default: Value | None):
async def _transform_value(self, key: Key, value: bytes | None, default: Value | None):
if value is None:
return default
if value.isdigit():
return int(value)
return value
return await self._serializer.decode(self, key=key, value=value, default=default)

async def incr(self, key: Key, value: int = 1, expire: float | None = None) -> int:
if not expire:
Expand Down
7 changes: 4 additions & 3 deletions cashews/backends/redis/client_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def __init__(
self._expire_for_recently_update = 5
self._listen_started = asyncio.Event()
self.__listen_stop = asyncio.Event()
super().__init__(*args, suppress=suppress, **kwargs)
kwargs["suppress"] = suppress
super().__init__(*args, **kwargs)

async def init(self):
self._listen_started = asyncio.Event()
self.__listen_stop = asyncio.Event()
self._listen_started.clear()
self.__listen_stop.clear()
await self._local_cache.init()
await self._recently_update.init()
await super().init()
Expand Down
2 changes: 2 additions & 0 deletions cashews/backends/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class TransactionBackend(Backend):
"_local_cache",
"_to_delete",
"__disable",
"_id",
]

def __init__(self, backend: Backend):
self._backend = backend
self._local_cache = Memory()
self._to_delete: set[Key] = set()
super().__init__()
self._id = backend._id

def _key_is_delete(self, key: Key) -> bool:
if key in self._to_delete:
Expand Down
1 change: 1 addition & 0 deletions cashews/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Command(Enum):
DELETE_MANY = "delete_many"
DELETE_MATCH = "delete_match"

EXISTS = "exists"
EXIST = "exists"
SCAN = "scan"
INCR = "incr"
Expand Down
2 changes: 1 addition & 1 deletion cashews/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def set_callback(key: str, result: Any):
_data = None
else:
_key = calls[0][0]
_data = calls[0][1][0]["value"]
_data = calls[0][1]["value"]
_etag = await self._set_etag(_key, _data)
return self._response_etag(response, _etag, request_etag)

Expand Down
2 changes: 1 addition & 1 deletion cashews/decorators/cache/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, previous_level=0, unset_token=None):
self._previous_level = previous_level

def _set(self, key: Key, **kwargs: Any) -> None:
self._value.append((key, [kwargs]))
self._value.append((key, kwargs))

@property
def calls(self):
Expand Down
19 changes: 14 additions & 5 deletions cashews/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,20 @@ async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *ar

def memory_limit(min_bytes: int = 0, max_bytes: int | None = None) -> Middleware:
async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T | None:
if cmd != Command.SET:
return await call(*args, **kwargs)
value_size = get_obj_size(kwargs["value"])
if max_bytes and value_size > max_bytes or value_size < min_bytes:
return None
if cmd == Command.SET_MANY:
pairs = {}
for key, value in kwargs["pairs"].items():
value_size = get_obj_size(value)
if max_bytes and value_size > max_bytes or value_size < min_bytes:
continue
pairs[key] = value
if not pairs:
return None
kwargs["pairs"] = pairs
elif cmd == Command.SET:
value_size = get_obj_size(kwargs["value"])
if max_bytes and value_size > max_bytes or value_size < min_bytes:
return None
return await call(*args, **kwargs)

return _middleware
9 changes: 6 additions & 3 deletions cashews/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,13 @@ def generate_key_template(func: Callable, exclude_parameters: Container = ()) ->

class _Star:
def __getattr__(self, item):
return _Star()
return self

def __getitem__(self, item):
return _Star()
return self

def __call__(self, *args, **kwargs):
return "*"


def _check_key_params(key: KeyOrTemplate, func_params: Iterable[str]):
Expand Down Expand Up @@ -142,7 +145,7 @@ def _get_func_signature(func: Callable):


def _get_call_values(func: Callable, args: Args, kwargs: Kwargs):
if len(args) == 0:
if not args:
_kwargs = {**kwargs}
for name, parameter in _get_func_signature(func).parameters.items():
if parameter.kind != inspect.Parameter.VAR_KEYWORD and name in _kwargs:
Expand Down
Loading
Loading