Skip to content

Commit

Permalink
Support multi-type Shape.only(), without()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 25, 2023
1 parent 0fb9f40 commit bb27179
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
47 changes: 31 additions & 16 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,16 +730,21 @@ def without(self, dims: 'DimFilter') -> 'Shape':
Returns:
Shape without specified dimensions
"""
if callable(dims):
if dims is None: # subtract none
return self
elif callable(dims):
dims = dims(self)
if isinstance(dims, str):
dims = parse_dim_order(dims)
if isinstance(dims, (tuple, list, set)):
return self[[i for i in range(self.rank) if self.names[i] not in dims]]
return self[[i for i in range(self.rank) if self.names[i] not in parse_dim_order(dims)]]
elif isinstance(dims, Shape):
return self[[i for i in range(self.rank) if self.names[i] not in dims.names]]
elif dims is None: # subtract none
return self
if isinstance(dims, (tuple, list, set)) and all([isinstance(d, str) for d in dims]):
return self[[i for i in range(self.rank) if self.names[i] not in dims]]
elif isinstance(dims, (tuple, list, set)):
result = self
for wo in dims:
result = result.without(wo)
return result
else:
raise ValueError(dims)

Expand Down Expand Up @@ -767,16 +772,26 @@ def only(self, dims: 'DimFilter', reorder=False):
dims = parse_dim_order(dims)
if isinstance(dims, Shape):
dims = dims.names
if not isinstance(dims, (tuple, list, set)):
raise ValueError(dims)
if reorder:
dims = [d.name if isinstance(d, Shape) else d for d in dims]
assert all(isinstance(d, str) for d in dims)
return self[[self.names.index(d) for d in dims if d in self.names]]
else:
dims = [d.name if isinstance(d, Shape) else d for d in dims]
assert all(isinstance(d, str) for d in dims)
return self[[i for i in range(self.rank) if self.names[i] in dims]]
if isinstance(dims, (tuple, list, set)):
dim_names = []
for d in dims:
if callable(d):
d = d(self)
if isinstance(d, str):
dim_names.append(d)
elif isinstance(d, Shape):
dim_names.extend(d.names)
else:
raise ValueError(f"Format not understood for Shape.only(): {dims}")
if reorder:
dim_names = [d.name if isinstance(d, Shape) else d for d in dim_names]
assert all(isinstance(d, str) for d in dim_names)
return self[[self.names.index(d) for d in dim_names if d in self.names]]
else:
dim_names = [d.name if isinstance(d, Shape) else d for d in dim_names]
assert all(isinstance(d, str) for d in dim_names)
return self[[i for i in range(self.rank) if self.names[i] in dim_names]]
raise ValueError(dims)

@property
def rank(self) -> int:
Expand Down
9 changes: 8 additions & 1 deletion tests/commit/math/test__shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,11 @@ def test_higher_dual(self):
self.assertEqual(spatial(x=4, y=3), pp.spatial)
self.assertEqual(math.EMPTY_SHAPE, pp.dual)


def test_only(self):
s = batch(b=10) & channel(vector='x,y')
self.assertEqual(batch(b=10), s.only([batch, spatial]))
self.assertEqual(s[(1, 0)], s.only([channel, batch], reorder=True))

def test_without(self):
s = batch(b=10) & channel(vector='x,y')
self.assertEqual(channel(vector='x,y'), s.without([batch, spatial]))

0 comments on commit bb27179

Please sign in to comment.