Skip to content

Commit

Permalink
Optimization: Use concat_shapes_ or Shape+Shape internally
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 26, 2024
1 parent 71aefa9 commit e617a6b
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 26 deletions.
4 changes: 2 additions & 2 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def expand(value, *dims: Union[Shape, str], **kwargs):
"""
if not dims:
return value
dims = concat_shapes(*[d if isinstance(d, Shape) else parse_shape_spec(d) for d in dims])
dims = concat_shapes_(*[d if isinstance(d, Shape) else parse_shape_spec(d) for d in dims])
combined = merge_shapes(value, dims) # check that existing sizes match
if not dims.without(shape(value)): # no new dims to add
if set(dims) == set(shape(value).only(dims)): # sizes and item names might differ, though
Expand Down Expand Up @@ -759,7 +759,7 @@ def unpack_dim(value, dim: DimFilter, *unpacked_dims: Union[Shape, Sequence[Shap
return value # Nothing to do, maybe expand?
assert dim.rank == 1, f"unpack_dim requires as single dimension to be unpacked but got {dim}"
dim = dim.name
unpacked_dims = concat_shapes(*unpacked_dims)
unpacked_dims = concat_shapes_(*unpacked_dims)
if unpacked_dims.rank == 0:
return value[{dim: 0}] # remove dim
elif unpacked_dims.rank == 1:
Expand Down
10 changes: 5 additions & 5 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze, ipack
from ._shape import (Shape, EMPTY_SHAPE,
spatial, batch, channel, instance, merge_shapes, parse_dim_order, concat_shapes,
IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual, resolve_index)
IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual, resolve_index, concat_shapes_)
from . import extrapolation as e_
from ._tensors import (Tensor, wrap, tensor, broadcastable_native_tensors, Dense, TensorStack,
custom_op2, compatible_tensor, variable_attributes, disassemble_tree, assemble_tree,
Expand Down Expand Up @@ -350,7 +350,7 @@ def _includes_slice(s_dict: dict, dim: Shape, i: int):


def _initialize(uniform_initializer, shapes: Tuple[Shape]) -> Tensor:
shape = concat_shapes(*shapes)
shape = concat_shapes_(*shapes)
assert shape.well_defined, f"When creating a Tensor, shape needs to have definitive sizes but got {shape}"
if shape.is_non_uniform:
stack_dim = shape.non_uniform_shape[0]
Expand Down Expand Up @@ -492,7 +492,7 @@ def random_permutation(*shape: Union[Shape, Any], dims=non_batch, index_dim=chan
`Tensor`
"""
assert dims is not batch, f"dims cannot include all batch dims because that violates the batch principle. Specify batch dims by name instead."
shape = concat_shapes(*shape)
shape = concat_shapes_(*shape)
assert not shape.dual_rank, f"random_permutation does not support dual dims but got {shape}"
perm_dims = shape.only(dims)
batches = shape - perm_dims
Expand Down Expand Up @@ -839,7 +839,7 @@ def range_tensor(*shape: Shape):
Returns:
`Tensor`
"""
shape = concat_shapes(*shape)
shape = concat_shapes_(*shape)
data = arange(spatial('range'), 0, shape.volume)
return unpack_dim(data, 'range', shape)

Expand Down Expand Up @@ -2170,7 +2170,7 @@ def tensor_dot(x, y):
assert x_dims.volume == y_dims.volume, f"Failed to reduce {x_dims} against {y_dims} in dot product of {x.shape} and {y.shape}. Sizes do not match."
if remaining_shape_y.isdisjoint(remaining_shape_x): # no shared batch dimensions -> tensordot
result_native = backend.tensordot(x_native, x.shape.indices(x_dims.names), y_native, y.shape.indices(y_dims.names))
result_shape = concat_shapes(remaining_shape_x, remaining_shape_y)
result_shape = remaining_shape_x + remaining_shape_y
else: # shared batch dimensions -> einsum
result_shape = merge_shapes(x.shape.without(x_dims), y.shape.without(y_dims))
REDUCE_LETTERS = list('ijklmn')
Expand Down
12 changes: 6 additions & 6 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim, unstack
from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, \
concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal
concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal, concat_shapes_
from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for
from ..backend import choose_backend, NUMPY, Backend, get_precision
from ..backend._dtype import DType
Expand Down Expand Up @@ -72,7 +72,7 @@ def sparse_tensor(indices: Optional[Tensor],
indices_constant = True
# --- type of sparse tensor ---
if dense_shape in indices: # compact
compressed = concat_shapes([dim for dim in dense_shape if dim.size > indices.shape.get_size(dim)])
compressed = concat_shapes_(*[dim for dim in dense_shape if dim.size > indices.shape.get_size(dim)])
values = expand(1, non_batch(indices))
sparse = CompactSparseTensor(indices, values, compressed, indices_constant)
else:
Expand Down Expand Up @@ -349,7 +349,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], *
idx_packed = np.ravel_multi_index(idx_to_pack.native([channel, instance(idx_to_pack)]), dims.sizes)
idx_packed = expand(reshaped_tensor(idx_packed, [instance('sp_entries')]), channel(sparse_idx=packed_dim.name))
indices = concat([indices.sparse_idx[list(self._dense_shape.without(dims).names)], idx_packed], 'sparse_idx')
dense_shape = concat_shapes(self._dense_shape.without(dims), packed_dim.with_size(dims.volume))
dense_shape = self._dense_shape.without(dims) + packed_dim.with_size(dims.volume)
idx_sorted = self._indices_sorted and False # ToDo still sorted if dims are ordered correctly and no other dim in between and inserted at right point
return SparseCoordinateTensor(indices, values, dense_shape, self._can_contain_double_entries, idx_sorted, self._indices_constant)

Expand Down Expand Up @@ -812,7 +812,7 @@ def decompress(self):
values = reshaped_tensor(native_values, [ind_batch & batch(self._values), instance(self._values), channel(self._values)], convert=False)
else:
raise NotImplementedError()
return SparseCoordinateTensor(indices, values, concat_shapes(self._compressed_dims, self._uncompressed_dims), False, True, self._indices_constant, self._matrix_rank)
return SparseCoordinateTensor(indices, values, self._compressed_dims + self._uncompressed_dims, False, True, self._indices_constant, self._matrix_rank)
if self._uncompressed_indices_perm is not None:
self._uncompressed_indices = self._uncompressed_indices[self._uncompressed_indices_perm]
self._uncompressed_indices_perm = None
Expand Down Expand Up @@ -938,7 +938,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], *
idx_packed = np.ravel_multi_index(idx_to_pack.native([channel, instance(idx_to_pack)]), dims.sizes)
idx_packed = expand(reshaped_tensor(idx_packed, [instance('sp_entries')]), channel(sparse_idx=packed_dim.name))
indices = concat([indices.sparse_idx[list(self._dense_shape.without(dims).names)], idx_packed], 'sparse_idx')
dense_shape = concat_shapes(self._dense_shape.without(dims), packed_dim.with_size(dims.volume))
dense_shape = self._dense_shape.without(dims) + packed_dim.with_size(dims.volume)
return CompactSparseTensor(indices, values, dense_shape, self._indices_constant)

def _with_shape_replaced(self, new_shape: Shape):
Expand Down Expand Up @@ -1640,7 +1640,7 @@ def add_sparse_batch_dim(matrix: Tensor, in_dims: Shape, out_dims: Shape):
# # offset = wrap([*idx.values()] * 2 + [0] * non_instance(matrix._indices).volume, channel(indices))
# # all_indices.append(indices + offset)
# # indices = concat(all_indices, instance(indices))
# # values = pack_dims(matrix._values, concat_shapes(dims, instance(matrix._values)), instance(matrix._values))
# # values = pack_dims(matrix._values, dims + instance(matrix._values), instance(matrix._values))
# dense_shape = in_dims & matrix._dense_shape & out_dims
# return SparseCoordinateTensor(indices, values, dense_shape, matrix._can_contain_double_entries, matrix._indices_sorted, matrix._indices_constant)
# raise NotImplementedError
Expand Down
12 changes: 6 additions & 6 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def __bool__(self):

def __stack__(self, values: tuple, dim: Shape, **kwargs) -> 'Layout':
obj = [v.native(self._stack_dim) for v in values]
new_stack_dim = concat_shapes(dim, self._stack_dim)
new_stack_dim = dim + self._stack_dim
return Layout(obj, new_stack_dim)

@staticmethod
Expand All @@ -1126,7 +1126,7 @@ def __expand__(self, dims: Shape, **kwargs) -> 'Tensor':
for dim in reversed(new_stack_dims):
assert isinstance(dim.size, int), "Can only expand layouts by integer-sized dimensions"
obj = [obj] * dim.size
return Layout(obj, concat_shapes(new_stack_dims, self._stack_dim))
return Layout(obj, new_stack_dims + self._stack_dim)

def __replace_dims__(self, dims: Tuple[str, ...], new_dims: Shape, **kwargs) -> 'Tensor':
new_stack_dim = self._stack_dim.replace(dims, new_dims)
Expand All @@ -1140,7 +1140,7 @@ def __pack_dims__(self, dims: Shape, packed_dim: Shape, pos: Union[int, None], *
obj = []
for i in dims.meshgrid():
obj.append(self[i].native())
return Layout(obj, concat_shapes(packed_dim.with_size(dims.volume), self._stack_dim - dims))
return Layout(obj, packed_dim.with_size(dims.volume) + (self._stack_dim - dims))

def __unpack_dim__(self, dim: str, unpacked_dims: Shape, **kwargs) -> 'Layout':
return NotImplemented
Expand Down Expand Up @@ -1177,7 +1177,7 @@ def _assert_close(self, other: Tensor, rel_tolerance: float, abs_tolerance: floa

def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> Tensor:
obj = self._recursive_op2(self._obj, self._stack_dim, other, operator, native_function, op_name)
new_stack = concat_shapes(self._stack_dim, other._stack_dim.without(self._stack_dim)) if isinstance(other, Layout) else self._stack_dim
new_stack = self._stack_dim + (other._stack_dim - self._stack_dim) if isinstance(other, Layout) else self._stack_dim
return Layout(obj, new_stack)

@staticmethod
Expand Down Expand Up @@ -1689,7 +1689,7 @@ def tensor(data,
if len(shape) == 1 and isinstance(shape[0], list):
return reshaped_tensor(data, shape[0], convert=convert)
shape = [parse_shape_spec(s) if isinstance(s, str) else s for s in shape]
shape = None if len(shape) == 0 else concat_shapes(*shape)
shape = None if len(shape) == 0 else concat_shapes_(*shape)
if isinstance(data, Shape):
if shape is None:
shape = channel('dims')
Expand Down Expand Up @@ -1800,7 +1800,7 @@ def layout(objects, *shape: Union[Shape, str]) -> Tensor:
"""
shape = [parse_shape_spec(s) if isinstance(s, str) else s for s in shape]
assert all(isinstance(s, Shape) for s in shape), f"shape needs to be one or multiple Shape instances but got {shape}"
shape = EMPTY_SHAPE if len(shape) == 0 else concat_shapes(*shape)
shape = EMPTY_SHAPE if len(shape) == 0 else concat_shapes_(*shape)
if isinstance(objects, Layout):
assert objects.shape == shape
return objects
Expand Down
8 changes: 4 additions & 4 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..backend import choose_backend, NUMPY, Backend
from ._ops import backend_for, concat_tensor, scatter, zeros_like
from ._shape import Shape, parse_dim_order, merge_shapes, spatial, instance, batch, concat_shapes, EMPTY_SHAPE, dual, channel, non_batch, primal, non_channel, DEBUG_CHECKS, \
after_gather
after_gather, concat_shapes_
from ._magic_ops import stack, expand, rename_dims, unpack_dim, unstack, value_attributes
from ._tensors import Tensor, wrap, disassemble_tree, disassemble_tensors, assemble_tree, TensorStack, may_vary_along, \
discard_constant_dims, variable_shape, Dense, equality_by_shape_and_value
Expand Down Expand Up @@ -465,7 +465,7 @@ def _matmul(self, self_dims: Shape, matrix: Tensor, matrix_dims: Shape) -> Tenso
from ._ops import dot
missing_self_dims = self_dims.without(self._matrix.shape)
if missing_self_dims:
new_source_dims = concat_shapes([dual(**{n + '_src': v for n, v in d.untyped_dict.items()}) for d in missing_self_dims])
new_source_dims = concat_shapes_(*[dual(**{n + '_src': v for n, v in d.untyped_dict.items()}) for d in missing_self_dims])
batched = add_sparse_batch_dim(self._matrix, new_source_dims, missing_self_dims) # to preserve the source dim
else:
batched = self._matrix
Expand Down Expand Up @@ -745,7 +745,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
indices = math.concat_tensor(indices, 'entries')
values = math.concat_tensor([m._values for m in matrices], 'entries')
# matrix = stack(matrices, tracer._stack_dim)
dense_shape = concat_shapes(matrices[0]._dense_shape, tracer._stack_dim)
dense_shape = matrices[0]._dense_shape + tracer._stack_dim
matrix = SparseCoordinateTensor(indices, values, dense_shape, can_contain_double_entries=False, indices_sorted=False, indices_constant=True)
else:
matrix = stack(matrices, tracer._stack_dim)
Expand Down Expand Up @@ -790,7 +790,7 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
indices = wrap(indices_np, instance('entries'), channel(sparse_idx=(sliced_src_shape if separate_independent else src_shape).names + out_shape.names))
backend = choose_backend(*values)
values = math.reshaped_tensor(backend.concat(values, axis=-1), [batch_val, instance('entries')], convert=False)
dense_shape = concat_shapes((sliced_src_shape if separate_independent else src_shape) & out_shape)
dense_shape = (sliced_src_shape if separate_independent else src_shape) & out_shape
max_rank = out_shape.volume - tracer.min_rank_deficiency()
matrix = SparseCoordinateTensor(indices, values, dense_shape, can_contain_double_entries=False, indices_sorted=False, indices_constant=True, m_rank=max_rank)
return matrix, tracer._bias
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import dataclasses

