diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index 522adae..4e7b32c 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -5,8 +5,9 @@ from numbers import Number from typing import TypeVar, Tuple, Dict, Union, Optional, Sequence, Any, Callable -from . import channel -from ._shape import Shape, DimFilter, batch, instance, shape, non_batch, merge_shapes, concat_shapes, spatial, parse_dim_order, dual, auto, parse_shape_spec, DIM_FUNCTIONS, INV_CHAR +from . import channel, EMPTY_SHAPE +from ._shape import Shape, DimFilter, batch, instance, shape, non_batch, merge_shapes, concat_shapes, spatial, parse_dim_order, dual, auto, parse_shape_spec, DIM_FUNCTIONS, \ + INV_CHAR, concat_shapes_, Dim from .magic import Sliceable, Shaped, Shapable, PhiTreeNode from ..backend import choose_backend, NoBackendFound from ..backend._dtype import DType @@ -88,7 +89,7 @@ def unstack(value, dim: DimFilter) -> tuple: (0.0, 0.0, 0.0, 0.0, 0.0) """ assert isinstance(value, Sliceable) and isinstance(value, Shaped), f"Cannot unstack {type(value).__name__}. Must be Sliceable and Shaped, see https://tum-pbs.github.io/PhiML/phiml/math/magic.html" - dims = shape(value).only(dim) + dims = shape(value).only(dim, reorder=True) if dims.rank == 0: return value, if dims.rank == 1: @@ -102,7 +103,7 @@ def unstack(value, dim: DimFilter) -> tuple: else: # multiple dimensions if hasattr(value, '__pack_dims__'): packed_dim = batch('_unstack') - value_packed = value.__pack_dims__(dims.names, packed_dim, pos=None) + value_packed = value.__pack_dims__(dims, packed_dim, pos=None) if value_packed is not NotImplemented: return unstack(value_packed, packed_dim) unstack_dim = _any_uniform_dim(dims) @@ -547,41 +548,54 @@ def rename_dims(value: PhiTreeNodeType, Same type as `value`. """ if isinstance(value, Shape): - return value._replace_names_and_types(dims, names) + old_dims, new_dims = _shape_replace(value, dims, names) + return value.replace(old_dims, new_dims) elif isinstance(value, (Number, bool)): return value assert isinstance(value, Shapable) and isinstance(value, Shaped), f"value must be a Shape or Shapable but got {type(value).__name__}" - dims = shape(value).only(dims).names if callable(dims) else parse_dim_order(dims) - existing_dims = shape(value).only(dims, reorder=True) - if isinstance(names, str) and names.startswith('(') and names.endswith(')'): - item_names = [s.strip() for s in names[1:-1].split(',')] - names = [shape(value)[d].with_size(item_names) for d in dims] - elif isinstance(names, str): - names = parse_dim_order(names) - elif callable(names): - names = names(**existing_dims.untyped_dict) - dims = existing_dims - assert len(dims) == len(names), f"names and dims must be of equal length but got #dims={len(dims)} and #names={len(names)}" - if not existing_dims: + old_dims, new_dims = _shape_replace(shape(value), dims, names) + if not new_dims: return value - existing_names = [n for i, n in enumerate(names) if dims[i] in existing_dims] - existing_names = existing_dims._replace_names_and_types(existing_dims, existing_names) # --- First try __replace_dims__ --- if hasattr(value, '__replace_dims__'): - result = value.__replace_dims__(existing_dims.names, existing_names, **kwargs) + result = value.__replace_dims__(old_dims.names, new_dims, **kwargs) if result is not NotImplemented: return result # --- Next try Tree Node --- if isinstance(value, PhiTreeNode): - return tree_map(rename_dims, value, all_attributes, treat_layout_as_leaf=True, dims=existing_dims, names=existing_names, **kwargs) + return tree_map(rename_dims, value, all_attributes, treat_layout_as_leaf=True, dims=old_dims, names=new_dims, **kwargs) # --- Fallback: unstack and stack --- - if shape(value).only(existing_dims).volume > 8: - warnings.warn(f"rename_dims() default implementation is slow on large dimensions ({existing_dims}). Please implement __replace_dims__() for {type(value).__name__} as defined in phiml.math.magic", RuntimeWarning, stacklevel=2) - for old_name, new_dim in zip(existing_dims.names, existing_names): + if shape(value).only(old_dims).volume > 8: + warnings.warn(f"rename_dims() default implementation is slow on large dimensions ({old_dims}). Please implement __replace_dims__() for {type(value).__name__} as defined in phiml.math.magic", RuntimeWarning, stacklevel=2) + for old_name, new_dim in zip(old_dims.names, new_dims): value = stack(unstack(value, old_name), new_dim, **kwargs) return value +def _shape_replace(shape: Shape, dims: DimFilter, new: DimFilter) -> Tuple[Shape, Shape]: # _replace_names_and_types + if callable(dims): + existing = dims(shape) + elif isinstance(dims, Shape): + existing = dims.only(shape) + else: + dims = parse_dim_order(dims) + existing = shape.only(dims, reorder=True) + if not existing: + return EMPTY_SHAPE, EMPTY_SHAPE + if isinstance(new, str) and new.startswith('(') and new.endswith(')'): + item_names = [s.strip() for s in new[1:-1].split(',')] + new = concat_shapes_(*[d.with_size(item_names) for d in existing]) + elif isinstance(new, str): + new = parse_dim_order(new) + assert len(new) == len(dims), f"Number of names {new} does not match dims to replace {dims}" + new = concat_shapes_(*[Dim(n, dim.size, dim.dim_type, dim.slice_names) for dim, n in zip(existing, new)]) + elif callable(new): + new = new(**existing.untyped_dict) + assert len(dims) == len(new), f"Number of names {new} does not match dims to replace {dims}" + return existing, new + + + def b2i(value: PhiTreeNodeType) -> PhiTreeNodeType: """ Change the type of all *batch* dimensions of `value` to *instance* dimensions. See `rename_dims`. """ return rename_dims(value, batch, instance) @@ -668,7 +682,7 @@ def pack_dims(value, dims: DimFilter, packed_dim: Union[Shape, str], pos: Option return unpack_dim(value, dims, packed_dim, **kwargs) # --- First try __pack_dims__ --- if hasattr(value, '__pack_dims__'): - result = value.__pack_dims__(dims.names, packed_dim, pos, **kwargs) + result = value.__pack_dims__(dims, packed_dim, pos, **kwargs) if result is not NotImplemented: return result # --- Next try Tree Node --- diff --git a/phiml/math/_nd.py b/phiml/math/_nd.py index e1271e9..5cfd119 100644 --- a/phiml/math/_nd.py +++ b/phiml/math/_nd.py @@ -6,7 +6,7 @@ from . import extrapolation as extrapolation from ._magic_ops import stack, rename_dims, concat, tree_map, value_attributes from ._ops import choose_backend_t, reshaped_native, reshaped_tensor -from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, instance, dual, auto, non_batch +from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, instance, dual, auto, non_batch, after_gather from ._tensors import Tensor, wrap, tensor, reshaped_numpy from .extrapolation import Extrapolation from .magic import PhiTreeNode @@ -910,6 +910,6 @@ def perform_query(np_query): def perform_query(np_vectors, np_query): return KDTree(np_vectors).query(np_query)[1] native_idx = b.numpy_call(perform_query, (query.shape.without(batch(vectors)).non_channel.volume,), DType(int, 64), native_vectors, native_query) - native_multi_idx = choose_backend(native_idx).unravel_index(native_idx, vectors.shape.after_gather(i).non_channel.sizes) + native_multi_idx = choose_backend(native_idx).unravel_index(native_idx, after_gather(vectors.shape, i).non_channel.sizes) result.append(reshaped_tensor(native_multi_idx, [query_i.shape.non_channel, index_dim or math.EMPTY_SHAPE])) return stack(result, batch(vectors)) diff --git a/phiml/math/_shape.py b/phiml/math/_shape.py index 016afa4..6274f8f 100644 --- a/phiml/math/_shape.py +++ b/phiml/math/_shape.py @@ -11,7 +11,7 @@ BATCH_DIM = 'batch' SPATIAL_DIM = 'spatial' CHANNEL_DIM = 'channel' -INSTANCE_DIM = 'înstance' +INSTANCE_DIM = 'instance' DUAL_DIM = 'dual' DIM_TYPES = (BATCH_DIM, DUAL_DIM, INSTANCE_DIM, SPATIAL_DIM, CHANNEL_DIM) TYPE_INDEX = {t: i for i, t in enumerate(DIM_TYPES)} @@ -57,6 +57,10 @@ def sizes(self) -> Sequence: """ ... + @property + def types(self) -> Sequence[str]: + ... + @property def item_names(self) -> Sequence[Optional[Sequence[str]]]: ... @@ -102,7 +106,7 @@ def index(self, dim: Union[str, 'Shape', None]) -> int: """ ... - def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int]: + def indices(self, dims: 'Shape') -> Tuple[int]: """ Finds the indices of the given dimensions within this `Shape`. @@ -411,7 +415,7 @@ def type(self) -> str: ... @property - def dim_type(self): + def dim_type(self) -> str: ... def mask(self, names: Union[tuple, list, set, 'Shape']): @@ -454,7 +458,7 @@ def __add__(self, other) -> 'Shape': def __sub__(self, other) -> 'Shape': ... - def only(self, dims: 'DimFilter', reorder=False): + def only(self, dims: 'DimFilter', reorder=False) -> 'Shape': """ Builds a new shape from this one that only contains the given dimensions. Dimensions in `dims` that are not part of this Shape are ignored. @@ -472,7 +476,7 @@ def only(self, dims: 'DimFilter', reorder=False): """ ... - def is_compatible(self, *others: 'Shape'): + def is_compatible(self, *others: 'Shape') -> bool: """ Checks if this shape and the others can be broadcast. @@ -657,7 +661,7 @@ def with_dim_size(self, dim: Union[str, 'Shape'], size: Union[int, 'math.Tensor' """ ... - def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', keep_item_names=True, replace_item_names: 'DimFilter' = None) -> 'Shape': + def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', replace_item_names: 'DimFilter' = None) -> 'Shape': """ Returns a copy of `self` with `dims` replaced by `new`. Dimensions that are not present in `self` are ignored. @@ -666,9 +670,8 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', keep_ite Args: dims: Dimensions to replace. - new: New dimensions, must have same length as `dims`. + new: New dimensions, must have same length as `dims` if `len(dims) > 1`. If a `Shape` is given, replaces the dimension types and item names as well. - keep_item_names: Keeps existing item names for dimensions where `new` does not specify item names if the new dimension has the same size. replace_item_names: For which dims the item names should be replaced as well. Returns: @@ -933,9 +936,8 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]: return 0 raise ValueError(f"index() requires a single dimension as input but got {dim}") - def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int, ...]: - names = dims.names if isinstance(dims, Shape) else dims - return tuple([self.index(n) for n in names]) + def indices(self, dims: Shape) -> Tuple[int, ...]: + return tuple([self.index(n) for n in dims.names]) def __getitem__(self, selection): if isinstance(selection, Shape): @@ -977,8 +979,6 @@ def get_item_names(self, dim: Union[str, 'Shape'], fallback_spatial=False) -> Un def __and__(self, other): if other is dual: return self & self.primal.as_dual() - if not isinstance(other, Shape): - other = shape(other) if isinstance(other, (Dim, PureShape)) and other.dim_type == self.dim_type: return pure_merge(self, other, allow_varying_sizes=False) elif isinstance(other, (Dim, PureShape)): @@ -998,6 +998,9 @@ def is_compatible(self, other): return False return self.dim_type == dim.dim_type + def isdisjoint(self, other) -> bool: + return self.name not in other + def only(self, dims: 'DimFilter', reorder=False): if dims is None: # keep none return EMPTY_SHAPE @@ -1085,6 +1088,18 @@ def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Sha def without_sizes(self): return Dim(self.name, None, self.dim_type, None) + def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', replace_item_names: 'DimFilter' = None): + assert self.name == parse_dim_names(dims, 1)[0] + return self._replace(new, replace_item_names) + + def _replace(self, new: 'Shape', replace_item_names: 'DimFilter' = None): + if self.slice_names is None: + return new + if self.only(replace_item_names) or self.slice_names is None or len(new) != 1 or not _size_equal(len(self.slice_names), new.size): + return new + else: # keep item names from self + return Dim(new.name, new.size, new.dim_type, self.slice_names) + def as_batch(self): name = _apply_prefix(self.name, BATCH_DIM) if self.dim_type == DUAL_DIM else self.name return Dim(name, self.size, BATCH_DIM, self.slice_names) @@ -1285,15 +1300,14 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]: return self.names.index(dim.name) raise ValueError(f"index() requires a single dimension as input but got {dim}") - def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int, ...]: - names = dims.names if isinstance(dims, Shape) else dims - return tuple([self.index(n) for n in names]) + def indices(self, dims: Shape) -> Tuple[int, ...]: + return tuple([self.index(n) for n in dims.names]) def __getitem__(self, selection): if isinstance(selection, int): return list(self.dims.values())[selection] elif isinstance(selection, slice): - return concat_shapes(list(self.dims.values())[selection]) + return concat_shapes(*list(self.dims.values())[selection]) elif isinstance(selection, str): if ',' in selection: selection = [self.index(s.strip()) for s in selection.split(',')] @@ -1326,8 +1340,6 @@ def get_item_names(self, dim: Union[str, 'Shape'], fallback_spatial=False) -> Un def __and__(self, other): if other is dual: return concat_shapes(self, self.primal.as_dual()) - if not isinstance(other, Shape): - other = shape(other) if isinstance(other, (Dim, PureShape)) and other.dim_type == self.dim_type: return pure_merge(self, other, allow_varying_sizes=False) elif isinstance(other, (Dim, PureShape)): @@ -1342,6 +1354,13 @@ def __and__(self, other): def is_compatible(self, other: Shape): return all(dim.is_compatible(other) for dim in self.dims.values()) + def isdisjoint(self, other) -> bool: + if isinstance(other, Dim): + return other.name in self.dims + if isinstance(other, PureShape): + return other.dim_type != self.dim_type or self.dims.keys().isdisjoint(other.dims) + return other.isdisjoint(self) + def only(self, dims: 'DimFilter', reorder=False): if not self.dims or dims is None: return EMPTY_SHAPE @@ -1395,7 +1414,7 @@ def without(self, dims: 'DimFilter'): dims = {n: dim for n, dim in self.dims.items() if n not in names} return next(iter(dims.values())) if len(dims) == 1 else PureShape(self.dim_type, dims) if isinstance(dims, (tuple, list, set)) and all([isinstance(d, str) for d in dims]): - dims = {n: dim for n, dim in self.dims if n not in dims} + dims = {n: dim for n, dim in self.dims.items() if n not in dims} return next(iter(dims.values())) if len(dims) == 1 else PureShape(self.dim_type, dims) elif isinstance(dims, (tuple, list, set)): result = self @@ -1430,12 +1449,19 @@ def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Sha if not self.dims: assert not sizes return self + if isinstance(sizes, int): + sizes = (sizes,) * len(self.dims) + elif isinstance(sizes, Shape): + sizes = tuple([sizes.get_size(dim.name) for dim in self.dims.values()]) dims = {dim.name: dim.with_size(size, keep_item_names) for dim, size in zip(self.dims.values(), sizes)} return PureShape(self.dim_type, dims) def without_sizes(self): return PureShape(self.dim_type, {n: dim.without_sizes() for n, dim in self.dims.items()}) + def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', replace_item_names: 'DimFilter' = None): + raise NotImplementedError + def as_batch(self): dims = [dim.as_batch() for dim in self.dims.values()] return PureShape(BATCH_DIM, {dim.name: dim for dim in dims}) @@ -1637,9 +1663,8 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]: return self.names.index(dim.name) raise ValueError(f"index() requires a single dimension as input but got {dim}") - def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int, ...]: - names = dims.names if isinstance(dims, Shape) else dims - return tuple([self.index(n) for n in names]) + def indices(self, dims: Shape) -> Tuple[int, ...]: + return tuple([self.index(n) for n in dims.names]) def __getitem__(self, selection): if isinstance(selection, int): @@ -1683,8 +1708,6 @@ def get_item_names(self, dim: Union[str, 'Shape'], fallback_spatial=False) -> Un def __and__(self, other): if other is dual: return self & self.primal.as_dual() - if not isinstance(other, Shape): - other = shape(other) if isinstance(other, (Dim, PureShape)): if not other: return self @@ -1699,6 +1722,18 @@ def __and__(self, other): def is_compatible(self, other: Shape): return all(dim.is_compatible(other) for dim in self.dims.values()) + def isdisjoint(self, other) -> bool: + if isinstance(other, Dim): + return other.name in self.dims + if isinstance(other, PureShape): + return other.isdisjoint(getattr(self, other.dim_type)) + assert isinstance(other, MixedShape) + return self.dims.keys().isdisjoint(other.dims) + + @property + def __empty__(self): + return EMPTY_SHAPE + def only(self, dims: 'DimFilter', reorder=False): if isinstance(dims, (Dim, PureShape)): return getattr(self, dims.dim_type).only(dims, reorder=reorder) @@ -1786,12 +1821,35 @@ def with_dim_size(self, dim: Union[str, 'Shape'], size: Union[int, 'math.Tensor' raise NotImplementedError def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Shape', int], keep_item_names=True): - dims = {dim.name: dim.with_size(size, keep_item_names) for dim, size in zip(self.dims.values(), sizes)} - return PureShape(self.dim_type, ) + if isinstance(sizes, int): + raise NotImplementedError + elif isinstance(sizes, Shape): + raise NotImplementedError + assert len(sizes) == len(self.dims) + sizes = {n: s for n, s in zip(self.dims, sizes)} + b = self.batch.with_sizes([sizes[n] for n in self.batch.names]) + d = self.dual.with_sizes([sizes[n] for n in self.dual.names]) + i = self.instance.with_sizes([sizes[n] for n in self.instance.names]) + s = self.spatial.with_sizes([sizes[n] for n in self.spatial.names]) + c = self.channel.with_sizes([sizes[n] for n in self.channel.names]) + dims = {**b.dims, **d.dims, **i.dims, **s.dims, **c.dims} + return MixedShape(b, d, i, s, c, {n: dims[n] for n in self.dims}) def without_sizes(self): raise NotImplementedError + def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape', replace_item_names: 'DimFilter' = None): + dims = parse_dim_order(dims) + dim_list = list(self.dims.values()) + if len(dims) == len(new): + for old, new_dim in zip(dims, new): + new_dim = self.dims[old]._replace(new_dim, replace_item_names) + dim_list[self.index(old)] = new_dim + elif len(new) > 1 and len(dims) == 1: + i = self.index(dims[0]) + dim_list[i:i+1] = new + return concat_shapes(*dim_list) + def as_batch(self): dims = [dim.as_batch() for dim in self.dims.values()] return PureShape(BATCH_DIM, {dim.name: dim for dim in dims}) @@ -1811,7 +1869,7 @@ def as_type(self, new_type: Callable): return {batch: self.as_batch, dual: self.as_dual, instance: self.as_instance, spatial: self.as_spatial, channel: self.as_channel}[new_type]() -EMPTY_SHAPE = PureShape('?', {}) +EMPTY_SHAPE = PureShape('__empty__', {}) """ Empty shape, `()` """ @@ -2221,61 +2279,56 @@ class InvalidShapeSpec(ValueError): def parse_shape_spec(input_string, default_type: Callable = None) -> Shape: - results = [] + dims = list[Dim]() pos = 0 while pos < len(input_string): if match := SPEC_PATTERNS['name_type_items'].match(input_string, pos): - tilde, name, type_, values = match.groups() - if tilde and type_ not in ('d', 'dual'): + tilde, n, t, values = match.groups() + if tilde and t not in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dimension names starting with ~ must be of type dual. Failed at index {pos}: {input_string[pos:]}") - elif not tilde and type_ in ('d', 'dual'): + elif not tilde and t in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dual dims must start with ~. Failed at index {pos}: {input_string[pos:]}") items = [n.strip() for n in values.split(',') if n.strip()] - results.append({'name': '~' + name if tilde else name, 'type': type_, 'values': items}) + dims.append(Dim('~' + n if tilde else n, len(items), INV_CHAR[t] if len(t) == 1 else t, tuple(items))) pos = match.end() + 1 elif match := SPEC_PATTERNS['name_type'].match(input_string, pos): - tilde, name, type_ = match.groups() - if tilde and type_ not in ('d', 'dual'): + tilde, n, t = match.groups() + if tilde and t not in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dimension names starting with ~ must be of type dual. Failed at index {pos}: {input_string[pos:]}") - elif not tilde and type_ in ('d', 'dual'): + elif not tilde and t in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dual dims must start with ~. Failed at index {pos}: {input_string[pos:]}") # Check if the next character is an equal sign followed by parentheses next_char_pos = pos + len(match.group()) if next_char_pos < len(input_string) and input_string[next_char_pos] == '=': raise ValueError(f"Invalid format at position {pos}: values must be inside parentheses") - results.append({'name': '~' + name if tilde else name, 'type': type_}) + dims.append(Dim('~' + n if tilde else n, None, INV_CHAR[t] if len(t) == 1 else t, None)) pos = match.end() + 1 elif match := SPEC_PATTERNS['name_items'].match(input_string, pos): - tilde, name, values = match.groups() + tilde, n, values = match.groups() items = [n.strip() for n in values.split(',') if n.strip()] - results.append({'name': '~' + name if tilde else name, 'type': 'd' if tilde else 'c', 'values': items}) + dims.append(Dim('~' + n if tilde else n, len(items), 'd' if tilde else 'c', tuple(items))) pos = match.end() + 1 elif match := SPEC_PATTERNS['items'].match(input_string, pos): tilde, values = match.groups() - results.append({'name': '~vector' if tilde else 'vector', 'type': 'd' if tilde else 'c', 'values': values.split(',')}) + items = [n.strip() for n in values.split(',') if n.strip()] + dims.append(Dim('~vector' if tilde else 'vector', len(items), 'd' if tilde else 'c', tuple(items))) pos = match.end() + 1 elif match := SPEC_PATTERNS['dual_name'].match(input_string, pos): - name, = match.groups() - results.append({'name': name, 'type': 'd'}) + n, = match.groups() + dims.append(Dim(n, None, DUAL_DIM, None)) pos = match.end() + 1 elif match := SPEC_PATTERNS['single_letter'].match(input_string, pos): - name, = match.groups() - results.append({'name': name, 'type': 's'}) + n, = match.groups() + dims.append(Dim(n, None, SPATIAL_DIM, None)) pos = match.end() + 1 elif default_type is not None and (match := SPEC_PATTERNS['name_only'].match(input_string, pos)): - name, = match.groups() + n, = match.groups() default_type_str = TYPE_BY_FUNCTION[default_type] - results.append({'name': name, 'type': default_type_str}) + dims.append(Dim(n, None, default_type_str, None)) pos = match.end() + 1 else: raise InvalidShapeSpec(input_string, f"Failed to parse from index {pos}: '{input_string[pos:]}'. Dims must be specified as name:type or name:type=(item_names...). Names and types may only be omitted if component names are given.") - names = [r['name'] for r in results] - types = [r['type'] for r in results] - types = [INV_CHAR[t] if len(t) == 1 else t for t in types] - item_names = [r.get('values', None) for r in results] - item_names = [tuple(items) if items is not None else None for items in item_names] - sizes = [len(items) if items is not None else None for items in item_names] - return Shape(tuple(sizes), tuple(names), tuple(types), tuple(item_names)) + return concat_shapes(*dims) DIM_FUNCTIONS = {BATCH_DIM: batch, SPATIAL_DIM: spatial, INSTANCE_DIM: instance, CHANNEL_DIM: channel, DUAL_DIM: dual} @@ -2521,15 +2574,29 @@ def concat_shapes(*shapes: Union[Shape, Any]) -> Shape: Returns: Combined `Shape`. """ - if len(shapes) == 0: - return EMPTY_SHAPE shapes = [obj if isinstance(obj, Shape) else shape(obj) for obj in shapes] + return concat_shapes_(*shapes) + + +def concat_shapes_(*shapes: Shape) -> Shape: + shapes = [s for s in shapes if s] + if not shapes: + return EMPTY_SHAPE if len(shapes) == 1: return shapes[0] - names = sum([s.names for s in shapes], ()) - if len(set(names)) != len(names): + all_dims = list[Dim]() + by_type = {BATCH_DIM: EMPTY_SHAPE, DUAL_DIM: EMPTY_SHAPE, INSTANCE_DIM: EMPTY_SHAPE, SPATIAL_DIM: EMPTY_SHAPE, CHANNEL_DIM: EMPTY_SHAPE} + for s in shapes: + if s: + all_dims.extend(s.dims.values()) + if isinstance(s, (Dim, PureShape)): + by_type[s.dim_type] &= s + else: + raise NotImplementedError + dims = {dim.name: dim for dim in all_dims} + if len(dims) != len(all_dims): raise IncompatibleShapes(f"Cannot concatenate shapes {list(shapes)}. Duplicate dimension names are not allowed.") - raise NotImplementedError + return MixedShape(dims=dims, **by_type) def shape_stack(stack_dim: Shape, *shapes: Shape, stack_dim_first=False): diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 45c0344..27e5361 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -336,8 +336,7 @@ def compress(self, dims: DimFilter): pointers = wrap(scipy_csr.indptr, instance('pointers')) return CompressedSparseMatrix(indices, pointers, values, u_dims, c_dims, self._indices_constant, uncompressed_indices=uncompressed_indices, uncompressed_indices_perm=perm, m_rank=self._matrix_rank) - def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': - dims = self._shape.only(dims) + def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': assert dims in self._dense_shape, f"Can only pack sparse dimensions on SparseCoordinateTensor but got {dims} of which {dims.without(self._dense_shape)} are not sparse" assert self._indices.default_backend is NUMPY, "Can only pack NumPy indices as of yet" inst_dim_order = instance(self._indices) @@ -816,9 +815,8 @@ def native(self, order: Union[str, tuple, list, Shape] = None, force_expand=True assert order is None, f"sparse matrices are always ordered (primal, dual). For custom ordering, use math.dense(tensor).native() instead." return native_matrix(self, NUMPY if to_numpy else self.default_backend) - def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': + def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': assert all(d in self._shape for d in dims) - dims = self._shape.only(dims, reorder=True) if dims.only(self._compressed_dims).is_empty: # pack cols uncompressed_dims = self._uncompressed_dims.replace(dims, packed_dim.with_size(dims.volume)) return CompressedSparseMatrix(self._indices, self._pointers, self._values, uncompressed_dims, self._compressed_dims, self._indices_constant, self._uncompressed_offset) @@ -918,9 +916,8 @@ def to_cs(self): values = pack_dims(self._values, self._uncompressed_dims + self._compressed_dims, instance('entries')) return CompressedSparseMatrix(indices, pointers, values, self._compressed_dims, self._uncompressed_dims, self._indices_constant, m_rank=self._matrix_rank) - def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': + def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': raise NotImplementedError - dims = self._shape.only(dims) assert dims in self._dense_shape, f"Can only pack sparse dimensions on SparseCoordinateTensor but got {dims} of which {dims.without(self._dense_shape)} are not sparse" assert self._indices.default_backend is NUMPY, "Can only pack NumPy indices as of yet" inst_dim_order = instance(self._indices) diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 3eff9c5..2602c77 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -3,7 +3,8 @@ import traceback import warnings from contextlib import contextmanager -from typing import Union, TypeVar, Sequence, Any, Dict +import typing +from typing import Union, TypeVar, Sequence, Any from dataclasses import dataclass from typing import Tuple, Callable, List @@ -16,7 +17,7 @@ CHANNEL_DIM, BATCH_DIM, SPATIAL_DIM, EMPTY_SHAPE, parse_dim_order, shape_stack, merge_shapes, channel, concat_shapes, primal, SUPERSCRIPT, IncompatibleShapes, INSTANCE_DIM, batch, spatial, dual, instance, shape, shape as shape_, DimFilter, non_batch, DEBUG_CHECKS, parse_shape_spec, - prepare_renaming_gather, after_gather) + prepare_renaming_gather, after_gather, concat_shapes_, Dim, PureShape) from ..backend import NoBackendFound, choose_backend, BACKENDS, get_precision, default_backend, convert as convert_, \ Backend, ComputeDevice, OBJECTS, NUMPY from ..backend._dtype import DType, combine_types @@ -554,23 +555,21 @@ def __unpack_dim__(self, dim: str, unpacked_dims: Shape, **kwargs) -> 'Tensor': return tensors[0] raise NotImplementedError - def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': - order = self.shape._order_group(dims) # ToDo only occurrence - # if isinstance(names, Shape): - # names = names.names - # result = [] - # for dim in self.names: - # if dim not in result: - # if dim in names: - # result.extend(names) - # else: - # result.append(dim) - + def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Tensor': if self.shape.is_uniform: + order = [] + for dim in self.shape.names: + if dim not in order: + if dim in dims: + order.extend(dims.names) + else: + order.append(dim) native = self._transposed_native(order, force_expand=True) if pos is None: pos = min(self.shape.indices(dims)) - new_shape = self.shape.without(dims)._expand(packed_dim.with_sizes([self.shape.only(dims).volume]), pos) + packed_dim = packed_dim.with_sizes([dims.volume]) + remaining = self.shape - dims + new_shape = concat_shapes_(remaining[:pos], packed_dim, remaining[pos:]) native = choose_backend(native).reshape(native, new_shape.sizes) return NativeTensor(native, new_shape) else: @@ -1120,15 +1119,15 @@ def __replace_dims__(self, dims: Tuple[str, ...], new_dims: Shape, **kwargs) -> new_stack_dim = self._stack_dim.replace(dims, new_dims) return Layout(self._obj, new_stack_dim) - def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Layout': - if dims == self._stack_dim.names: + def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Layout': + if dims.names == self._stack_dim.names: native = self._as_list() return Layout(native, packed_dim.with_size(len(native))) else: obj = [] - for i in self._shape.only(dims, reorder=True).meshgrid(): + for i in dims.meshgrid(): obj.append(self[i].native()) - return Layout(obj, concat_shapes(packed_dim.with_size(self.shape.only(dims).volume), self._stack_dim.without(dims))) + return Layout(obj, concat_shapes(packed_dim.with_size(dims.volume), self._stack_dim - dims)) def __unpack_dim__(self, dim: str, unpacked_dims: Shape, **kwargs) -> 'Layout': return NotImplemented @@ -1256,7 +1255,7 @@ def _transposed_native(self, order: Sequence[str], force_expand: bool): else: return backend.cast(self._native, DType(self.dtype.kind, precision=get_precision())) # --- Transpose --- - perm = [self._native_shape.index(dim) for dim in self._native_shape.only(order, reorder=True).names] + perm = [self._native_shape.index(dim) for dim in order if dim in self._native_shape] if perm != list(range(len(perm))): transposed = backend.transpose(self._native, perm) # this will cast automatically else: @@ -1306,9 +1305,9 @@ def default_backend(self) -> Backend: def _with_shape_replaced(self, new_shape): if new_shape.rank != self._shape.rank: raise IncompatibleShapes(f"Tensor {self} is not compatible with shape {new_shape}", self._shape, new_shape) - new_shape = Shape(self._shape.sizes, new_shape.names, new_shape.types, new_shape.item_names) + new_shape = new_shape.with_sizes(self._shape.sizes) native_indices = self._shape.indices(self._native_shape) - new_native_shape = new_shape[native_indices] + new_native_shape = concat_shapes_(*[new_shape[i] for i in native_indices]) return NativeTensor(self._native, new_native_shape, new_shape) def _with_natives_replaced(self, natives: list): @@ -1953,7 +1952,8 @@ def disassemble_tree(obj: PhiTreeNodeType, cache: bool, attr_type=variable_attri if backend == OBJECTS: return obj, [] sizes = backend.staticshape(obj) - shape = Shape(sizes, tuple([f"dim{i}" for i in range(len(sizes))]), (None,) * len(sizes), (None,) * len(sizes)) + dims = [Dim(f"dim{i}", s, CHANNEL_DIM, None) for i, s in enumerate(sizes)] + shape = PureShape(CHANNEL_DIM, {dim.name: dim for dim in dims}) return NATIVE_TENSOR, [NativeTensor(obj, shape)] except NoBackendFound: return obj, [] @@ -2877,8 +2877,8 @@ def _format_vector(self: Tensor, options: PrintOptions) -> str: colors = options.get_colors() if self.shape.rank > 1: self = flatten(self, channel('flat')) - if self.shape.get_item_names(0) is not None and options.include_shape is not False: - content = ", ".join([f"{item}={_format_number(number, options, self.dtype)}" for number, item in zip(self, self.shape.get_item_names(0))]) + if self.shape.item_names[0] is not None and options.include_shape is not False: + content = ", ".join([f"{item}={_format_number(number, options, self.dtype)}" for number, item in zip(self, self.shape.item_names[0])]) else: content = ", ".join([_format_number(num, options, self.dtype) for num in self]) return colors.value(f"({content})") diff --git a/phiml/math/magic.py b/phiml/math/magic.py index fe39398..29c6275 100644 --- a/phiml/math/magic.py +++ b/phiml/math/magic.py @@ -246,7 +246,7 @@ def __replace_dims__(self, dims: Tuple[str, ...], new_dims: Shape, **kwargs) -> """ raise NotImplementedError - def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Shapable': + def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], **kwargs) -> 'Shapable': """ Compresses multiple dimensions into a single dimension by concatenating the elements. Elements along the new dimensions are laid out according to the order of `dims`.