From b9561ff3cee81e65e4cd8dc25da3abb14eb10702 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 22 Dec 2024 16:27:33 +0100 Subject: [PATCH] Shape refactor 7 --- phiml/math/_magic_ops.py | 2 +- phiml/math/_ops.py | 2 +- phiml/math/_shape.py | 2 ++ phiml/math/_tensors.py | 2 +- phiml/math/_trace.py | 5 +++-- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index c0680075..1790d6f4 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -591,7 +591,7 @@ def _shape_replace(shape: Shape, dims: DimFilter, new: DimFilter) -> Tuple[Shape new = concat_shapes_(*[d.with_size(item_names) for d in existing]) elif isinstance(new, str): new = parse_dim_order(new) - assert len(new) == len(dims), f"Number of names {new} does not match dims to replace {dims}" + assert len(new) == len(existing), f"Number of names {new} does not match dims to replace {existing}" new = concat_shapes_(*[Dim(n, dim.size, dim.dim_type, dim.slice_names) for dim, n in zip(existing, new)]) elif callable(new): new = new(**existing.untyped_dict) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 73121c70..17e24877 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -1567,7 +1567,7 @@ def _mean(value: Tensor, dims: Shape) -> Tensor: if not dims: return value if isinstance(value, NativeTensor): - result = value.default_backend.mean(value._native, tuple(value._names.index(n) for n in dims.names)) + result = value.default_backend.mean(value._native, tuple(value._names.index(n) for n in dims.names if n in value._names)) return NativeTensor(result, [n for n in value._names if n not in dims], value.shape - dims) elif isinstance(value, TensorStack): if value._stack_dim in dims: diff --git a/phiml/math/_shape.py b/phiml/math/_shape.py index fde94313..971c44d6 100644 --- a/phiml/math/_shape.py +++ b/phiml/math/_shape.py @@ -1331,6 +1331,8 @@ def __and__(self, other): elif isinstance(other, (Dim, PureShape)): if not self: return other + if not other: + return self by_type = [EMPTY_SHAPE] * len(DIM_TYPES) by_type[TYPE_INDEX[self.dim_type]] = self by_type[TYPE_INDEX[other.dim_type]] = other diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 50fe1f25..c3d7aaf9 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1341,7 +1341,7 @@ def _getitem(self, selection: dict): assert isinstance(sel, int), f"Attempting slice missing dimension {name} with {selection}" gathered = self.default_backend.multi_slice(self._native, tuple(selections)) if selections else self._native new_native_shape = after_gather(self._shape[self._names], selection) - new_shape = self.collapsed_dims & new_native_shape + new_shape = after_gather(self.collapsed_dims, selection) & new_native_shape return NativeTensor(gathered, new_native_shape.names, new_shape) def _unstack(self, dim: str): diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index ff89c0e4..ed72c129 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -8,7 +8,8 @@ from ..backend import choose_backend, NUMPY, Backend from ._ops import choose_backend_t, concat_tensor, scatter, zeros_like -from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS +from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS, \ + after_gather from ._magic_ops import stack, expand, rename_dims, unpack_dim, unstack, value_attributes from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, \ discard_constant_dims, variable_shape, NativeTensor, equality_by_shape_and_value @@ -122,7 +123,7 @@ def _is_tracer(self) -> bool: def _getitem(self, selection: dict): starts = {dim: (item.start or 0) if isinstance(item, slice) else item for dim, item in selection.items()} - new_shape = math.zeros(self._shape)[selection].shape + new_shape = after_gather(self._shape, selection) return self.shift(starts, new_shape, lambda v: v[selection], lambda b: b[selection], nonzero_edge=False) def shift(self, shifts: Dict[str, int],