Skip to content

Commit

Permalink
Fix at_max, at_min for multi-dim reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 24, 2023
1 parent e952a35 commit 1c4d8ff
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,13 +1663,14 @@ def finite_mean(value, dim: DimFilter = non_batch, default: Union[complex, float
return where(is_finite(mean_nan), mean_nan, default)


def lookup_where(name: str, native_index_fn: Callable, value: Union[Tensor, PhiTreeNode], key: Tensor, dim: DimFilter):
def lookup_where(native_index_fn: Callable, value: Union[Tensor, PhiTreeNode], key: Tensor, dim: DimFilter):
dims = key.shape.only(dim)
keep = key.shape.without(dims)
assert dim, f"No dimensions {dim} present on key {key.shape}"
v_native = reshaped_native(key, [keep, dims])
idx_native = native_index_fn(v_native)
idx = reshaped_tensor(idx_native, [keep.as_batch(), channel(_index=dims)])
multi_idx_native = choose_backend(idx_native).unravel_index(idx_native[:, 0], dims.sizes)
idx = reshaped_tensor(multi_idx_native, [keep.as_batch(), channel(_index=dims)])
def lookup(t: Tensor):
keep_t = t.shape.without(dims)
sel = rename_dims(t, keep_t, batch)[idx]
Expand All @@ -1693,7 +1694,7 @@ def at_max(value, key: Tensor, dim: DimFilter = non_batch):
Returns:
The values of `other_tensors` at the positions where the maximum values in `value` are located along `dim`.
"""
return lookup_where("max", lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, key, dim)
return lookup_where(lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, key, dim)


def at_min(value, key: Tensor, dim: DimFilter = non_batch):
Expand All @@ -1711,7 +1712,7 @@ def at_min(value, key: Tensor, dim: DimFilter = non_batch):
Returns:
The values of `other_tensors` at the positions where the minimum values in `value` are located along `dim`.
"""
return lookup_where("min", lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, key, dim)
return lookup_where(lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, key, dim)


def quantile(value: Tensor,
Expand Down

0 comments on commit 1c4d8ff

Please sign in to comment.