Skip to content

Commit 8a566db

Browse files
committed
improve cache
1 parent 7e8c0d5 commit 8a566db

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

indexd/auth/drivers/__init__.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,75 @@
11
import functools
22
import time
33

4+
import flask
5+
import jwt
46

5-
def timed_cache(ttl_seconds):
7+
8+
def request_auth_cache(maximum_ttl_seconds=1800):
69
"""
7-
Decorator to cache the result of a function for a specified time-to-live (TTL) in seconds.
10+
Decorator to cache the result of a function for a specified maximum TTL in seconds.
11+
The actual cache duration is determined by the 'token' parameter's expiration.
12+
If no token is provided, the maximum TTL is used and the Authorization header is included in the cache key.
813
"""
14+
915
def decorator(func):
1016
cache = {}
1117

1218
@functools.wraps(func)
1319
def wrapper(*args, **kwargs):
1420
key = functools._make_key(args, kwargs, typed=False)
1521
now = time.time()
22+
23+
# Extract token from args or kwargs
24+
token = kwargs.get("token")
25+
if token is None:
26+
# print("No token provided in kwargs")
27+
if type(args[0]) is str:
28+
# If the first argument is a string, assume it's the token
29+
token = args[0]
30+
else:
31+
token = args[1] if len(args) > 1 else None
32+
33+
# Calculate token expiration duration
34+
if token:
35+
# Decode the JWT token without verifying the signature to get the 'exp' claim
36+
# If the token is a string, decode it
37+
token = token.encode('utf-8') if isinstance(token, str) else token
38+
39+
# we could check for jwt.exceptions.DecodeError here, but we assume the token is valid
40+
# and just decode it to get the expiration time
41+
payload = jwt.decode(token, options={"verify_signature": False})
42+
43+
exp = payload.get("exp", now + maximum_ttl_seconds)
44+
token_ttl = max(0, exp - now)
45+
else:
46+
# If no token is provided, use the maximum TTL and add the Authorization header to the key.
47+
# This is useful for cases where the function does not require a token,
48+
# but still needs to cache based on the Authorization header.
49+
auth_header = flask.request.headers.get('Authorization', '')
50+
# Add the Authorization header to the key
51+
key = functools._make_key(args + (auth_header,), kwargs, typed=False)
52+
token_ttl = maximum_ttl_seconds
53+
54+
ttl = min(token_ttl, maximum_ttl_seconds)
55+
56+
# Check if the result is already cached and still valid
1657
if key in cache:
1758
result, timestamp = cache[key]
18-
if now - timestamp < ttl_seconds:
59+
if now - timestamp < ttl:
1960
return result
61+
62+
# If not cached or expired, call the function and cache the result
2063
result = func(*args, **kwargs)
2164
cache[key] = (result, now)
65+
66+
# Clean up any old cache entries
67+
keys_to_delete = [k for k, (v, t) in cache.items() if now - t >= ttl]
68+
for k in keys_to_delete:
69+
del cache[k]
70+
2271
return result
72+
2373
return wrapper
74+
2475
return decorator

indexd/auth/drivers/alchemy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import hashlib
2-
import sys
32

43
from contextlib import contextmanager
54

@@ -12,7 +11,7 @@
1211
from sqlalchemy.ext.declarative import declarative_base
1312

1413
from indexd.auth.driver import AuthDriverABC
15-
from indexd.auth.drivers import timed_cache
14+
from indexd.auth.drivers import request_auth_cache
1615

1716
from indexd.auth.errors import AuthError, AuthzError
1817

@@ -180,7 +179,7 @@ def resources(self):
180179
"Failed to get resources from Arborist. Please check your Arborist configuration."
181180
)
182181

183-
@timed_cache(1800) # Cache for 30 minutes (typical JWT expiration time)
182+
@request_auth_cache() # cache the result of the auth request
184183
def caching_auth_mapping(self, token):
185184
"""
186185
Returns a list of resources the user has access to.
@@ -196,7 +195,7 @@ def caching_auth_mapping(self, token):
196195
resources = self.arborist.auth_mapping()
197196
return resources
198197

199-
@timed_cache(1800) # Cache for 30 minutes (typical JWT expiration time)
198+
@request_auth_cache() # cache the result of the auth request
200199
def cached_auth_request(self, token, service, method, resource):
201200
"""
202201
Makes an authenticated request to Arborist and caches the result.

0 commit comments

Comments
 (0)