diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 36231ab..80fabae 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -3209,7 +3209,7 @@ def close(*tensors, rel_tolerance: Union[float, Tensor] = 1e-5, abs_tolerance: U """ Checks whether all tensors have equal values within the specified tolerance. - Does not check that the shapes exactly match. + Does not check that the shapes exactly match but if shapes are incompatible, returns `False`. Unlike with `always_close()`, all shapes must be compatible and tensors with different shapes are reshaped before comparing. See Also: @@ -3233,6 +3233,8 @@ def close(*tensors, rel_tolerance: Union[float, Tensor] = 1e-5, abs_tolerance: U if all(t is tensors[0] for t in tensors): return True tensors = [wrap(t) for t in tensors] + if any(not tensors[0].shape.is_compatible(t.shape) for t in tensors[1:]): + return False c = True for other in tensors[1:]: c &= _close(tensors[0], other, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, equal_nan=equal_nan, reduce=reduce) diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index e20bbec..6e9cabb 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -720,6 +720,8 @@ def __rmod__(self, other): return self._op2(other, lambda x, y: y % x, lambda x, y: choose_backend(x, y).mod(y, x), 'rmod', '%') def __eq__(self, other) -> 'Tensor': + if self is other: + return expand(True, self.shape) if _EQUALITY_REDUCE[-1] == 'ref': return wrap(self is other) elif _EQUALITY_REDUCE[-1] == 'shape_and_value': @@ -729,7 +731,10 @@ def __eq__(self, other) -> 'Tensor': return wrap(close(self, other, rel_tolerance=0, abs_tolerance=0)) if other is None: other = float('nan') - return self._op2(other, lambda x, y: x == y, lambda x, y: choose_backend(x, y).equal(x, y), 'eq', '==') + if self.shape.is_compatible(shape(other)): + return self._op2(other, lambda x, y: x == y, lambda x, y: choose_backend(x, y).equal(x, y), 'eq', '==') + else: + return wrap(False) def __ne__(self, other) -> 'Tensor': if _EQUALITY_REDUCE[-1] == 'ref':