Skip to content

Commit

Permalink
Added some more tests for strange behavious of lambdas in hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Aug 26, 2024
1 parent e2d068f commit cd2a4d1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
62 changes: 57 additions & 5 deletions tests/test_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,69 @@ def func(a):

obj1 = FloatAggregator(func)

@functools.wraps(func)
def _func(a):
return func(a)
def decorator(func):
@functools.wraps(func)
def _func(a):
return func(a)

obj2 = FloatAggregator(_func)
return _func

obj2 = FloatAggregator(decorator(func))

assert custom_hash(obj1) != custom_hash(obj2)


def test_hash_lambdas_same():
def func(a, b):
return np.mean(a) + b

def func2():
return FloatAggregator(lambda a: func(a, 1))

obj1 = func2()
obj2 = func2()

assert custom_hash(obj1) == custom_hash(obj2)


def test_hash_lambdas_different():
# This is quite interesting, these two lambdas are different, as they have different names, as they are
# defined in the same scope. in the pevious test, where there was only on lambda defined, the names were the same
# hence the hash the same.
obj1 = FloatAggregator(lambda a: np.mean(a))
obj2 = FloatAggregator(lambda a: np.mean(a))
obj2 = obj1
obj1 = FloatAggregator(lambda a: np.mean(a))
assert custom_hash(obj1) != custom_hash(obj2)


def test_hash_partials_same():
def func(a, b):
return np.mean(a) + b

obj1 = FloatAggregator(functools.partial(func, b=1))
obj2 = FloatAggregator(functools.partial(func, b=1))

assert custom_hash(obj1) == custom_hash(obj2)


def test_hash_partials_different():
def func(a, b):
return np.mean(a) + b

obj1 = FloatAggregator(functools.partial(func, b=1))
obj2 = FloatAggregator(functools.partial(func, b=2))

assert custom_hash(obj1) != custom_hash(obj2)


def test_hash_partials_different2():
def func(a, b):
return np.mean(a) + b

def func2(a, b):
return np.mean(a) + b

obj1 = FloatAggregator(functools.partial(func, b=1))
obj2 = FloatAggregator(functools.partial(func2, b=1))

assert custom_hash(obj1) != custom_hash(obj2)
12 changes: 12 additions & 0 deletions tpcp/_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import sys
import types
import warnings
from pathlib import Path

from joblib.func_inspect import get_func_code
Expand Down Expand Up @@ -64,6 +65,17 @@ def save(self, obj):
# However, in the context of tpcp, that is not really a concern. In most possible cases, this just means
# that some (likely obscure) guardrail will not trigger for you.
if isinstance(obj, types.FunctionType):
if "<lambda>" in obj.__qualname__:
warnings.warn(
"You are attempting to hash a lambda defined within a closure, likely because you used it as a "
"parameter to a tpcp object (e.g. an Aggregator). "
"While this works most of the time, it can to lead to some unexpected false positive hash "
"equalities, depending on how you define the lambdas. "
"We highly recommend to use a named function or a `functools.partial` instead.",
stacklevel=1,
)
# Note, that for lambdas this actully hashes the entire definition line.
# This means potentially more of the surrounding code than the lambda itself is hashed.
obj = ("F", obj.__qualname__, get_func_code(obj), vars(obj))

if isinstance(obj, type):
Expand Down

0 comments on commit cd2a4d1

Please sign in to comment.