Skip to content

Commit 0f313ca

Browse files
Improve and document decorator
1 parent 67b494d commit 0f313ca

File tree

2 files changed

+85
-27
lines changed

2 files changed

+85
-27
lines changed

src/aiidalab_qe/common/decorators.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,70 @@ def cache_per_thread(invalidator: str | None = None):
1414
`invalidator` : `str | None`
1515
The name of the attribute to watch for changes.
1616
If the attribute changes, the cache will be invalidated.
17+
18+
Usage
19+
-----
20+
Methods / free functions:
21+
>>> @cache_per_thread
22+
>>> def f(...): ...
23+
24+
Properties (must use this order):
25+
>>> @cache_per_thread(invalidator="uuid")
26+
>>> @property
27+
>>> def f(...): ...
1728
"""
1829

19-
def decorator(func):
20-
@functools.wraps(func)
21-
def wrapper(self, *args, **kwargs):
22-
cache = getattr(_thread_local, "cache", None)
23-
if cache is None:
24-
cache = {}
25-
_thread_local.cache = cache
30+
def cached_call(func, *args, **kwargs):
31+
# Get or initialize the cache for this thread
32+
cache = getattr(_thread_local, "cache", None)
33+
if cache is None:
34+
cache = {}
35+
_thread_local.cache = cache
36+
37+
# Begin constructing the cache key
38+
key_parts = [func]
39+
40+
if args:
41+
# Check if func is a method of a class
42+
if hasattr(args[0], "__dict__"):
43+
self = args[0]
44+
key_parts.append(id(self))
45+
46+
# Include the invalidator attribute if specified
47+
if invalidator is not None:
48+
key_parts.append(getattr(self, invalidator))
2649

27-
key_parts = [id(self), func]
28-
if invalidator is not None:
29-
key_parts.append(getattr(self, invalidator))
30-
if args or kwargs:
31-
key_parts.append(args)
32-
key_parts.append(frozenset(kwargs.items()))
33-
key = tuple(key_parts)
50+
# Include the arguments in the cache key
51+
if args or kwargs:
52+
key_parts.append(args)
53+
key_parts.append(frozenset(kwargs.items()))
3454

35-
if key not in cache:
36-
cache[key] = func(self, *args, **kwargs)
55+
key = tuple(key_parts)
3756

38-
return cache[key]
57+
# Check if the result is already cached
58+
if key not in cache:
59+
cache[key] = func(*args, **kwargs)
3960

61+
return cache[key]
62+
63+
def decorator(func):
64+
# Property case
4065
if isinstance(func, property):
66+
fget = func.fget
67+
if fget is None:
68+
raise TypeError("Property has no getter to wrap")
69+
70+
@functools.wraps(fget)
71+
def wrapper(self): # type: ignore
72+
return cached_call(fget, self)
73+
4174
return property(wrapper)
4275

76+
# method/function case
77+
@functools.wraps(func)
78+
def wrapper(*args, **kwargs):
79+
return cached_call(func, *args, **kwargs)
80+
4381
return wrapper
4482

4583
return decorator

tests/test_decorators.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,40 @@ def __init__(self, value):
88
self.value = value
99

1010
@cache_per_thread("value")
11-
def compute(self, x):
11+
@property
12+
def double_the_value(self):
13+
return self.value * 2
14+
15+
@cache_per_thread("value")
16+
def add_to_value(self, x):
1217
return self.value + x
1318

1419

20+
def test_cache_per_thread_property():
21+
obj = DummyClass(10)
22+
assert obj.double_the_value == 20 # First call, computes the result
23+
assert obj.double_the_value == 20 # Cached result, no recomputation
24+
25+
1526
def test_cache_per_thread_single_thread():
1627
obj = DummyClass(10)
17-
assert obj.compute(5) == 15 # First call, computes the result
18-
assert obj.compute(5) == 15 # Cached result, no recomputation
28+
assert obj.add_to_value(5) == 15 # First call, computes the result
29+
assert obj.add_to_value(5) == 15 # Cached result, no recomputation
1930

2031

2132
def test_cache_per_thread_different_args():
2233
obj = DummyClass(10)
23-
assert obj.compute(5) == 15 # First call with x=5
24-
assert obj.compute(3) == 13 # First call with x=3
25-
assert obj.compute(5) == 15 # Cached result for x=5
34+
assert obj.add_to_value(5) == 15 # First call with x=5
35+
assert obj.add_to_value(3) == 13 # First call with x=3
36+
assert obj.add_to_value(5) == 15 # Cached result for x=5
2637

2738

2839
def test_cache_per_thread_different_threads():
2940
obj = DummyClass(300) # must be > 256 to avoid Python integer interning
3041
results = []
3142

3243
def thread_func(x):
33-
results.append(obj.compute(x))
44+
results.append(obj.add_to_value(x))
3445

3546
thread1 = Thread(target=thread_func, args=(5,))
3647
thread2 = Thread(target=thread_func, args=(5,))
@@ -50,7 +61,16 @@ def thread_func(x):
5061

5162
def test_cache_invalidation():
5263
obj = DummyClass(10)
53-
assert obj.compute(5) == 15 # First call, computes the result
54-
assert obj.compute(5) == 15 # Cached result, no recomputation
64+
assert obj.add_to_value(5) == 15 # First call, computes the result
65+
assert obj.add_to_value(5) == 15 # Cached result, no recomputation
5566
obj.value = 20 # Invalidate cache by changing the invalidator attribute
56-
assert obj.compute(5) == 25 # Recomputes the result
67+
assert obj.add_to_value(5) == 25 # Recomputes the result
68+
69+
70+
def test_cache_per_thread_on_plain_function():
71+
@cache_per_thread()
72+
def compute(x):
73+
return 42 + x
74+
75+
assert compute(5) == 47 # First call, computes the result
76+
assert compute(5) == 47 # Cached result, no recomputation

0 commit comments

Comments
 (0)