Skip to content

Commit

Permalink
always_close() now returns False for incompatible tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 26, 2023
1 parent 2934bba commit 159df93
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
19 changes: 15 additions & 4 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2624,11 +2624,15 @@ def dtype(x) -> DType:
def always_close(t1: Union[Number, Tensor, bool], t2: Union[Number, Tensor, bool], rel_tolerance=1e-5, abs_tolerance=0, equal_nan=False) -> bool:
"""
Checks whether two tensors are guaranteed to be `close` in all values.
Unlike `close()`, this function can be used with JIT compilation.
Unlike `close()`, this function can be used with JIT compilation and with tensors of incompatible shapes.
Incompatible tensors are never close.
If one of the given tensors is being traced, the tensors are only equal if they reference the same native tensor.
Otherwise, an element-wise equality check is performed.
See Also:
`close()`.
Args:
t1: First tensor or number to compare.
t2: Second tensor or number to compare.
Expand All @@ -2644,7 +2648,10 @@ def always_close(t1: Union[Number, Tensor, bool], t2: Union[Number, Tensor, bool
if t1.available != t2.available:
return False
if t1.available and t2.available:
return close(t1, t2, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, equal_nan=equal_nan)
try:
return close(t1, t2, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, equal_nan=equal_nan)
except IncompatibleShapes:
return False
elif isinstance(t1, NativeTensor) and isinstance(t2, NativeTensor):
return t1._native is t2._native
else:
Expand All @@ -2656,10 +2663,14 @@ def close(*tensors, rel_tolerance=1e-5, abs_tolerance=0, equal_nan=False) -> boo
Checks whether all tensors have equal values within the specified tolerance.
Does not check that the shapes exactly match.
Tensors with different shapes are reshaped before comparing.
Unlike with `always_close()`, all shapes must be compatible and tensors with different shapes are reshaped before comparing.
See Also:
`always_close()`.
Args:
*tensors: `Tensor` or tensor-like (constant) each
*tensors: At least two `Tensor` or tensor-like objects.
The shapes of all tensors must be compatible but not all tensors must have all dimensions.
rel_tolerance: Relative tolerance
abs_tolerance: Absolute tolerance
equal_nan: If `True`, tensors are considered close if they are NaN in the same places.
Expand Down
1 change: 1 addition & 0 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def jit(x, y):
x = math.tensor(0)
y = math.tensor(0)
self.assertEqual(1, jit(x, y).native(), msg=b.name)
self.assertFalse(math.always_close(vec(x=0), vec(x=0, y=1)))

def test_assert_close_non_uniform(self):
t = math.stack([math.zeros(spatial(x=4)), math.zeros(spatial(x=3))], channel('stack'))
Expand Down

0 comments on commit 159df93

Please sign in to comment.