From 1c4d8ff6b490761f3bb6e72c77bbf89bd15e0042 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Fri, 24 Nov 2023 17:55:41 +0100 Subject: [PATCH] Fix at_max, at_min for multi-dim reduce --- phiml/math/_ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index b7df58d1..ddfd304d 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -1663,13 +1663,14 @@ def finite_mean(value, dim: DimFilter = non_batch, default: Union[complex, float return where(is_finite(mean_nan), mean_nan, default) -def lookup_where(name: str, native_index_fn: Callable, value: Union[Tensor, PhiTreeNode], key: Tensor, dim: DimFilter): +def lookup_where(native_index_fn: Callable, value: Union[Tensor, PhiTreeNode], key: Tensor, dim: DimFilter): dims = key.shape.only(dim) keep = key.shape.without(dims) assert dim, f"No dimensions {dim} present on key {key.shape}" v_native = reshaped_native(key, [keep, dims]) idx_native = native_index_fn(v_native) - idx = reshaped_tensor(idx_native, [keep.as_batch(), channel(_index=dims)]) + multi_idx_native = choose_backend(idx_native).unravel_index(idx_native[:, 0], dims.sizes) + idx = reshaped_tensor(multi_idx_native, [keep.as_batch(), channel(_index=dims)]) def lookup(t: Tensor): keep_t = t.shape.without(dims) sel = rename_dims(t, keep_t, batch)[idx] @@ -1693,7 +1694,7 @@ def at_max(value, key: Tensor, dim: DimFilter = non_batch): Returns: The values of `other_tensors` at the positions where the maximum values in `value` are located along `dim`. """ - return lookup_where("max", lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, key, dim) + return lookup_where(lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, key, dim) def at_min(value, key: Tensor, dim: DimFilter = non_batch): @@ -1711,7 +1712,7 @@ def at_min(value, key: Tensor, dim: DimFilter = non_batch): Returns: The values of `other_tensors` at the positions where the minimum values in `value` are located along `dim`. """ - return lookup_where("min", lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, key, dim) + return lookup_where(lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, key, dim) def quantile(value: Tensor,