Skip to content

Commit

Permalink
Sort by key
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 4, 2025
1 parent 98afea1 commit f093f47
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 16 deletions.
61 changes: 46 additions & 15 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,31 +586,62 @@ def swap_axes(x, axes):
return choose_backend(x).transpose(x, axes)


def sort(x: Tensor, dim: DimFilter = non_batch) -> Tensor:
def sort(x: TensorOrTree, dim: DimFilter = non_batch, key: Tensor = None) -> Tensor:
"""
Sort the values of `x` along `dim`.
If `key` is specified, sorts `x` according to the corresponding values in the `key` tensor.
When sorting by key, you can pass pytrees and dataclasses for `x`. The value `range` for `x` returns the sorting permutation.
In order to sort a flattened array, use `pack_dims` first.
Examples:
>>> x = tensor([1, 3, 2, -1], spatial('x'))
>>> math.sort(x)
>>> # Out: (-1, 1, 2, 3) along xˢ
>>> math.sort(range, 'x', key=x)
>>> # Out: (3, 0, 2, 1) along xˢ int64
>>> result, perm = math.sort((x, range), key=x)
Args:
x: `Tensor`
x: `Tensor` to sort. If `key` is specified, can be a tree as well.
dim: Dimension to sort. If not present, sorting will be skipped. Defaults to non-batch dim.
key: `Tensor` holding values to compare during sorting.
Returns:
Sorted `Tensor` or `x` if `x` is constant along `dims`.
"""
x_shape = x.shape
dim = x_shape.only(dim)
v_names = variable_dim_names(x)
if dim.name not in v_names:
return x # nothing to do; x is constant along dim
assert dim.rank == 1, f"Can only sort one dimension at a time. Use pack_dims() to jointly sort over multiple dimensions."
axis = v_names.index(dim.name)
x_native = x._native if isinstance(x, Dense) else x.native(x.shape)
sorted_native = x.backend.sort(x_native, axis=axis)
if x.shape.get_item_names(dim):
warnings.warn(f"sort() removes item names along sorted axis '{dim}'. Was {x.shape.get_item_names(dim)}", RuntimeWarning, stacklevel=2)
x_shape = x_shape.with_dim_size(dim, x_shape.get_size(dim), keep_item_names=False)
return Dense(sorted_native, v_names, x_shape, x.backend)
if key is None:
x_shape = x.shape
dim = x_shape.only(dim)
var_names = variable_dim_names(x)
if not dim or dim.name not in var_names:
return x # nothing to do; x is constant along dim
assert dim.rank == 1, f"Can only sort one dimension at a time. Use pack_dims() to jointly sort over multiple dimensions."
axis = var_names.index(dim.name)
x_native = x._native if isinstance(x, Dense) else x.native(x.shape)
sorted_native = x.backend.sort(x_native, axis=axis)
if x.shape.get_item_names(dim):
warnings.warn(f"sort() removes item names along sorted axis '{dim}'. Was {x.shape.get_item_names(dim)}", RuntimeWarning, stacklevel=2)
x_shape = x_shape.with_dim_size(dim, x_shape.get_size(dim), keep_item_names=False)
return Dense(sorted_native, var_names, x_shape, x.backend)
else:
k_shape = key.shape
dim = k_shape.only(dim)
var_names = variable_dim_names(key)
if not dim or dim.name not in var_names:
return x # nothing to do; key is constant along dim
assert dim.rank == 1, f"Can only sort one dimension at a time. Use pack_dims() to jointly sort over multiple dimensions."
axis = var_names.index(dim.name)
x_native = key._native if isinstance(key, Dense) else key.native(key.shape)
native_perm = key.backend.argsort(x_native, axis=axis)
if key.shape.get_item_names(dim):
warnings.warn(f"sort() removes item names along sorted axis '{dim}'. Was {x.shape.get_item_names(dim)}", RuntimeWarning, stacklevel=2)
k_shape = k_shape.with_dim_size(dim, k_shape.get_size(dim), keep_item_names=False) & channel(index=dim.name)
perm = Dense(native_perm, var_names, k_shape, key.backend)
return slice_(x, perm)


def cumulative_sum(x: Tensor, dim: DimFilter, include_0=False, include_sum=True, index_dim: Union[str, Shape, None] = None):
Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,7 +3004,7 @@ def is_scalar(value) -> bool:


def variable_shape(value):
return value._shape - value.collapsed_dims if isinstance(value, Dense) else shape(value)
return value._shape.only(value._names) if isinstance(value, Dense) else shape(value)


def variable_dim_names(value):
Expand Down
8 changes: 8 additions & 0 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,14 @@ def test_sort(self):
t = tensor([[3, 2, 1], [-1, 5, 2]], batch('b'), spatial(x='1,2,3'))
self.assertIsNone(math.sort(t).shape.get_item_names('x'))

def test_sort_by_key(self):
for b in BACKENDS:
with b:
x = tensor([1, 3, 2, -1], spatial('x'))
result, perm = math.sort((x, range), key=x)
math.assert_close([-1, 1, 2, 3], result)
math.assert_close([3, 0, 2, 1], perm)

def test_pairwise_differences(self):
for b in BACKENDS:
with b:
Expand Down

0 comments on commit f093f47

Please sign in to comment.