Skip to content

Commit

Permalink
Support item name replacement in Shape.replace()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Feb 20, 2024
1 parent 6740139 commit 1971004
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@

DEBUG_CHECKS = []

DimFilter = Union[str, tuple, list, set, 'Shape', Callable]
try:
DimFilter.__doc__ = """Dimension filters can be used with `Shape.only()` and `Shype.without()`, making them the standard tool for specifying sets of dimensions.
The following types can be used as dimension filters:
* `Shape` instances
* `tuple` or `list` objects containing dimension names as `str`
* Single `str` listing comma-separated dimension names
* Any function `filter(Shape) -> Shape`, such as `math.batch()`, `math.non_batch()`, `math.spatial()`, etc.
""" # docstring must be set explicitly
except AttributeError: # on older Python versions, this is not possible
pass


def enable_debug_checks():
"""
Expand Down Expand Up @@ -1039,7 +1053,7 @@ def _replace_names_and_types(self,
names[self.index(old_name)] = _apply_prefix(new_dim, types[self.index(old_name)])
return Shape(tuple(sizes), tuple(names), tuple(types), tuple(item_names))

def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', keep_item_names=True) -> 'Shape':
def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', keep_item_names=True, replace_item_names: DimFilter = None) -> 'Shape':
"""
Returns a copy of `self` with `dims` replaced by `new`.
Dimensions that are not present in `self` are ignored.
Expand All @@ -1051,6 +1065,7 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', keep_ite
new: New dimensions, must have same length as `dims`.
If a `Shape` is given, replaces the dimension types and item names as well.
keep_item_names: Keeps existing item names for dimensions where `new` does not specify item names if the new dimension has the same size.
replace_item_names: For which dims the item names should be replaced as well.
Returns:
`Shape` with same rank and dimension order as `self`.
Expand All @@ -1061,6 +1076,13 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', keep_ite
sizes = list(self.sizes)
types = list(self.types)
item_names = list(self.item_names)
for i in self.indices(self.only(replace_item_names)):
if item_names[i]:
if len(new) > len(dims):
raise NotImplementedError
else:
name_map = {d: n for d, n in zip(dims, new.names)}
item_names[i] = tuple([name_map.get(n, n) for n in item_names[i]])
if len(new) > len(dims): # Put all in one spot
assert len(dims) == 1, "Cannot replace 2+ dims by more replacements"
index = self.index(dims[0])
Expand Down Expand Up @@ -1306,20 +1328,6 @@ def __hash__(self):
EMPTY_SHAPE = Shape((), (), (), ())
""" Empty shape, `()` """

DimFilter = Union[str, tuple, list, set, Shape, Callable]
try:
DimFilter.__doc__ = """Dimension filters can be used with `Shape.only()` and `Shype.without()`, making them the standard tool for specifying sets of dimensions.
The following types can be used as dimension filters:
* `Shape` instances
* `tuple` or `list` objects containing dimension names as `str`
* Single `str` listing comma-separated dimension names
* Any function `filter(Shape) -> Shape`, such as `math.batch()`, `math.non_batch()`, `math.spatial()`, etc.
""" # docstring must be set explicitly
except AttributeError: # on older Python versions, this is not possible
pass


class IncompatibleShapes(Exception):
"""
Expand Down

0 comments on commit 1971004

Please sign in to comment.