diff --git a/cashews/backends/interface.py b/cashews/backends/interface.py index afa2cd72..d97e5458 100644 --- a/cashews/backends/interface.py +++ b/cashews/backends/interface.py @@ -146,7 +146,9 @@ async def is_locked( async def unlock(self, key: Key, value: Value) -> bool: ... @asynccontextmanager - async def lock(self, key: Key, expire: float, wait: bool = True) -> AsyncGenerator[None, None]: + async def lock( + self, key: Key, expire: float, wait: bool = True, check_interval: float = 0 + ) -> AsyncGenerator[None, None]: identifier = str(uuid.uuid4()) while True: lock = await self.set_lock(key, identifier, expire=expire) @@ -163,7 +165,7 @@ async def lock(self, key: Key, expire: float, wait: bool = True) -> AsyncGenerat return if wait: - await asyncio.sleep(0) + await asyncio.sleep(check_interval) continue raise LockedError(f"Key {key} is already locked") try: diff --git a/cashews/decorators/locked.py b/cashews/decorators/locked.py index 737f8958..7e46202b 100644 --- a/cashews/decorators/locked.py +++ b/cashews/decorators/locked.py @@ -21,6 +21,7 @@ def locked( ttl: TTL | None = None, wait: bool = True, prefix: str = "lock", + check_interval: float = 0, ) -> Callable[[DecoratedFunc], DecoratedFunc]: """ Decorator that can help you to solve Cache stampede problem (https://en.wikipedia.org/wiki/Cache_stampede), @@ -32,14 +33,15 @@ def locked( :param ttl: duration to lock wrapped function call :param wait: if true - wait till lock is released :param prefix: custom prefix for key, default 'lock' + :param check_interval: interval in seconds between lock checks while it is waiting for the lock to be released """ ttl = ttl_to_seconds(ttl) def _decor(func): _key_template = get_cache_key_template(func, key=key, prefix=prefix) if inspect.isasyncgenfunction(func): - return _asyncgen_lock(func, backend, ttl, _key_template, wait) - return _coroutine_lock(func, backend, ttl, _key_template, wait) + return _asyncgen_lock(func, backend, ttl, _key_template, wait, check_interval) + return _coroutine_lock(func, backend, ttl, _key_template, wait, check_interval) return _decor @@ -50,12 +52,13 @@ def _coroutine_lock( ttl: TTL | None, key_template: KeyOrTemplate, wait: bool, + check_interval: float, ) -> DecoratedFunc: @wraps(func) async def _wrap(*args, **kwargs): _ttl = ttl_to_seconds(ttl, *args, **kwargs, with_callable=True) _cache_key = get_cache_key(func, key_template, args, kwargs) - async with backend.lock(_cache_key, _ttl, wait=wait): + async with backend.lock(_cache_key, _ttl, wait=wait, check_interval=check_interval): return await func(*args, **kwargs) return _wrap # type: ignore[return-value] @@ -67,12 +70,13 @@ def _asyncgen_lock( ttl: TTL | None, key_template: KeyOrTemplate, wait: bool, + check_interval: float, ): @wraps(func) async def _wrap(*args, **kwargs): _ttl = ttl_to_seconds(ttl, *args, **kwargs, with_callable=True) _cache_key = get_cache_key(func, key_template, args, kwargs) - async with backend.lock(_cache_key, _ttl, wait=wait): + async with backend.lock(_cache_key, _ttl, wait=wait, check_interval=check_interval): async for chunk in func(*args, **kwargs): yield chunk return diff --git a/cashews/wrapper/decorators.py b/cashews/wrapper/decorators.py index 874b8fa2..dafc83f8 100644 --- a/cashews/wrapper/decorators.py +++ b/cashews/wrapper/decorators.py @@ -360,6 +360,7 @@ def locked( key: KeyOrTemplate | None = None, wait: bool = True, prefix: str = "locked", + check_interval: float = 0, ) -> Callable[[DecoratedFunc], DecoratedFunc]: return decorators.locked( backend=self, # type: ignore[arg-type] @@ -367,6 +368,7 @@ def locked( key=key, wait=wait, prefix=prefix, + check_interval=check_interval, ) def bloom( diff --git a/tests/conftest.py b/tests/conftest.py index 29555abb..d3802bb6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,6 +88,7 @@ async def _backend(request, redis_dsn, backend_factory): backend = backend_factory(Memory, check_interval=0.01) try: await backend.init() + await backend.clear() yield backend, request.param finally: await backend.close() diff --git a/tests/test_lock_decorator.py b/tests/test_lock_decorator.py index 1a828ca9..c15350b4 100644 --- a/tests/test_lock_decorator.py +++ b/tests/test_lock_decorator.py @@ -35,6 +35,28 @@ async def func(): assert mock.call_count == 20 +async def test_lock_cache_parallel_check_interval(cache): + mock = Mock() + mock_middleware = Mock() + + async def middleware(call, cmd, backend, *args, **kwargs): + mock_middleware() + return await call(*args, **kwargs) + + cache.add_middleware(middleware) + + @cache.locked(key="key", ttl=1, wait=True, check_interval=0.05) + async def func_interval(): + await asyncio.sleep(0.01) + mock() + + for _ in range(2): + await asyncio.gather(*[func_interval() for _ in range(2)]) + + assert mock.call_count == 4 + assert mock_middleware.call_count == 12 + + async def test_lock_cache_iterator(cache): mock = Mock() chunks = range(10) diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index c29092bf..26175e02 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -209,7 +209,7 @@ async def func(): async def test_cache_lock(cache: Cache): m = Mock() - @cache(ttl=3, lock=True) + @cache(ttl=3, lock=True, protected=False) async def my_func(val=1): await asyncio.sleep(0) # for task switching m(val)