diff --git a/cashews/key_context.py b/cashews/key_context.py index 837294d5..ba2d1708 100644 --- a/cashews/key_context.py +++ b/cashews/key_context.py @@ -7,8 +7,8 @@ from typing import Any _REWRITE = "__rewrite" -_template_context: ContextVar[dict[str, Any]] = ContextVar("template_context") -_template_context.set({_REWRITE: False}) +TContext = dict[str, Any] +_template_context: ContextVar[TContext | None] = ContextVar("template_context", default=None) _empty = object() @@ -22,8 +22,7 @@ def context(rewrite=_empty, **values) -> Iterator[None]: ) else: rewrite = False - new_context = {**_template_context.get(), **values} - new_context[_REWRITE] = rewrite + new_context = {**_get_raw(), **values, _REWRITE: rewrite} token = _template_context.set(new_context) try: yield @@ -31,11 +30,15 @@ def context(rewrite=_empty, **values) -> Iterator[None]: _template_context.reset(token) -def get() -> tuple[dict[str, Any], bool]: - _context = {**_template_context.get()} +def get() -> tuple[TContext, bool]: + _context = {**_get_raw()} # a copy return _context, _context.pop(_REWRITE) +def _get_raw() -> TContext: + return _template_context.get() or {_REWRITE: False} + + def register(*names: str) -> None: warnings.warn( "`register_key_context` deprecated and will be removed in next release, use @ notation", @@ -43,4 +46,4 @@ def register(*names: str) -> None: stacklevel=2, ) new_names = dict.fromkeys(names, "") - _template_context.set({**new_names, **_template_context.get()}) + _template_context.set({**new_names, **_get_raw()}) diff --git a/tests/test_bugs.py b/tests/test_bugs.py new file mode 100644 index 00000000..aefc5118 --- /dev/null +++ b/tests/test_bugs.py @@ -0,0 +1,17 @@ +import asyncio + +from cashews import Cache + + +async def test_issue_382(): + cache = Cache() + cache.setup("mem://") + + @cache(ttl=60, key="item:{item_id}") + async def get_item(item_id: int): + return f"item_{item_id}" + + def sync_get_item(item_id: int): + return asyncio.run(get_item(item_id)) + + await asyncio.get_running_loop().run_in_executor(None, sync_get_item, 123)