Skip to content

Commit

Permalink
Tensor comparison == and close() now return False when shapes are not…
Browse files Browse the repository at this point in the history
… compatible
  • Loading branch information
Philipp Holl committed Dec 16, 2024
1 parent 4de3ed7 commit 8a7b20b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3209,7 +3209,7 @@ def close(*tensors, rel_tolerance: Union[float, Tensor] = 1e-5, abs_tolerance: U
"""
Checks whether all tensors have equal values within the specified tolerance.
Does not check that the shapes exactly match.
Does not check that the shapes exactly match but if shapes are incompatible, returns `False`.
Unlike with `always_close()`, all shapes must be compatible and tensors with different shapes are reshaped before comparing.
See Also:
Expand All @@ -3233,6 +3233,8 @@ def close(*tensors, rel_tolerance: Union[float, Tensor] = 1e-5, abs_tolerance: U
if all(t is tensors[0] for t in tensors):
return True
tensors = [wrap(t) for t in tensors]
if any(not tensors[0].shape.is_compatible(t.shape) for t in tensors[1:]):
return False
c = True
for other in tensors[1:]:
c &= _close(tensors[0], other, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, equal_nan=equal_nan, reduce=reduce)
Expand Down
7 changes: 6 additions & 1 deletion phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,8 @@ def __rmod__(self, other):
return self._op2(other, lambda x, y: y % x, lambda x, y: choose_backend(x, y).mod(y, x), 'rmod', '%')

def __eq__(self, other) -> 'Tensor':
if self is other:
return expand(True, self.shape)
if _EQUALITY_REDUCE[-1] == 'ref':
return wrap(self is other)
elif _EQUALITY_REDUCE[-1] == 'shape_and_value':
Expand All @@ -729,7 +731,10 @@ def __eq__(self, other) -> 'Tensor':
return wrap(close(self, other, rel_tolerance=0, abs_tolerance=0))
if other is None:
other = float('nan')
return self._op2(other, lambda x, y: x == y, lambda x, y: choose_backend(x, y).equal(x, y), 'eq', '==')
if self.shape.is_compatible(shape(other)):
return self._op2(other, lambda x, y: x == y, lambda x, y: choose_backend(x, y).equal(x, y), 'eq', '==')
else:
return wrap(False)

def __ne__(self, other) -> 'Tensor':
if _EQUALITY_REDUCE[-1] == 'ref':
Expand Down

0 comments on commit 8a7b20b

Please sign in to comment.