From 9fbce871264c561bac250511ae5889994b08972c Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Wed, 22 Nov 2023 20:56:19 +0100 Subject: [PATCH] Improve at_max, at_min --- phiml/math/_nd.py | 26 ++++++++++---------- phiml/math/_ops.py | 44 ++++++++++++++++++---------------- phiml/math/_shape.py | 4 ++-- tests/commit/math/test__ops.py | 12 +++++----- 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/phiml/math/_nd.py b/phiml/math/_nd.py index b83d0741..3cf4a6ec 100644 --- a/phiml/math/_nd.py +++ b/phiml/math/_nd.py @@ -4,7 +4,7 @@ from . import _ops as math from . import extrapolation as extrapolation -from ._magic_ops import stack, rename_dims, concat, variable_values +from ._magic_ops import stack, rename_dims, concat, variable_values, tree_map from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, shape, instance, dual, auto from ._tensors import Tensor, wrap, tensor from .extrapolation import Extrapolation @@ -542,7 +542,6 @@ def join_index_offsets(offsets: Sequence[Union[int, Tensor]], negate=False): max_by_dim = {d: max([o[d] for o in offsets]) for d in dims} neg = -1 if negate else 1 result = [{d: (int(o[d] - min_by_dim[d]) * neg, int(max_by_dim[d] - o[d]) * neg) for d in dims} for o in offsets] - print(result) return offsets, result, min_by_dim, max_by_dim @@ -596,7 +595,7 @@ def neighbor_min(grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapo return neighbor_reduce(math.min_, grid, dims, padding) -def at_neighbor_where(reduce_fun: Callable, grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor: +def at_neighbor_where(reduce_fun: Callable, values, key_grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor: """ Computes the mean of two neighboring values along each dimension in `dim`. The result tensor has one entry less than `grid` in each averaged dimension unless `padding` is specified.. @@ -605,30 +604,31 @@ def at_neighbor_where(reduce_fun: Callable, grid: Tensor, dims: DimFilter = spat Args: reduce_fun: Reduction function, such as `at_max`, `at_min`. - grid: Values to average. + values: Values to look up and return. + key_grid: Values to compare. dims: Dimensions along which neighbors should be averaged. padding: Padding at the upper edges of `grid` along `dims'. If not `None`, the result tensor will have the same shape as `grid`. Returns: `Tensor` """ - result = grid - dims = grid.shape.only(dims) + result = key_grid + dims = key_grid.shape.only(dims) for dim in dims: lr = stack(shift(result, (0, 1), dim, padding, None), batch('_reduce')) - other_tensors = [stack(shift(t, (0, 1), dim, padding, None), batch('_reduce')) for t in other_tensors] - result, *other_tensors = reduce_fun(lr, '_reduce', lr, *other_tensors) - return other_tensors[0] if len(other_tensors) == 1 else other_tensors + values = tree_map(lambda t: stack(shift(t, (0, 1), dim, padding, None), batch('_reduce')), values) + result, values = reduce_fun([lr, values], lr, '_reduce') + return values -def at_max_neighbor(grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor: +def at_max_neighbor(values, key_grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor: """`at_neighbor_where` with `reduce_fun` set to `phiml.math.at_max`.""" - return at_neighbor_where(math.at_max, grid, dims, *other_tensors, padding=padding) + return at_neighbor_where(math.at_max, values, key_grid, dims, padding=padding) -def at_min_neighbor(grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor: +def at_min_neighbor(values, key_grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor: """`at_neighbor_where` with `reduce_fun` set to `phiml.math.at_min`.""" - return at_neighbor_where(math.at_min, grid, dims, *other_tensors, padding=padding) + return at_neighbor_where(math.at_min, values, key_grid, dims, padding=padding) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 337e0e45..d7090990 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -1663,46 +1663,50 @@ def finite_mean(value, dim: DimFilter = non_batch, default: Union[complex, float return where(is_finite(mean_nan), mean_nan, default) -def lookup_where(name: str, native_index_fn: Callable, value: Tensor, dim: DimFilter, other_tensors: Tuple[Tensor]): - assert other_tensors, f"Pass at least one additional tensor from which values will be returned" - dims = value.shape.only(dim) - keep = value.shape.without(dims) - assert dim, f"No dimensions {dim} present on value {value.shape}" - v_native = reshaped_native(value, [keep, dims]) +def lookup_where(name: str, native_index_fn: Callable, value: Union[Tensor, PhiTreeNode], key: Tensor, dim: DimFilter): + dims = key.shape.only(dim) + keep = key.shape.without(dims) + assert dim, f"No dimensions {dim} present on key {key.shape}" + v_native = reshaped_native(key, [keep, dims]) idx_native = native_index_fn(v_native) idx = reshaped_tensor(idx_native, [keep.as_batch(), channel(_index=dims)]) - result = tuple([rename_dims(rename_dims(t, keep, batch)[idx], keep.names, keep) for t in other_tensors]) - return result[0] if len(other_tensors) == 1 else result + return tree_map(lambda t: rename_dims(rename_dims(t, keep, batch)[idx], keep.names, keep), value) -def at_max(value: Tensor, dim: DimFilter = non_batch, *other_tensors: Tensor): +def at_max(value, key: Tensor, dim: DimFilter = non_batch): """ - Looks up the values of `other_tensors` at the positions where the maximum values in `value` are located along `dim`. + Looks up the values of `value` at the positions where the maximum values in `key` are located along `dim`. + + See Also: + `at_min`, `phiml.math.max`. Args: - value: `Tensor` containing at least one dimension of `dim`. - dim: Dimensions along which to compute the maximum of `value`. - *other_tensors: Tensors from which to lookup and return values. Must also contain `dim`. + value: Tensors or trees from which to lookup and return values. These tensors are indexed at the maximum index in `key´. + key: `Tensor` containing at least one dimension of `dim`. The maximum index of `key` is determined. + dim: Dimensions along which to compute the maximum of `key`. Returns: The values of `other_tensors` at the positions where the maximum values in `value` are located along `dim`. """ - return lookup_where("max", lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, dim, other_tensors) + return lookup_where("max", lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, key, dim) -def at_min(value: Tensor, dim: DimFilter = non_batch, *other_tensors: Tensor): +def at_min(value, key: Tensor, dim: DimFilter = non_batch): """ - Looks up the values of `other_tensors` at the positions where the minimum values in `value` are located along `dim`. + Looks up the values of `value` at the positions where the minimum values in `key` are located along `dim`. + + See Also: + `at_max`, `phiml.math.min`. Args: - value: `Tensor` containing at least one dimension of `dim`. - dim: Dimensions along which to compute the minimum of `value`. - *other_tensors: Tensors from which to lookup and return values. Must also contain `dim`. + value: Tensors or trees from which to lookup and return values. These tensors are indexed at the minimum index in `key´. + key: `Tensor` containing at least one dimension of `dim`. The minimum index of `key` is determined. + dim: Dimensions along which to compute the minimum of `key`. Returns: The values of `other_tensors` at the positions where the minimum values in `value` are located along `dim`. """ - return lookup_where("min", lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, dim, other_tensors) + return lookup_where("min", lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, key, dim) def quantile(value: Tensor, diff --git a/phiml/math/_shape.py b/phiml/math/_shape.py index 1c718e6d..bd7b3d49 100644 --- a/phiml/math/_shape.py +++ b/phiml/math/_shape.py @@ -1306,7 +1306,7 @@ def __init__(self, message, *shapes: Shape): self.shapes = shapes -def parse_dim_names(obj: Union[str, tuple, list, Shape], count: int) -> tuple: +def parse_dim_names(obj: Union[str, Sequence[str], Shape], count: int) -> tuple: if isinstance(obj, str): parts = obj.split(',') result = [] @@ -1323,7 +1323,7 @@ def parse_dim_names(obj: Union[str, tuple, list, Shape], count: int) -> tuple: elif isinstance(obj, Shape): assert len(obj) == count, f"Number of specified names in {obj} does not match number of dimensions ({count})" return obj.names - elif isinstance(obj, (tuple, list)): + elif isinstance(obj, Sequence): assert len(obj) == count, f"Number of specified names in {obj} does not match number of dimensions ({count})" return tuple(obj) raise ValueError(obj) diff --git a/tests/commit/math/test__ops.py b/tests/commit/math/test__ops.py index 2ca591f7..bc00605f 100644 --- a/tests/commit/math/test__ops.py +++ b/tests/commit/math/test__ops.py @@ -157,22 +157,22 @@ def test_at_max(self): with backend: # --- 1D --- x = math.range(spatial(x=11)) - math.assert_close(-math.max(x, 'x'), math.at_max(x, 'x', -x)) + math.assert_close(-math.max(x, 'x'), math.at_max([-x, x], x, 'x')[0]) # --- 2D --- x = math.meshgrid(x=3, y=3) - math.assert_close(-math.max(x, 'x'), math.at_max(x, 'x', -x)) - math.assert_close(-math.max(x, 'vector'), math.at_max(x, 'vector', -x)) + math.assert_close(-math.max(x, 'x'), math.at_max(-x, x, 'x')) + math.assert_close(-math.max(x, 'vector'), math.at_max(-x, x, 'vector')) def test_at_min(self): for backend in BACKENDS: with backend: # --- 1D --- x = math.range(spatial(x=11)) - math.assert_close(-math.min(x, 'x'), math.at_min(x, 'x', -x)) + math.assert_close(-math.min(x, 'x'), math.at_min(-x, x, 'x')) # --- 2D --- x = math.meshgrid(x=3, y=3) - math.assert_close(-math.min(x, 'x'), math.at_min(x, 'x', -x)) - math.assert_close(-math.min(x, 'vector'), math.at_min(x, 'vector', -x)) + math.assert_close(-math.min(x, 'x'), math.at_min(-x, x, 'x')) + math.assert_close(-math.min(x, 'vector'), math.at_min(-x, x, 'vector')) def test_unstack(self): a = math.random_uniform(batch(b=10), spatial(x=4, y=3), channel(vector=2))