Skip to content

Commit

Permalink
[geom] Update default equality checks
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Mar 29, 2024
1 parent f77b24c commit 9abccd5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 26 deletions.
2 changes: 1 addition & 1 deletion phi/field/_field_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def stack(fields: Sequence[Field], dim: Shape, dim_bounds: Box = None):
values = math.stack([f.values for f in fields], dim)
return PointCloud(elements, values, fields[0].extrapolation)
elif fields[0].is_mesh:
assert all([f.geometry.shallow_equals(fields[0].geometry) for f in fields])
assert all([f.geometry.shallow_equals(fields[0].geometry) for f in fields]), f"stacking fields with different geometries is not supported. Got {[f.geometry for f in fields]}"
values = math.stack([f.values for f in fields], dim)
return Field(fields[0].geometry, values, fields[0].extrapolation)
raise NotImplementedError(type(fields[0]))
Expand Down
33 changes: 8 additions & 25 deletions phi/geom/_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from phi import math
from phi.math import Tensor, Shape, EMPTY_SHAPE, non_channel, wrap, shape, Extrapolation
from phiml.math._magic_ops import variable_attributes, expand, stack
from phiml.math._magic_ops import variable_attributes, expand, stack, find_differences
from phi.math.magic import BoundDim, slicing_dict


Expand Down Expand Up @@ -418,18 +418,8 @@ def __eq__(self, other):
See Also:
`shallow_equals()`
"""
if self is other:
return True
if not isinstance(other, type(self)):
return False
if self.shape != other.shape:
return False
c1 = {a: getattr(self, a) for a in variable_attributes(self)}
c2 = {a: getattr(other, a) for a in variable_attributes(self)}
for c in c1.keys():
if c1[c] is not c2[c] and math.any(c1[c] != c2[c]):
return False
return True
differences = find_differences(self, other, compare_tensors_by_id=False)
return not differences

def shallow_equals(self, other):
"""
Expand All @@ -439,18 +429,8 @@ def shallow_equals(self, other):
The `shallow_equals()` check does not compare all tensor elements but merely checks whether the same tensors are referenced.
"""
if self is other:
return True
if not isinstance(other, type(self)):
return False
if self.shape != other.shape:
return False
c1 = {a: getattr(self, a) for a in variable_attributes(self)}
c2 = {a: getattr(other, a) for a in variable_attributes(self)}
for c in c1.keys():
if c1[c] is not c2[c]:
return False
return True
differences = find_differences(self, other, compare_tensors_by_id=True)
return not differences

@staticmethod
def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry':
Expand Down Expand Up @@ -651,6 +631,9 @@ def __init__(self, location: math.Tensor):
def __variable_attrs__(self):
return '_location',

def __value_attrs__(self):
return '_location',

def __with_attrs__(self, **updates):
if '_location' in updates:
result = Point.__new__(Point)
Expand Down

0 comments on commit 9abccd5

Please sign in to comment.