from ._shape import Shape, shape, channel, non_batch, batch, spatial, instance, concat_shapes, dual, PureShape, Dim, MixedShape, DEBUG_CHECKS
from ._shape import Shape, shape, channel, non_batch, batch, spatial, instance, concat_shapes, dual, PureShape, Dim, MixedShape, DEBUG_CHECKS, concat_shapes_
from ..backend._dtype import DType


Expand Down Expand Up @@ -745,7 +745,7 @@ def retype(self, dim_type: Callable, **kwargs):
`phiml.math.rename_dims()`
"""
s = shape(self.obj)
new_dims = concat_shapes(*[dim_type(**{dim: s.get_item_names(dim) or s.get_size(dim)}) for dim in self.dims])
new_dims = concat_shapes_(*[dim_type(**{dim: s.get_item_names(dim) or s.get_size(dim)}) for dim in self.dims])
from ._magic_ops import rename_dims
return rename_dims(self.obj, self.dims, new_dims, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion tests/commit/math/test__shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_filters(self):
self.assertEqual(spatial(obj), s)
self.assertEqual(instance(obj), i)
self.assertEqual(channel(obj), c)
self.assertEqual(set(non_batch(obj)), set(math.concat_shapes(s, i, c)))
self.assertEqual(set(non_batch(obj)), set(s + i + c))
self.assertEqual(set(non_spatial(obj)), set(math.concat_shapes(b, i, c)))
self.assertEqual(set(non_instance(obj)), set(math.concat_shapes(b, s, c)))
self.assertEqual(set(non_channel(obj)), set(math.concat_shapes(b, s, i)))
Expand Down

0 comments on commit e617a6b

Please sign in to comment.