Skip to content

Commit

Permalink
gather(None) now returns None
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Sep 27, 2024
1 parent 71f9b02 commit 9807048
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,11 +1678,7 @@ def lookup_where(native_index_fn: Callable, value: Union[Tensor, PhiTreeNode, No
idx_native = native_index_fn(key_nat)
multi_idx_native = choose_backend(idx_native).unravel_index(idx_native[:, 0], dims.sizes)
idx = reshaped_tensor(multi_idx_native, [key_batch, channel(_index=dims)])
def lookup(t: Tensor):
if t is None:
return None
return gather(t, idx, pref_index_dim='_index')
return tree_map(lookup, value)
return tree_map(lambda t: gather(t, idx, pref_index_dim='_index'), value)


def at_max(value, key: Tensor, dim: DimFilter = non_batch):
Expand Down Expand Up @@ -2485,6 +2481,8 @@ def gather(values, indices: Tensor, dims: Union[DimFilter, None] = None, pref_in
Returns:
`Tensor` with combined batch dimensions, channel dimensions of `values` and spatial/instance dimensions of `indices`.
"""
if values is None:
return None
if not isinstance(values, Tensor):
return tree_map(lambda v: gather(v, indices, dims), values)
index_dim = channel(indices)
Expand Down

0 comments on commit 9807048

Please sign in to comment.