Skip to content

Commit

Permalink
Ref: Fix rename_dims()
Browse files Browse the repository at this point in the history
Ref: Add Shape.non_singleton
  • Loading branch information
holl- committed Jan 12, 2025
1 parent a7270e1 commit 13c9e53
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
11 changes: 11 additions & 0 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ def rename_dims(value: PhiTreeNodeType,
old_dims, new_dims = _shape_replace(shape(value), dims, names)
if not new_dims:
return value
if new_dims.names == old_dims.names and new_dims == old_dims:
return value
# --- First try __replace_dims__ ---
if hasattr(value, '__replace_dims__'):
result = value.__replace_dims__(old_dims.names, new_dims, **kwargs)
Expand All @@ -590,6 +592,7 @@ def _shape_replace(shape: Shape, dims: DimFilter, new: DimFilter) -> Tuple[Shape
existing = shape.only(dims, reorder=True)
if not existing:
return EMPTY_SHAPE, EMPTY_SHAPE
# --- Replace based on type(new) ---
if isinstance(new, str) and new.startswith('(') and new.endswith(')'):
item_names = [s.strip() for s in new[1:-1].split(',')]
new = concat_shapes_(*[d.with_size(item_names) for d in existing])
Expand All @@ -612,6 +615,14 @@ def _shape_replace(shape: Shape, dims: DimFilter, new: DimFilter) -> Tuple[Shape
raise ValueError(f"Invalid item in names: {n}")
new = concat_shapes_(*new_dims)
elif isinstance(new, Shape):
if not callable(dims):
if isinstance(dims, Shape):
existing_idx = dims.indices(existing.names)
elif isinstance(dims, (tuple, list)):
existing_idx = [dims.index(n) for n in existing.names]
else:
raise NotImplementedError
new = new[existing_idx]
if not new.well_defined:
new = new.with_sizes(existing.sizes)
else:
Expand Down
11 changes: 11 additions & 0 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,9 @@ def non_uniform_shape(self):
@property
def singleton(self):
return self if _size_equal(self.size, 1) else EMPTY_SHAPE
@property
def non_singleton(self):
return EMPTY_SHAPE if _size_equal(self.size, 1) else self

@property
def well_defined(self):
Expand Down Expand Up @@ -1280,6 +1283,10 @@ def non_uniform_shape(self):
def singleton(self):
dims = {n: dim for n, dim in self.dims.items() if _size_equal(dim.size, 1)}
return next(iter(dims.values())) if len(dims) == 1 else PureShape(self.dim_type, dims)
@property
def non_singleton(self):
dims = {n: dim for n, dim in self.dims.items() if not _size_equal(dim.size, 1)}
return next(iter(dims.values())) if len(dims) == 1 else PureShape(self.dim_type, dims)

@property
def well_defined(self):
Expand Down Expand Up @@ -1716,6 +1723,10 @@ def non_uniform_shape(self):
def singleton(self):
dims = {n: dim for n, dim in self.dims.items() if _size_equal(dim.size, 1)}
return next(iter(dims.values())) if len(dims) == 1 else merge_shapes(dims)
@property
def non_singleton(self):
dims = {n: dim for n, dim in self.dims.items() if not _size_equal(dim.size, 1)}
return next(iter(dims.values())) if len(dims) == 1 else merge_shapes(dims)

@property
def well_defined(self):
Expand Down

0 comments on commit 13c9e53

Please sign in to comment.