Skip to content

Commit

Permalink
Shape refactor 7
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 22, 2024
1 parent d097fc4 commit b9561ff
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def _shape_replace(shape: Shape, dims: DimFilter, new: DimFilter) -> Tuple[Shape
new = concat_shapes_(*[d.with_size(item_names) for d in existing])
elif isinstance(new, str):
new = parse_dim_order(new)
assert len(new) == len(dims), f"Number of names {new} does not match dims to replace {dims}"
assert len(new) == len(existing), f"Number of names {new} does not match dims to replace {existing}"
new = concat_shapes_(*[Dim(n, dim.size, dim.dim_type, dim.slice_names) for dim, n in zip(existing, new)])
elif callable(new):
new = new(**existing.untyped_dict)
Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ def _mean(value: Tensor, dims: Shape) -> Tensor:
if not dims:
return value
if isinstance(value, NativeTensor):
result = value.default_backend.mean(value._native, tuple(value._names.index(n) for n in dims.names))
result = value.default_backend.mean(value._native, tuple(value._names.index(n) for n in dims.names if n in value._names))
return NativeTensor(result, [n for n in value._names if n not in dims], value.shape - dims)
elif isinstance(value, TensorStack):
if value._stack_dim in dims:
Expand Down
2 changes: 2 additions & 0 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,8 @@ def __and__(self, other):
elif isinstance(other, (Dim, PureShape)):
if not self:
return other
if not other:
return self
by_type = [EMPTY_SHAPE] * len(DIM_TYPES)
by_type[TYPE_INDEX[self.dim_type]] = self
by_type[TYPE_INDEX[other.dim_type]] = other
Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ def _getitem(self, selection: dict):
assert isinstance(sel, int), f"Attempting slice missing dimension {name} with {selection}"
gathered = self.default_backend.multi_slice(self._native, tuple(selections)) if selections else self._native
new_native_shape = after_gather(self._shape[self._names], selection)
new_shape = self.collapsed_dims & new_native_shape
new_shape = after_gather(self.collapsed_dims, selection) & new_native_shape
return NativeTensor(gathered, new_native_shape.names, new_shape)

def _unstack(self, dim: str):
Expand Down
5 changes: 3 additions & 2 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ..backend import choose_backend, NUMPY, Backend
from ._ops import choose_backend_t, concat_tensor, scatter, zeros_like
from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS
from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS, \
after_gather
from ._magic_ops import stack, expand, rename_dims, unpack_dim, unstack, value_attributes
from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, \
discard_constant_dims, variable_shape, NativeTensor, equality_by_shape_and_value
Expand Down Expand Up @@ -122,7 +123,7 @@ def _is_tracer(self) -> bool:

def _getitem(self, selection: dict):
starts = {dim: (item.start or 0) if isinstance(item, slice) else item for dim, item in selection.items()}
new_shape = math.zeros(self._shape)[selection].shape
new_shape = after_gather(self._shape, selection)
return self.shift(starts, new_shape, lambda v: v[selection], lambda b: b[selection], nonzero_edge=False)

def shift(self, shifts: Dict[str, int],
Expand Down

0 comments on commit b9561ff

Please sign in to comment.