Skip to content

Commit

Permalink
Improve at_max, at_min
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 22, 2023
1 parent 8f3f646 commit 9fbce87
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 41 deletions.
26 changes: 13 additions & 13 deletions phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import _ops as math
from . import extrapolation as extrapolation
from ._magic_ops import stack, rename_dims, concat, variable_values
from ._magic_ops import stack, rename_dims, concat, variable_values, tree_map
from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, shape, instance, dual, auto
from ._tensors import Tensor, wrap, tensor
from .extrapolation import Extrapolation
Expand Down Expand Up @@ -542,7 +542,6 @@ def join_index_offsets(offsets: Sequence[Union[int, Tensor]], negate=False):
max_by_dim = {d: max([o[d] for o in offsets]) for d in dims}
neg = -1 if negate else 1
result = [{d: (int(o[d] - min_by_dim[d]) * neg, int(max_by_dim[d] - o[d]) * neg) for d in dims} for o in offsets]
print(result)
return offsets, result, min_by_dim, max_by_dim


Expand Down Expand Up @@ -596,7 +595,7 @@ def neighbor_min(grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapo
return neighbor_reduce(math.min_, grid, dims, padding)


def at_neighbor_where(reduce_fun: Callable, grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
def at_neighbor_where(reduce_fun: Callable, values, key_grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""
Computes the mean of two neighboring values along each dimension in `dim`.
The result tensor has one entry less than `grid` in each averaged dimension unless `padding` is specified..
Expand All @@ -605,30 +604,31 @@ def at_neighbor_where(reduce_fun: Callable, grid: Tensor, dims: DimFilter = spat
Args:
reduce_fun: Reduction function, such as `at_max`, `at_min`.
grid: Values to average.
values: Values to look up and return.
key_grid: Values to compare.
dims: Dimensions along which neighbors should be averaged.
padding: Padding at the upper edges of `grid` along `dims'. If not `None`, the result tensor will have the same shape as `grid`.
Returns:
`Tensor`
"""
result = grid
dims = grid.shape.only(dims)
result = key_grid
dims = key_grid.shape.only(dims)
for dim in dims:
lr = stack(shift(result, (0, 1), dim, padding, None), batch('_reduce'))
other_tensors = [stack(shift(t, (0, 1), dim, padding, None), batch('_reduce')) for t in other_tensors]
result, *other_tensors = reduce_fun(lr, '_reduce', lr, *other_tensors)
return other_tensors[0] if len(other_tensors) == 1 else other_tensors
values = tree_map(lambda t: stack(shift(t, (0, 1), dim, padding, None), batch('_reduce')), values)
result, values = reduce_fun([lr, values], lr, '_reduce')
return values


def at_max_neighbor(grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
def at_max_neighbor(values, key_grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""`at_neighbor_where` with `reduce_fun` set to `phiml.math.at_max`."""
return at_neighbor_where(math.at_max, grid, dims, *other_tensors, padding=padding)
return at_neighbor_where(math.at_max, values, key_grid, dims, padding=padding)


def at_min_neighbor(grid: Tensor, dims: DimFilter = spatial, *other_tensors: Tensor, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
def at_min_neighbor(values, key_grid: Tensor, dims: DimFilter = spatial, padding: Union[Extrapolation, float, Tensor, str, None] = None) -> Tensor:
"""`at_neighbor_where` with `reduce_fun` set to `phiml.math.at_min`."""
return at_neighbor_where(math.at_min, grid, dims, *other_tensors, padding=padding)
return at_neighbor_where(math.at_min, values, key_grid, dims, padding=padding)



Expand Down
44 changes: 24 additions & 20 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,46 +1663,50 @@ def finite_mean(value, dim: DimFilter = non_batch, default: Union[complex, float
return where(is_finite(mean_nan), mean_nan, default)


def lookup_where(name: str, native_index_fn: Callable, value: Tensor, dim: DimFilter, other_tensors: Tuple[Tensor]):
assert other_tensors, f"Pass at least one additional tensor from which values will be returned"
dims = value.shape.only(dim)
keep = value.shape.without(dims)
assert dim, f"No dimensions {dim} present on value {value.shape}"
v_native = reshaped_native(value, [keep, dims])
def lookup_where(name: str, native_index_fn: Callable, value: Union[Tensor, PhiTreeNode], key: Tensor, dim: DimFilter):
dims = key.shape.only(dim)
keep = key.shape.without(dims)
assert dim, f"No dimensions {dim} present on key {key.shape}"
v_native = reshaped_native(key, [keep, dims])
idx_native = native_index_fn(v_native)
idx = reshaped_tensor(idx_native, [keep.as_batch(), channel(_index=dims)])
result = tuple([rename_dims(rename_dims(t, keep, batch)[idx], keep.names, keep) for t in other_tensors])
return result[0] if len(other_tensors) == 1 else result
return tree_map(lambda t: rename_dims(rename_dims(t, keep, batch)[idx], keep.names, keep), value)


def at_max(value: Tensor, dim: DimFilter = non_batch, *other_tensors: Tensor):
def at_max(value, key: Tensor, dim: DimFilter = non_batch):
"""
Looks up the values of `other_tensors` at the positions where the maximum values in `value` are located along `dim`.
Looks up the values of `value` at the positions where the maximum values in `key` are located along `dim`.
See Also:
`at_min`, `phiml.math.max`.
Args:
value: `Tensor` containing at least one dimension of `dim`.
dim: Dimensions along which to compute the maximum of `value`.
*other_tensors: Tensors from which to lookup and return values. Must also contain `dim`.
value: Tensors or trees from which to lookup and return values. These tensors are indexed at the maximum index in `key´.
key: `Tensor` containing at least one dimension of `dim`. The maximum index of `key` is determined.
dim: Dimensions along which to compute the maximum of `key`.
Returns:
The values of `other_tensors` at the positions where the maximum values in `value` are located along `dim`.
"""
return lookup_where("max", lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, dim, other_tensors)
return lookup_where("max", lambda v: choose_backend(v).argmax(v, 1, keepdims=True), value, key, dim)


def at_min(value: Tensor, dim: DimFilter = non_batch, *other_tensors: Tensor):
def at_min(value, key: Tensor, dim: DimFilter = non_batch):
"""
Looks up the values of `other_tensors` at the positions where the minimum values in `value` are located along `dim`.
Looks up the values of `value` at the positions where the minimum values in `key` are located along `dim`.
See Also:
`at_max`, `phiml.math.min`.
Args:
value: `Tensor` containing at least one dimension of `dim`.
dim: Dimensions along which to compute the minimum of `value`.
*other_tensors: Tensors from which to lookup and return values. Must also contain `dim`.
value: Tensors or trees from which to lookup and return values. These tensors are indexed at the minimum index in `key´.
key: `Tensor` containing at least one dimension of `dim`. The minimum index of `key` is determined.
dim: Dimensions along which to compute the minimum of `key`.
Returns:
The values of `other_tensors` at the positions where the minimum values in `value` are located along `dim`.
"""
return lookup_where("min", lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, dim, other_tensors)
return lookup_where("min", lambda v: choose_backend(v).argmin(v, 1, keepdims=True), value, key, dim)


def quantile(value: Tensor,
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ def __init__(self, message, *shapes: Shape):
self.shapes = shapes


def parse_dim_names(obj: Union[str, tuple, list, Shape], count: int) -> tuple:
def parse_dim_names(obj: Union[str, Sequence[str], Shape], count: int) -> tuple:
if isinstance(obj, str):
parts = obj.split(',')
result = []
Expand All @@ -1323,7 +1323,7 @@ def parse_dim_names(obj: Union[str, tuple, list, Shape], count: int) -> tuple:
elif isinstance(obj, Shape):
assert len(obj) == count, f"Number of specified names in {obj} does not match number of dimensions ({count})"
return obj.names
elif isinstance(obj, (tuple, list)):
elif isinstance(obj, Sequence):
assert len(obj) == count, f"Number of specified names in {obj} does not match number of dimensions ({count})"
return tuple(obj)
raise ValueError(obj)
Expand Down
12 changes: 6 additions & 6 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,22 @@ def test_at_max(self):
with backend:
# --- 1D ---
x = math.range(spatial(x=11))
math.assert_close(-math.max(x, 'x'), math.at_max(x, 'x', -x))
math.assert_close(-math.max(x, 'x'), math.at_max([-x, x], x, 'x')[0])
# --- 2D ---
x = math.meshgrid(x=3, y=3)
math.assert_close(-math.max(x, 'x'), math.at_max(x, 'x', -x))
math.assert_close(-math.max(x, 'vector'), math.at_max(x, 'vector', -x))
math.assert_close(-math.max(x, 'x'), math.at_max(-x, x, 'x'))
math.assert_close(-math.max(x, 'vector'), math.at_max(-x, x, 'vector'))

def test_at_min(self):
for backend in BACKENDS:
with backend:
# --- 1D ---
x = math.range(spatial(x=11))
math.assert_close(-math.min(x, 'x'), math.at_min(x, 'x', -x))
math.assert_close(-math.min(x, 'x'), math.at_min(-x, x, 'x'))
# --- 2D ---
x = math.meshgrid(x=3, y=3)
math.assert_close(-math.min(x, 'x'), math.at_min(x, 'x', -x))
math.assert_close(-math.min(x, 'vector'), math.at_min(x, 'vector', -x))
math.assert_close(-math.min(x, 'x'), math.at_min(-x, x, 'x'))
math.assert_close(-math.min(x, 'vector'), math.at_min(-x, x, 'vector'))

def test_unstack(self):
a = math.random_uniform(batch(b=10), spatial(x=4, y=3), channel(vector=2))
Expand Down

0 comments on commit 9fbce87

Please sign in to comment.