diff --git a/phiml/math/__init__.py b/phiml/math/__init__.py index e490e16..b56118f 100644 --- a/phiml/math/__init__.py +++ b/phiml/math/__init__.py @@ -48,7 +48,7 @@ zeros_like, ones_like, pad, swap_axes, # reshape operations - sort, + sort, dsort, isort, ssort, csort, safe_div, where, nonzero, ravel_index, sum_ as sum, finite_sum, dsum, isum, ssum, csum, diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index fbac5b3..c8a73b3 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -644,6 +644,19 @@ def sort(x: TensorOrTree, dim: DimFilter = non_batch, key: Tensor = None) -> Ten return slice_(x, perm) +dsort = functools.partial(sort, dim=dual) +dsort.__doc__ = """Sort tensor along dual dim, see `phiml.math.sort`.""" + +isort = functools.partial(sort, dim=instance) +isort.__doc__ = """Sort tensor along instance dim, see `phiml.math.sort`.""" + +ssort = functools.partial(sort, dim=spatial) +ssort.__doc__ = """Sort tensor along spatial dim, see `phiml.math.sort`.""" + +csort = functools.partial(sort, dim=channel) +csort.__doc__ = """Sort tensor along channel dim, see `phiml.math.sort`.""" + + def cumulative_sum(x: Tensor, dim: DimFilter, include_0=False, include_sum=True, index_dim: Union[str, Shape, None] = None): """ Performs a cumulative sum of `x` along `dim`. @@ -1229,7 +1242,7 @@ def where(condition: Union[Tensor, float, int], """ Builds a tensor by choosing either values from `value_true` or `value_false` depending on `condition`. If `condition` is not of type boolean, non-zero values are interpreted as True. - + This function requires non-None values for `value_true` and `value_false`. To get the indices of True / non-zero values, use :func:`nonzero`. @@ -1286,7 +1299,7 @@ def inner_where(c: Tensor, vt: Tensor, vf: Tensor): def nonzero(value: Tensor, list_dim: Union[Shape, str, int] = instance('nonzero'), index_dim: Shape = channel('vector'), element_dims: DimFilter = channel, list_dims: DimFilter = non_batch, preserve_names=False): """ Get spatial indices of non-zero / True values. - + Batch dimensions are preserved by this operation. If channel dimensions are present, this method returns the indices where any component is nonzero.