diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 33fbc9f..fbac5b3 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -586,31 +586,62 @@ def swap_axes(x, axes): return choose_backend(x).transpose(x, axes) -def sort(x: Tensor, dim: DimFilter = non_batch) -> Tensor: +def sort(x: TensorOrTree, dim: DimFilter = non_batch, key: Tensor = None) -> Tensor: """ Sort the values of `x` along `dim`. + If `key` is specified, sorts `x` according to the corresponding values in the `key` tensor. + When sorting by key, you can pass pytrees and dataclasses for `x`. The value `range` for `x` returns the sorting permutation. + In order to sort a flattened array, use `pack_dims` first. + Examples: + + >>> x = tensor([1, 3, 2, -1], spatial('x')) + >>> math.sort(x) + >>> # Out: (-1, 1, 2, 3) along xˢ + + >>> math.sort(range, 'x', key=x) + >>> # Out: (3, 0, 2, 1) along xˢ int64 + + >>> result, perm = math.sort((x, range), key=x) + Args: - x: `Tensor` + x: `Tensor` to sort. If `key` is specified, can be a tree as well. dim: Dimension to sort. If not present, sorting will be skipped. Defaults to non-batch dim. + key: `Tensor` holding values to compare during sorting. Returns: Sorted `Tensor` or `x` if `x` is constant along `dims`. """ - x_shape = x.shape - dim = x_shape.only(dim) - v_names = variable_dim_names(x) - if dim.name not in v_names: - return x # nothing to do; x is constant along dim - assert dim.rank == 1, f"Can only sort one dimension at a time. Use pack_dims() to jointly sort over multiple dimensions." - axis = v_names.index(dim.name) - x_native = x._native if isinstance(x, Dense) else x.native(x.shape) - sorted_native = x.backend.sort(x_native, axis=axis) - if x.shape.get_item_names(dim): - warnings.warn(f"sort() removes item names along sorted axis '{dim}'. Was {x.shape.get_item_names(dim)}", RuntimeWarning, stacklevel=2) - x_shape = x_shape.with_dim_size(dim, x_shape.get_size(dim), keep_item_names=False) - return Dense(sorted_native, v_names, x_shape, x.backend) + if key is None: + x_shape = x.shape + dim = x_shape.only(dim) + var_names = variable_dim_names(x) + if not dim or dim.name not in var_names: + return x # nothing to do; x is constant along dim + assert dim.rank == 1, f"Can only sort one dimension at a time. Use pack_dims() to jointly sort over multiple dimensions." + axis = var_names.index(dim.name) + x_native = x._native if isinstance(x, Dense) else x.native(x.shape) + sorted_native = x.backend.sort(x_native, axis=axis) + if x.shape.get_item_names(dim): + warnings.warn(f"sort() removes item names along sorted axis '{dim}'. Was {x.shape.get_item_names(dim)}", RuntimeWarning, stacklevel=2) + x_shape = x_shape.with_dim_size(dim, x_shape.get_size(dim), keep_item_names=False) + return Dense(sorted_native, var_names, x_shape, x.backend) + else: + k_shape = key.shape + dim = k_shape.only(dim) + var_names = variable_dim_names(key) + if not dim or dim.name not in var_names: + return x # nothing to do; key is constant along dim + assert dim.rank == 1, f"Can only sort one dimension at a time. Use pack_dims() to jointly sort over multiple dimensions." + axis = var_names.index(dim.name) + x_native = key._native if isinstance(key, Dense) else key.native(key.shape) + native_perm = key.backend.argsort(x_native, axis=axis) + if key.shape.get_item_names(dim): + warnings.warn(f"sort() removes item names along sorted axis '{dim}'. Was {x.shape.get_item_names(dim)}", RuntimeWarning, stacklevel=2) + k_shape = k_shape.with_dim_size(dim, k_shape.get_size(dim), keep_item_names=False) & channel(index=dim.name) + perm = Dense(native_perm, var_names, k_shape, key.backend) + return slice_(x, perm) def cumulative_sum(x: Tensor, dim: DimFilter, include_0=False, include_sum=True, index_dim: Union[str, Shape, None] = None): diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 2da531f..d17ea90 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -3004,7 +3004,7 @@ def is_scalar(value) -> bool: def variable_shape(value): - return value._shape - value.collapsed_dims if isinstance(value, Dense) else shape(value) + return value._shape.only(value._names) if isinstance(value, Dense) else shape(value) def variable_dim_names(value): diff --git a/tests/commit/math/test__ops.py b/tests/commit/math/test__ops.py index 7393d59..87e5912 100644 --- a/tests/commit/math/test__ops.py +++ b/tests/commit/math/test__ops.py @@ -1052,6 +1052,14 @@ def test_sort(self): t = tensor([[3, 2, 1], [-1, 5, 2]], batch('b'), spatial(x='1,2,3')) self.assertIsNone(math.sort(t).shape.get_item_names('x')) + def test_sort_by_key(self): + for b in BACKENDS: + with b: + x = tensor([1, 3, 2, -1], spatial('x')) + result, perm = math.sort((x, range), key=x) + math.assert_close([-1, 1, 2, 3], result) + math.assert_close([3, 0, 2, 1], perm) + def test_pairwise_differences(self): for b in BACKENDS: with b: