Skip to content

Commit cd2a4d1

Browse files
committed
Added some more tests for strange behavious of lambdas in hashing
1 parent e2d068f commit cd2a4d1

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

tests/test_hash.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,69 @@ def func(a):
9393

9494
obj1 = FloatAggregator(func)
9595

96-
@functools.wraps(func)
97-
def _func(a):
98-
return func(a)
96+
def decorator(func):
97+
@functools.wraps(func)
98+
def _func(a):
99+
return func(a)
99100

100-
obj2 = FloatAggregator(_func)
101+
return _func
102+
103+
obj2 = FloatAggregator(decorator(func))
101104

102105
assert custom_hash(obj1) != custom_hash(obj2)
103106

104107

108+
def test_hash_lambdas_same():
109+
def func(a, b):
110+
return np.mean(a) + b
111+
112+
def func2():
113+
return FloatAggregator(lambda a: func(a, 1))
114+
115+
obj1 = func2()
116+
obj2 = func2()
117+
118+
assert custom_hash(obj1) == custom_hash(obj2)
119+
120+
105121
def test_hash_lambdas_different():
122+
# This is quite interesting, these two lambdas are different, as they have different names, as they are
123+
# defined in the same scope. in the pevious test, where there was only on lambda defined, the names were the same
124+
# hence the hash the same.
106125
obj1 = FloatAggregator(lambda a: np.mean(a))
107-
obj2 = FloatAggregator(lambda a: np.mean(a))
126+
obj2 = obj1
127+
obj1 = FloatAggregator(lambda a: np.mean(a))
128+
assert custom_hash(obj1) != custom_hash(obj2)
129+
130+
131+
def test_hash_partials_same():
132+
def func(a, b):
133+
return np.mean(a) + b
134+
135+
obj1 = FloatAggregator(functools.partial(func, b=1))
136+
obj2 = FloatAggregator(functools.partial(func, b=1))
137+
138+
assert custom_hash(obj1) == custom_hash(obj2)
139+
140+
141+
def test_hash_partials_different():
142+
def func(a, b):
143+
return np.mean(a) + b
144+
145+
obj1 = FloatAggregator(functools.partial(func, b=1))
146+
obj2 = FloatAggregator(functools.partial(func, b=2))
147+
148+
assert custom_hash(obj1) != custom_hash(obj2)
149+
150+
151+
def test_hash_partials_different2():
152+
def func(a, b):
153+
return np.mean(a) + b
154+
155+
def func2(a, b):
156+
return np.mean(a) + b
157+
158+
obj1 = FloatAggregator(functools.partial(func, b=1))
159+
obj2 = FloatAggregator(functools.partial(func2, b=1))
108160

109161
assert custom_hash(obj1) != custom_hash(obj2)

tpcp/_hash.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55
import sys
66
import types
7+
import warnings
78
from pathlib import Path
89

910
from joblib.func_inspect import get_func_code
@@ -64,6 +65,17 @@ def save(self, obj):
6465
# However, in the context of tpcp, that is not really a concern. In most possible cases, this just means
6566
# that some (likely obscure) guardrail will not trigger for you.
6667
if isinstance(obj, types.FunctionType):
68+
if "<lambda>" in obj.__qualname__:
69+
warnings.warn(
70+
"You are attempting to hash a lambda defined within a closure, likely because you used it as a "
71+
"parameter to a tpcp object (e.g. an Aggregator). "
72+
"While this works most of the time, it can to lead to some unexpected false positive hash "
73+
"equalities, depending on how you define the lambdas. "
74+
"We highly recommend to use a named function or a `functools.partial` instead.",
75+
stacklevel=1,
76+
)
77+
# Note, that for lambdas this actully hashes the entire definition line.
78+
# This means potentially more of the surrounding code than the lambda itself is hashed.
6779
obj = ("F", obj.__qualname__, get_func_code(obj), vars(obj))
6880

6981
if isinstance(obj, type):

0 commit comments

Comments
 (0)