Skip to content

Commit

Permalink
[geom] slice off Grid dims with int
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Mar 9, 2024
1 parent 8611cf6 commit a2d774f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
22 changes: 9 additions & 13 deletions phi/geom/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,29 @@ def rotation_matrix(self) -> Optional[Tensor]:

def __getitem__(self, item):
item = slicing_dict(self, item)
bounds = self._bounds
resolution = self._resolution.after_gather(item)
bounds = self._bounds[{d: s for d, s in item.items() if d != 'vector'}]
if 'vector' in item:
resolution = resolution.only(item['vector'], reorder=True)
bounds = bounds.vector[item['vector']]
bounds = bounds.vector[resolution.name_list]
dx = self.size
gather_dict = {}
for dim, selection in item.items():
if dim in self._resolution:
if isinstance(selection, int):
start = selection
stop = selection + 1
elif isinstance(selection, slice):
if dim in resolution:
if isinstance(selection, slice):
start = selection.start or 0
if start < 0:
start += self.resolution.get_size(dim)
stop = selection.stop or self.resolution.get_size(dim)
if stop < 0:
stop += self.resolution.get_size(dim)
assert selection.step is None or selection.step == 1
else:
else: # int slices are not contained in resolution anymore
raise ValueError(f"Illegal selection: {item}")
dim_mask = math.wrap(self.resolution.mask(dim))
lower = bounds.lower + start * dim_mask * dx
upper = bounds.upper + (stop - self.resolution.get_size(dim)) * dim_mask * dx
bounds = Box(lower, upper)
gather_dict[dim] = slice(start, stop)
resolution = self._resolution.after_gather(gather_dict)
bounds = bounds[{d: s for d, s in item.items() if d != 'vector'}]
if 'vector' in item:
bounds = bounds[item['vector']] # resolution[item['vector']] will be done automatically
return UniformGrid(resolution, bounds)

def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Optional[int], **kwargs) -> 'Cuboid':
Expand Down
21 changes: 21 additions & 0 deletions tests/commit/geom/test__grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest import TestCase


from phi import math
from phi.geom import UniformGrid, Box
from phi.math import batch, channel
from phi.math.magic import Shaped, Sliceable, Shapable
from phiml.math import vec, spatial


class TestBox(TestCase):

def test_slice_int(self):
grid = UniformGrid(x=4, y=3, z=2)
self.assertEqual(grid[{'z': 0}].resolution, spatial(x=4, y=3))
self.assertEqual(grid[{'z': 0}].bounds, grid.bounds['x,y'])

def test_slice(self):
grid = UniformGrid(x=4, y=3, z=2)
self.assertEqual(grid[{'z': slice(1, 2)}].resolution, spatial(x=4, y=3, z=1))
self.assertEqual(grid[{'z': slice(1, 2)}].bounds, Box(vec(x=0, y=0, z=1), vec(x=4, y=3, z=2)))

0 comments on commit a2d774f

Please sign in to comment.