diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index f16fbbe..e4bd391 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1886,7 +1886,7 @@ def broadcastable_native_tensors(*tensors) -> Tuple[Sequence[str], Shape, Sequen tensors = [dense(t) for t in tensors] var_names = tuple(set.union(*[set(variable_dim_names(t)) for t in tensors])) natives = [t._transposed_native(var_names, False) if t.rank > 0 else t.native(None, False) for t in tensors] - broadcast_shape = merge_shapes(tensors) + broadcast_shape = merge_shapes(*tensors) return var_names, broadcast_shape, natives diff --git a/phiml/math/extrapolation.py b/phiml/math/extrapolation.py index 82b8003..d55b4c8 100644 --- a/phiml/math/extrapolation.py +++ b/phiml/math/extrapolation.py @@ -666,7 +666,7 @@ def is_flexible(self) -> bool: return False def transform_coordinates(self, coordinates: Tensor, shape: Shape, **kwargs) -> Tensor: - return coordinates % shape.spatial + return coordinates % wrap(shape.spatial, channel(coordinates))[coordinates.vector.item_name_list] def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, already_padded: Optional[dict] = None, **kwargs) -> Tensor: if upper_edge: diff --git a/phiml/math/magic.py b/phiml/math/magic.py index 405ef89..a5d3802 100644 --- a/phiml/math/magic.py +++ b/phiml/math/magic.py @@ -542,6 +542,10 @@ def type(self) -> Callable: def item_names(self): return shape(self.obj).get_item_names(self.name) + @property + def item_name_list(self): + return list(shape(self.obj).get_item_names(self.name)) + @property def name_tensor(self): dim = shape(self.obj)[self.name] diff --git a/tests/commit/math/test__ops.py b/tests/commit/math/test__ops.py index 87e5912..5c82d02 100644 --- a/tests/commit/math/test__ops.py +++ b/tests/commit/math/test__ops.py @@ -249,20 +249,21 @@ def test_grid_sample_1d(self): assert_close(sampled, [0, 1, 0.5]) def test_grid_sample_backend_equality_2d(self): - grid = math.random_uniform(spatial(y=10, x=7)) - coords = math.random_uniform(batch(mybatch=10) & spatial(x=3, y=2)) * vec(y=12, x=9) - grid_ = math.tensor(grid.native('x,y'), spatial('x,y')) - coords_ = coords.vector[::-1] + grid_yx = math.random_uniform(spatial(y=10, x=7)) + coords_yx = math.random_uniform(batch(mybatch=10) & spatial(x=3, y=2)) * vec(y=12, x=9) + grid_xy = math.tensor(grid_yx.native('x,y'), spatial('x,y')) + coords_xy = coords_yx.vector['x,y'] for extrap in (extrapolation.ZERO, extrapolation.ONE, extrapolation.BOUNDARY, extrapolation.PERIODIC): sampled = [] for backend in BACKENDS: with backend: - grid = math.tensor(grid) - coords = math.tensor(coords) - grid_ = math.tensor(grid_) - coords_ = math.tensor(coords_) - sampled.append(math.grid_sample(grid, coords, extrap)) - sampled.append(math.grid_sample(grid_, coords_, extrap)) + grid_yx = math.tensor(grid_yx) + coords_yx = math.tensor(coords_yx) + grid_xy = math.tensor(grid_xy) + coords_xy = math.tensor(coords_xy) + sampled.append(math.grid_sample(grid_yx, coords_yx, extrap)) + sampled.append(math.grid_sample(grid_xy, coords_xy, extrap)) + sampled.append(math.grid_sample(grid_xy, coords_yx, extrap)) assert_close(*sampled, abs_tolerance=1e-5) def test_grid_sample_backend_equality_2d_batched(self):