Skip to content

Commit

Permalink
Add tcat, improved type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl authored and holl- committed Oct 25, 2024
1 parent 22929ac commit 3146ea1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ._magic_ops import (
slice_ as slice, unstack,
stack, concat, ccat, expand,
stack, concat, ccat, tcat, expand,
rename_dims, rename_dims as replace_dims, pack_dims, unpack_dim, flatten,
b2i, c2b, c2d, i2b, s2b, si2d, d2i, d2s,
copy_with, replace, find_differences
Expand Down
37 changes: 33 additions & 4 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import partial
from numbers import Number
from typing import TypeVar, Tuple, Set, Dict, Union, Optional, Sequence, Any, get_origin, List, Iterable, get_args
from typing import TypeVar, Tuple, Set, Dict, Union, Optional, Sequence, Any, get_origin, List, Iterable, get_args, Callable

import dataclasses

Expand Down Expand Up @@ -116,7 +116,7 @@ def _any_uniform_dim(dims: Shape):
raise ValueError(f"Uniform dimension required but found only non-uniform dimensions {dims}")


def stack(values: Union[tuple, list, dict], dim: Union[Shape, str], expand_values=False, simplify=False, **kwargs):
def stack(values: Union[Sequence[PhiTreeNodeType], Dict[str, PhiTreeNodeType]], dim: Union[Shape, str], expand_values=False, simplify=False, **kwargs) -> PhiTreeNodeType:
"""
Stacks `values` along the new dimension `dim`.
All values must have the same spatial, instance and channel dimensions. If the dimension sizes vary, the resulting tensor will be non-uniform.
Expand Down Expand Up @@ -263,7 +263,7 @@ def stack(values: Union[tuple, list, dict], dim: Union[Shape, str], expand_value
return values[0]


def concat(values: Union[tuple, list], dim: Union[str, Shape], expand_values=False, **kwargs):
def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_values=False, **kwargs) -> PhiTreeNodeType:
"""
Concatenates a sequence of `phiml.math.magic.Shapable` objects, e.g. `Tensor`, along one dimension.
All values must have the same spatial, instance and channel dimensions and their sizes must be equal, except for `dim`.
Expand Down Expand Up @@ -350,7 +350,7 @@ def concat(values: Union[tuple, list], dim: Union[str, Shape], expand_values=Fal
raise MagicNotImplemented(f"concat: No value implemented __concat__ and slices could not be stacked. values = {[type(v) for v in values]}")


def ccat(values: Sequence, dim: Shape, expand_values=False):
def ccat(values: Sequence[PhiTreeNodeType], dim: Shape, expand_values=False) -> PhiTreeNodeType:
"""
Concatenate components along `dim`.
Expand Down Expand Up @@ -384,6 +384,35 @@ def ccat(values: Sequence, dim: Shape, expand_values=False):
return stack(components, dim, expand_values=expand_values)


def tcat(values: Sequence[PhiTreeNodeType], dim_type: Callable, expand_values=False, default_name='tcat') -> PhiTreeNodeType:
"""
Concatenate values by dim type.
This function first packs all dimensions of `dim_type` into one dim, then concatenates all `values`.
Values that do not have a dim of `dim_type` are considered a size-1 slice.
The name of the first matching dim of `dim_type` is used as the concatenated output dim name.
If no value has a matching dim, `default_name` is used instead.
Args:
values: Values to be concatenated.
dim_type: Dimension type along which to concatenate.
expand_values: Whether to add missing other non-batch dims to values as needed.
default_name: Concatenation dim name if none of the values have a matching dim.
Returns:
Same type as any value.
"""
dims = [dim_type(v) for v in values]
present_dims = [s for s in dims if s]
if present_dims:
dim_name = present_dims[0].name
else:
dim_name = default_name
single = dim_type(**{dim_name: 1})
flat_values = [pack_dims(v, dim_type, dim_type(dim_name)) if dim_name in s else expand(v, single) for v, s in zip(values, dims)]
return concat(flat_values, dim_name, expand_values=expand_values)


def expand(value, *dims: Union[Shape, str], **kwargs):
"""
Adds dimensions to a `Tensor` or tensor-like object by implicitly repeating the tensor values along the new dimensions.
Expand Down

0 comments on commit 3146ea1

Please sign in to comment.