Skip to content

Commit

Permalink
Add squeeze()
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Dec 10, 2024
1 parent 14ffa96 commit c64a95a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ._magic_ops import (
slice_ as slice, unstack,
stack, concat, ncat, tcat, ccat, scat, icat, dcat, expand,
rename_dims, rename_dims as replace_dims, pack_dims, dpack, ipack, spack, cpack, unpack_dim, flatten,
rename_dims, rename_dims as replace_dims, pack_dims, dpack, ipack, spack, cpack, unpack_dim, flatten, squeeze,
b2i, c2b, c2d, i2b, s2b, si2d, d2i, d2s,
copy_with, replace, find_differences
)
Expand Down
18 changes: 18 additions & 0 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,24 @@ def expand(value, *dims: Union[Shape, str], **kwargs):
return expand_tensor(value, dims)


def squeeze(x: PhiTreeNodeType, dims: DimFilter) -> PhiTreeNodeType:
"""
Remove specific singleton (volume=1) dims from `x`.
Args:
x: Tensor or composite type / tree.
dims: Singleton dims to remove.
Returns:
Same type as `x`.
"""
dims = shape(x).only(dims)
if not dims:
return x
assert dims.volume == 1, f"Cannot squeeze non-singleton dims {dims} from {x}"
return x[{d: 0 for d in dims.names}]


def rename_dims(value: PhiTreeNodeType,
dims: DimFilter,
names: DimFilter,
Expand Down
3 changes: 2 additions & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..backend import default_backend, choose_backend, Backend, get_precision, convert as b_convert, BACKENDS, NoBackendFound, ComputeDevice, NUMPY
from ..backend._dtype import DType, combine_types
from .magic import PhiTreeNode
from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes
from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze
from ._shape import (Shape, EMPTY_SHAPE,
spatial, batch, channel, instance, merge_shapes, parse_dim_order, concat_shapes,
IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual)
Expand Down Expand Up @@ -861,6 +861,7 @@ def stack_tensors(values: Union[tuple, list], dim: Shape):
if len(values) == 1 and not dim:
return values[0]
values = [wrap(v) for v in values]
values = [squeeze(v, dim) for v in values]
values = cast_same(*values)
# --- sparse to dense ---
if any(isinstance(t, (SparseCoordinateTensor, CompressedSparseMatrix)) for t in values) and not all(isinstance(t, (SparseCoordinateTensor, CompressedSparseMatrix)) for t in values):
Expand Down

0 comments on commit c64a95a

Please sign in to comment.