diff --git a/phiml/dataclasses/__init__.py b/phiml/dataclasses/__init__.py index 5acb19a..79d14da 100644 --- a/phiml/dataclasses/__init__.py +++ b/phiml/dataclasses/__init__.py @@ -44,6 +44,6 @@ from functools import cached_property -from ._dataclasses import sliceable, data_fields, non_data_fields, config_fields, special_fields, replace, getitem +from ._dataclasses import sliceable, data_fields, non_data_fields, config_fields, special_fields, replace, getitem, equal, data_eq __all__ = [key for key in globals().keys() if not key.startswith('_')] diff --git a/phiml/dataclasses/_dataclasses.py b/phiml/dataclasses/_dataclasses.py index 551c361..4436f6a 100644 --- a/phiml/dataclasses/_dataclasses.py +++ b/phiml/dataclasses/_dataclasses.py @@ -8,7 +8,7 @@ from phiml.dataclasses._dep import get_unchanged_cache from phiml.math import DimFilter, shape, Shape from phiml.math._magic_ops import slice_, variable_attributes -from phiml.math._tensors import disassemble_tree, Tensor, assemble_tree +from phiml.math._tensors import disassemble_tree, Tensor, assemble_tree, equality_by_shape_and_value, equality_by_ref from phiml.math.magic import slicing_dict, BoundDim PhiMLDataclass = TypeVar("PhiMLDataclass") @@ -52,10 +52,23 @@ def __dataclass_repr__(obj): return f"{type(obj).__name__}[{content}]" cls.__repr__ = __dataclass_repr__ return cls + return wrap(cls) if cls is not None else wrap # See if we're being called as @dataclass or @dataclass(). - if cls is None: # See if we're being called as @dataclass or @dataclass(). - return wrap - return wrap(cls) + +def data_eq(cls=None, /, *, rel_tolerance=0., abs_tolerance=0., equal_nan=True, compare_tensors_by_ref=False): + def wrap(cls): + assert cls.__dataclass_params__.eq, f"@data_eq can only be used with dataclasses with eq=True." + cls.__default_dataclass_eq__ = cls.__eq__ + def __tensor_eq__(obj, other): + if compare_tensors_by_ref: + with equality_by_ref(): + return cls.__default_dataclass_eq__(obj, other) + with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan): + return cls.__default_dataclass_eq__(obj, other) + cls.__eq__ = __tensor_eq__ + # __ne__ calls `not __eq__()` by default + return cls + return wrap(cls) if cls is not None else wrap # See if we're being called as @dataclass or @dataclass(). NON_ATTR_TYPES = str, int, float, complex, bool, Shape, slice, Callable @@ -236,3 +249,23 @@ def __getitem__(self, item): cache = {k: slice_(v, item) for k, v in obj.__dict__.items() if isinstance(getattr(type(obj), k, None), cached_property) and not isinstance(v, Shape)} new_obj.__dict__.update(cache) return new_obj + + +def equal(obj1, obj2, rel_tolerance=0., abs_tolerance=0., equal_nan=True): + """ + Checks if two + + Args: + obj1: + obj2: + rel_tolerance: + abs_tolerance: + equal_nan: + + Returns: + + """ + cls = type(obj1) + eq_fn = cls.__default_dataclass_eq__ if hasattr(cls, '__default_dataclass_eq__') else cls.__eq__ + with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan): + return eq_fn(obj1, obj2) diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index af44fe0..e36b13a 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -944,7 +944,7 @@ def equality_by_ref(): @contextmanager -def equality_by_shape_and_value(rel_tolerance=0, abs_tolerance=0, equal_nan=False): +def equality_by_shape_and_value(rel_tolerance=0., abs_tolerance=0., equal_nan=False): """ Enables Tensor.__bool__ """ diff --git a/tests/commit/dataclasses/__init__.py b/tests/commit/dataclasses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/commit/dataclasses/test_dataclasses.py b/tests/commit/dataclasses/test_dataclasses.py new file mode 100644 index 0000000..ca06bc3 --- /dev/null +++ b/tests/commit/dataclasses/test_dataclasses.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass +from typing import Tuple, Sequence, Dict +from unittest import TestCase + +from phiml.dataclasses import data_fields, config_fields, special_fields, sliceable, data_eq, equal +from phiml.math import Tensor, vec, wrap, shape, assert_close + + +class TestDataclasses(TestCase): + + def test_field_types(self): + @dataclass + class Custom: + variable_attrs: Sequence[str] + age: int + next: 'Custom' + conf: Dict[str, Sequence[Tuple[str, float, complex, bool, int, slice]]] + data_names = [f.name for f in data_fields(Custom)] + self.assertEqual(['next'], data_names) + config_names = [f.name for f in config_fields(Custom)] + self.assertEqual(['age', 'conf'], config_names) + special_names = [f.name for f in special_fields(Custom)] + self.assertEqual(['variable_attrs'], special_names) + + def test_sliceable(self): + @sliceable + @dataclass(frozen=True) + class Custom: + pos: Tensor + edges: Dict[str, Tensor] + c = Custom(vec(x=1, y=2), {'lo': wrap([-1, 1], 'b:b')}) + self.assertEqual(('b', 'vector'), shape(c).names) + assert_close(c['y,x'].pos, c.pos['y,x']) + assert_close(c.vector['y,x'].pos, c.pos['y,x']) + + def test_data_eq(self): + @data_eq(abs_tolerance=.2) + @dataclass(frozen=True) + class Custom: + pos: Tensor + c1 = Custom(vec(x=0, y=1)) + c2 = Custom(vec(x=3, y=4)) + c11 = Custom(vec(x=.1, y=1.1)) + self.assertNotEqual(c1, c2) + self.assertEqual(c1, c1) + self.assertEqual(c1, c11) + self.assertFalse(equal(c1, c11, abs_tolerance=0)) + self.assertTrue(equal(c1, c2, abs_tolerance=3))