diff --git a/src/fastapi_redis_cache/cache.py b/src/fastapi_redis_cache/cache.py index 5257a14..7fc84bb 100644 --- a/src/fastapi_redis_cache/cache.py +++ b/src/fastapi_redis_cache/cache.py @@ -19,16 +19,18 @@ ) -def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS): +def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS, fastapi_route: bool = True): """Enable caching behavior for the decorated function. Args: expire (Union[int, timedelta], optional): The number of seconds from now when the cached response should expire. Defaults to 31,536,000 seconds (i.e., the number of seconds in one year). + fastapi_route (bool): True if this is caching a FastAPI route function or + False if caching a standard Python function. """ - def outer_wrapper(func): + def fastapi_route_wrapper(func): @wraps(func) async def inner_wrapper(*args, **kwargs): """Return cached value if one exists, otherwise evaluate the wrapped function and cache the result.""" @@ -80,7 +82,54 @@ async def inner_wrapper(*args, **kwargs): return inner_wrapper - return outer_wrapper + def standard_wrapper(func): + @wraps(func) + async def inner_wrapper_async(*args, **kwargs): + """Return cached value if one exists, otherwise evaluate the wrapped function and cache the result.""" + + func_kwargs = kwargs.copy() + redis_cache = FastApiRedisCache() + if redis_cache.not_connected: + # if the redis client is not connected or request is not cacheable, no caching behavior is performed. + return await func(*args, **kwargs) + key = redis_cache.get_cache_key(func, *args, **kwargs) + ttl, in_cache = redis_cache.check_cache(key) + if in_cache: + return deserialize_json(in_cache) + + response_data = await func(*args, **kwargs) + ttl = calculate_ttl(expire) + redis_cache.add_to_cache(key, response_data, ttl) + return response_data + + @wraps(func) + def inner_wrapper_sync(*args, **kwargs): + """Return cached value if one exists, otherwise evaluate the wrapped function and cache the result.""" + + func_kwargs = kwargs.copy() + redis_cache = FastApiRedisCache() + if redis_cache.not_connected: + # if the redis client is not connected or request is not cacheable, no caching behavior is performed. + return func(*args, **kwargs) + key = redis_cache.get_cache_key(func, *args, **kwargs) + ttl, in_cache = redis_cache.check_cache(key) + if in_cache: + return deserialize_json(in_cache) + + response_data = func(*args, **kwargs) + ttl = calculate_ttl(expire) + redis_cache.add_to_cache(key, response_data, ttl) + return response_data + + if asyncio.iscoroutinefunction(func): + return inner_wrapper_async + else: + return inner_wrapper_sync + + if fastapi_route: + return fastapi_route_wrapper + else: + return standard_wrapper async def get_api_response_async(func, *args, **kwargs):