Skip to content

Commit

Permalink
Add *sort aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 4, 2025
1 parent f093f47 commit f289dec
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
zeros_like, ones_like,
pad,
swap_axes, # reshape operations
sort,
sort, dsort, isort, ssort, csort,
safe_div,
where, nonzero, ravel_index,
sum_ as sum, finite_sum, dsum, isum, ssum, csum,
Expand Down
17 changes: 15 additions & 2 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,19 @@ def sort(x: TensorOrTree, dim: DimFilter = non_batch, key: Tensor = None) -> Ten
return slice_(x, perm)


dsort = functools.partial(sort, dim=dual)
dsort.__doc__ = """Sort tensor along dual dim, see `phiml.math.sort`."""

isort = functools.partial(sort, dim=instance)
isort.__doc__ = """Sort tensor along instance dim, see `phiml.math.sort`."""

ssort = functools.partial(sort, dim=spatial)
ssort.__doc__ = """Sort tensor along spatial dim, see `phiml.math.sort`."""

csort = functools.partial(sort, dim=channel)
csort.__doc__ = """Sort tensor along channel dim, see `phiml.math.sort`."""


def cumulative_sum(x: Tensor, dim: DimFilter, include_0=False, include_sum=True, index_dim: Union[str, Shape, None] = None):
"""
Performs a cumulative sum of `x` along `dim`.
Expand Down Expand Up @@ -1229,7 +1242,7 @@ def where(condition: Union[Tensor, float, int],
"""
Builds a tensor by choosing either values from `value_true` or `value_false` depending on `condition`.
If `condition` is not of type boolean, non-zero values are interpreted as True.
This function requires non-None values for `value_true` and `value_false`.
To get the indices of True / non-zero values, use :func:`nonzero`.
Expand Down Expand Up @@ -1286,7 +1299,7 @@ def inner_where(c: Tensor, vt: Tensor, vf: Tensor):
def nonzero(value: Tensor, list_dim: Union[Shape, str, int] = instance('nonzero'), index_dim: Shape = channel('vector'), element_dims: DimFilter = channel, list_dims: DimFilter = non_batch, preserve_names=False):
"""
Get spatial indices of non-zero / True values.
Batch dimensions are preserved by this operation.
If channel dimensions are present, this method returns the indices where any component is nonzero.
Expand Down

0 comments on commit f289dec

Please sign in to comment.