Skip to content

Commit

Permalink
Fix grid_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 8, 2025
1 parent 7c6e444 commit d9faf17
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,7 @@ def broadcastable_native_tensors(*tensors) -> Tuple[Sequence[str], Shape, Sequen
tensors = [dense(t) for t in tensors]
var_names = tuple(set.union(*[set(variable_dim_names(t)) for t in tensors]))
natives = [t._transposed_native(var_names, False) if t.rank > 0 else t.native(None, False) for t in tensors]
broadcast_shape = merge_shapes(tensors)
broadcast_shape = merge_shapes(*tensors)
return var_names, broadcast_shape, natives


Expand Down
2 changes: 1 addition & 1 deletion phiml/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def is_flexible(self) -> bool:
return False

def transform_coordinates(self, coordinates: Tensor, shape: Shape, **kwargs) -> Tensor:
return coordinates % shape.spatial
return coordinates % wrap(shape.spatial, channel(coordinates))[coordinates.vector.item_name_list]

def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, already_padded: Optional[dict] = None, **kwargs) -> Tensor:
if upper_edge:
Expand Down
4 changes: 4 additions & 0 deletions phiml/math/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ def type(self) -> Callable:
def item_names(self):
return shape(self.obj).get_item_names(self.name)

@property
def item_name_list(self):
return list(shape(self.obj).get_item_names(self.name))

@property
def name_tensor(self):
dim = shape(self.obj)[self.name]
Expand Down
42 changes: 22 additions & 20 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,37 +249,39 @@ def test_grid_sample_1d(self):
assert_close(sampled, [0, 1, 0.5])

def test_grid_sample_backend_equality_2d(self):
grid = math.random_uniform(spatial(y=10, x=7))
coords = math.random_uniform(batch(mybatch=10) & spatial(x=3, y=2)) * vec(y=12, x=9)
grid_ = math.tensor(grid.native('x,y'), spatial('x,y'))
coords_ = coords.vector[::-1]
grid_yx = math.random_uniform(spatial(y=10, x=7))
coords_yx = math.random_uniform(batch(mybatch=10) & spatial(x=3, y=2)) * vec(y=12, x=9)
grid_xy = math.tensor(grid_yx.native('x,y'), spatial('x,y'))
coords_xy = coords_yx.vector['x,y']
for extrap in (extrapolation.ZERO, extrapolation.ONE, extrapolation.BOUNDARY, extrapolation.PERIODIC):
sampled = []
for backend in BACKENDS:
with backend:
grid = math.tensor(grid)
coords = math.tensor(coords)
grid_ = math.tensor(grid_)
coords_ = math.tensor(coords_)
sampled.append(math.grid_sample(grid, coords, extrap))
sampled.append(math.grid_sample(grid_, coords_, extrap))
grid_yx = math.tensor(grid_yx)
coords_yx = math.tensor(coords_yx)
grid_xy = math.tensor(grid_xy)
coords_xy = math.tensor(coords_xy)
sampled.append(math.grid_sample(grid_yx, coords_yx, extrap))
sampled.append(math.grid_sample(grid_xy, coords_xy, extrap))
sampled.append(math.grid_sample(grid_xy, coords_yx, extrap))
assert_close(*sampled, abs_tolerance=1e-5)

def test_grid_sample_backend_equality_2d_batched(self):
grid = math.random_uniform(batch(mybatch=10) & spatial(y=10, x=7))
coords = math.random_uniform(batch(mybatch=10) & spatial(x=3, y=2)) * vec(y=12, x=9)
grid_ = math.tensor(grid.native('mybatch,x,y'), batch('mybatch'), spatial('x,y'))
coords_ = coords.vector[::-1]
grid_yx = math.random_uniform(batch(mybatch=10) & spatial(y=10, x=7))
coords_yx = math.random_uniform(batch(mybatch=10) & spatial(x=3, y=2)) * vec(y=12, x=9)
grid_xy = math.tensor(grid_yx.native('mybatch,x,y'), batch('mybatch'), spatial('x,y'))
coords_xy = coords_yx.vector['x,y']
for extrap in (extrapolation.ZERO, extrapolation.ONE, extrapolation.BOUNDARY, extrapolation.PERIODIC):
sampled = []
for backend in BACKENDS:
with backend:
grid = math.tensor(grid)
coords = math.tensor(coords)
grid_ = math.tensor(grid_)
coords_ = math.tensor(coords_)
sampled.append(math.grid_sample(grid, coords, extrap))
sampled.append(math.grid_sample(grid_, coords_, extrap))
grid_yx = math.tensor(grid_yx)
coords_yx = math.tensor(coords_yx)
grid_xy = math.tensor(grid_xy)
coords_xy = math.tensor(coords_xy)
sampled.append(math.grid_sample(grid_yx, coords_yx, extrap))
sampled.append(math.grid_sample(grid_xy, coords_xy, extrap))
sampled.append(math.grid_sample(grid_xy, coords_yx, extrap))
assert_close(*sampled, abs_tolerance=1e-5)

def test_grid_sample_gradient_1d(self):
Expand Down

0 comments on commit d9faf17

Please sign in to comment.