diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index 1790d6f4..5d00c618 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -476,7 +476,7 @@ def expand(value, *dims: Union[Shape, str], **kwargs): """ if not dims: return value - dims = concat_shapes(*[d if isinstance(d, Shape) else parse_shape_spec(d) for d in dims]) + dims = concat_shapes_(*[d if isinstance(d, Shape) else parse_shape_spec(d) for d in dims]) combined = merge_shapes(value, dims) # check that existing sizes match if not dims.without(shape(value)): # no new dims to add if set(dims) == set(shape(value).only(dims)): # sizes and item names might differ, though @@ -759,7 +759,7 @@ def unpack_dim(value, dim: DimFilter, *unpacked_dims: Union[Shape, Sequence[Shap return value # Nothing to do, maybe expand? assert dim.rank == 1, f"unpack_dim requires as single dimension to be unpacked but got {dim}" dim = dim.name - unpacked_dims = concat_shapes(*unpacked_dims) + unpacked_dims = concat_shapes_(*unpacked_dims) if unpacked_dims.rank == 0: return value[{dim: 0}] # remove dim elif unpacked_dims.rank == 1: diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 6c266bc1..f2ebf1da 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -12,7 +12,7 @@ from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze, ipack from ._shape import (Shape, EMPTY_SHAPE, spatial, batch, channel, instance, merge_shapes, parse_dim_order, concat_shapes, - IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual, resolve_index) + IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual, resolve_index, concat_shapes_) from . import extrapolation as e_ from ._tensors import (Tensor, wrap, tensor, broadcastable_native_tensors, Dense, TensorStack, custom_op2, compatible_tensor, variable_attributes, disassemble_tree, assemble_tree, @@ -350,7 +350,7 @@ def _includes_slice(s_dict: dict, dim: Shape, i: int): def _initialize(uniform_initializer, shapes: Tuple[Shape]) -> Tensor: - shape = concat_shapes(*shapes) + shape = concat_shapes_(*shapes) assert shape.well_defined, f"When creating a Tensor, shape needs to have definitive sizes but got {shape}" if shape.is_non_uniform: stack_dim = shape.non_uniform_shape[0] @@ -492,7 +492,7 @@ def random_permutation(*shape: Union[Shape, Any], dims=non_batch, index_dim=chan `Tensor` """ assert dims is not batch, f"dims cannot include all batch dims because that violates the batch principle. Specify batch dims by name instead." - shape = concat_shapes(*shape) + shape = concat_shapes_(*shape) assert not shape.dual_rank, f"random_permutation does not support dual dims but got {shape}" perm_dims = shape.only(dims) batches = shape - perm_dims @@ -839,7 +839,7 @@ def range_tensor(*shape: Shape): Returns: `Tensor` """ - shape = concat_shapes(*shape) + shape = concat_shapes_(*shape) data = arange(spatial('range'), 0, shape.volume) return unpack_dim(data, 'range', shape) @@ -2170,7 +2170,7 @@ def tensor_dot(x, y): assert x_dims.volume == y_dims.volume, f"Failed to reduce {x_dims} against {y_dims} in dot product of {x.shape} and {y.shape}. Sizes do not match." if remaining_shape_y.isdisjoint(remaining_shape_x): # no shared batch dimensions -> tensordot result_native = backend.tensordot(x_native, x.shape.indices(x_dims.names), y_native, y.shape.indices(y_dims.names)) - result_shape = concat_shapes(remaining_shape_x, remaining_shape_y) + result_shape = remaining_shape_x + remaining_shape_y else: # shared batch dimensions -> einsum result_shape = merge_shapes(x.shape.without(x_dims), y.shape.without(y_dims)) REDUCE_LETTERS = list('ijklmn') diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 6d1a7cf7..5c0b398f 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -10,7 +10,7 @@ from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim, unstack from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, \ - concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal + concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal, concat_shapes_ from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for from ..backend import choose_backend, NUMPY, Backend, get_precision from ..backend._dtype import DType @@ -72,7 +72,7 @@ def sparse_tensor(indices: Optional[Tensor], indices_constant = True # --- type of sparse tensor --- if dense_shape in indices: # compact - compressed = concat_shapes([dim for dim in dense_shape if dim.size > indices.shape.get_size(dim)]) + compressed = concat_shapes_(*[dim for dim in dense_shape if dim.size > indices.shape.get_size(dim)]) values = expand(1, non_batch(indices)) sparse = CompactSparseTensor(indices, values, compressed, indices_constant) else: @@ -349,7 +349,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], * idx_packed = np.ravel_multi_index(idx_to_pack.native([channel, instance(idx_to_pack)]), dims.sizes) idx_packed = expand(reshaped_tensor(idx_packed, [instance('sp_entries')]), channel(sparse_idx=packed_dim.name)) indices = concat([indices.sparse_idx[list(self._dense_shape.without(dims).names)], idx_packed], 'sparse_idx') - dense_shape = concat_shapes(self._dense_shape.without(dims), packed_dim.with_size(dims.volume)) + dense_shape = self._dense_shape.without(dims) + packed_dim.with_size(dims.volume) idx_sorted = self._indices_sorted and False # ToDo still sorted if dims are ordered correctly and no other dim in between and inserted at right point return SparseCoordinateTensor(indices, values, dense_shape, self._can_contain_double_entries, idx_sorted, self._indices_constant) @@ -812,7 +812,7 @@ def decompress(self): values = reshaped_tensor(native_values, [ind_batch & batch(self._values), instance(self._values), channel(self._values)], convert=False) else: raise NotImplementedError() - return SparseCoordinateTensor(indices, values, concat_shapes(self._compressed_dims, self._uncompressed_dims), False, True, self._indices_constant, self._matrix_rank) + return SparseCoordinateTensor(indices, values, self._compressed_dims + self._uncompressed_dims, False, True, self._indices_constant, self._matrix_rank) if self._uncompressed_indices_perm is not None: self._uncompressed_indices = self._uncompressed_indices[self._uncompressed_indices_perm] self._uncompressed_indices_perm = None @@ -938,7 +938,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], * idx_packed = np.ravel_multi_index(idx_to_pack.native([channel, instance(idx_to_pack)]), dims.sizes) idx_packed = expand(reshaped_tensor(idx_packed, [instance('sp_entries')]), channel(sparse_idx=packed_dim.name)) indices = concat([indices.sparse_idx[list(self._dense_shape.without(dims).names)], idx_packed], 'sparse_idx') - dense_shape = concat_shapes(self._dense_shape.without(dims), packed_dim.with_size(dims.volume)) + dense_shape = self._dense_shape.without(dims) + packed_dim.with_size(dims.volume) return CompactSparseTensor(indices, values, dense_shape, self._indices_constant) def _with_shape_replaced(self, new_shape: Shape): @@ -1640,7 +1640,7 @@ def add_sparse_batch_dim(matrix: Tensor, in_dims: Shape, out_dims: Shape): # # offset = wrap([*idx.values()] * 2 + [0] * non_instance(matrix._indices).volume, channel(indices)) # # all_indices.append(indices + offset) # # indices = concat(all_indices, instance(indices)) -# # values = pack_dims(matrix._values, concat_shapes(dims, instance(matrix._values)), instance(matrix._values)) +# # values = pack_dims(matrix._values, dims + instance(matrix._values), instance(matrix._values)) # dense_shape = in_dims & matrix._dense_shape & out_dims # return SparseCoordinateTensor(indices, values, dense_shape, matrix._can_contain_double_entries, matrix._indices_sorted, matrix._indices_constant) # raise NotImplementedError diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 20a16051..75da90af 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1108,7 +1108,7 @@ def __bool__(self): def __stack__(self, values: tuple, dim: Shape, **kwargs) -> 'Layout': obj = [v.native(self._stack_dim) for v in values] - new_stack_dim = concat_shapes(dim, self._stack_dim) + new_stack_dim = dim + self._stack_dim return Layout(obj, new_stack_dim) @staticmethod @@ -1126,7 +1126,7 @@ def __expand__(self, dims: Shape, **kwargs) -> 'Tensor': for dim in reversed(new_stack_dims): assert isinstance(dim.size, int), "Can only expand layouts by integer-sized dimensions" obj = [obj] * dim.size - return Layout(obj, concat_shapes(new_stack_dims, self._stack_dim)) + return Layout(obj, new_stack_dims + self._stack_dim) def __replace_dims__(self, dims: Tuple[str, ...], new_dims: Shape, **kwargs) -> 'Tensor': new_stack_dim = self._stack_dim.replace(dims, new_dims) @@ -1140,7 +1140,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], * obj = [] for i in dims.meshgrid(): obj.append(self[i].native()) - return Layout(obj, concat_shapes(packed_dim.with_size(dims.volume), self._stack_dim - dims)) + return Layout(obj, packed_dim.with_size(dims.volume) + (self._stack_dim - dims)) def __unpack_dim__(self, dim: str, unpacked_dims: Shape, **kwargs) -> 'Layout': return NotImplemented @@ -1177,7 +1177,7 @@ def _assert_close(self, other: Tensor, rel_tolerance: float, abs_tolerance: floa def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> Tensor: obj = self._recursive_op2(self._obj, self._stack_dim, other, operator, native_function, op_name) - new_stack = concat_shapes(self._stack_dim, other._stack_dim.without(self._stack_dim)) if isinstance(other, Layout) else self._stack_dim + new_stack = self._stack_dim + (other._stack_dim - self._stack_dim) if isinstance(other, Layout) else self._stack_dim return Layout(obj, new_stack) @staticmethod @@ -1689,7 +1689,7 @@ def tensor(data, if len(shape) == 1 and isinstance(shape[0], list): return reshaped_tensor(data, shape[0], convert=convert) shape = [parse_shape_spec(s) if isinstance(s, str) else s for s in shape] - shape = None if len(shape) == 0 else concat_shapes(*shape) + shape = None if len(shape) == 0 else concat_shapes_(*shape) if isinstance(data, Shape): if shape is None: shape = channel('dims') @@ -1800,7 +1800,7 @@ def layout(objects, *shape: Union[Shape, str]) -> Tensor: """ shape = [parse_shape_spec(s) if isinstance(s, str) else s for s in shape] assert all(isinstance(s, Shape) for s in shape), f"shape needs to be one or multiple Shape instances but got {shape}" - shape = EMPTY_SHAPE if len(shape) == 0 else concat_shapes(*shape) + shape = EMPTY_SHAPE if len(shape) == 0 else concat_shapes_(*shape) if isinstance(objects, Layout): assert objects.shape == shape return objects diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index 80dc4388..1ed21bc7 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -9,7 +9,7 @@ from ..backend import choose_backend, NUMPY, Backend from ._ops import backend_for, concat_tensor, scatter, zeros_like from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS, \ - after_gather + after_gather, concat_shapes_ from ._magic_ops import stack, expand, rename_dims, unpack_dim, unstack, value_attributes from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, \ discard_constant_dims, variable_shape, Dense, equality_by_shape_and_value @@ -465,7 +465,7 @@ def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tenso from ._ops import dot missing_self_dims = self_dims.without(self._matrix.shape) if missing_self_dims: - new_source_dims = concat_shapes([dual(**{n + '_src': v for n, v in d.untyped_dict.items()}) for d in missing_self_dims]) + new_source_dims = concat_shapes_(*[dual(**{n + '_src': v for n, v in d.untyped_dict.items()}) for d in missing_self_dims]) batched = add_sparse_batch_dim(self._matrix, new_source_dims, missing_self_dims) # to preserve the source dim else: batched = self._matrix @@ -745,7 +745,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo indices = math.concat_tensor(indices, 'entries') values = math.concat_tensor([m._values for m in matrices], 'entries') # matrix = stack(matrices, tracer._stack_dim) - dense_shape = concat_shapes(matrices[0]._dense_shape, tracer._stack_dim) + dense_shape = matrices[0]._dense_shape + tracer._stack_dim matrix = SparseCoordinateTensor(indices, values, dense_shape, can_contain_double_entries=False, indices_sorted=False, indices_constant=True) else: matrix = stack(matrices, tracer._stack_dim) @@ -790,7 +790,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo indices = wrap(indices_np, instance('entries'), channel(sparse_idx=(sliced_src_shape if separate_independent else src_shape).names + out_shape.names)) backend = choose_backend(*values) values = math.reshaped_tensor(backend.concat(values, axis=-1), [batch_val, instance('entries')], convert=False) - dense_shape = concat_shapes((sliced_src_shape if separate_independent else src_shape) & out_shape) + dense_shape = (sliced_src_shape if separate_independent else src_shape) & out_shape max_rank = out_shape.volume - tracer.min_rank_deficiency() matrix = SparseCoordinateTensor(indices, values, dense_shape, can_contain_double_entries=False, indices_sorted=False, indices_constant=True, m_rank=max_rank) return matrix, tracer._bias diff --git a/phiml/math/magic.py b/phiml/math/magic.py index 510c3f0a..5380fd5e 100644 --- a/phiml/math/magic.py +++ b/phiml/math/magic.py @@ -20,7 +20,7 @@ import dataclasses -from ._shape import Shape, shape, channel, non_batch, batch, spatial, instance, concat_shapes, dual, PureShape, Dim, MixedShape, DEBUG_CHECKS +from ._shape import Shape, shape, channel, non_batch, batch, spatial, instance, concat_shapes, dual, PureShape, Dim, MixedShape, DEBUG_CHECKS, concat_shapes_ from ..backend._dtype import DType @@ -745,7 +745,7 @@ def retype(self, dim_type: Callable, **kwargs): `phiml.math.rename_dims()` """ s = shape(self.obj) - new_dims = concat_shapes(*[dim_type(**{dim: s.get_item_names(dim) or s.get_size(dim)}) for dim in self.dims]) + new_dims = concat_shapes_(*[dim_type(**{dim: s.get_item_names(dim) or s.get_size(dim)}) for dim in self.dims]) from ._magic_ops import rename_dims return rename_dims(self.obj, self.dims, new_dims, **kwargs) diff --git a/tests/commit/math/test__shape.py b/tests/commit/math/test__shape.py index 9d10273c..78d74f72 100644 --- a/tests/commit/math/test__shape.py +++ b/tests/commit/math/test__shape.py @@ -108,7 +108,7 @@ def test_filters(self): self.assertEqual(spatial(obj), s) self.assertEqual(instance(obj), i) self.assertEqual(channel(obj), c) - self.assertEqual(set(non_batch(obj)), set(math.concat_shapes(s, i, c))) + self.assertEqual(set(non_batch(obj)), set(s + i + c)) self.assertEqual(set(non_spatial(obj)), set(math.concat_shapes(b, i, c))) self.assertEqual(set(non_instance(obj)), set(math.concat_shapes(b, s, c))) self.assertEqual(set(non_channel(obj)), set(math.concat_shapes(b, s, i)))