diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 6d5df7d..b39b12a 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -1678,11 +1678,7 @@ def lookup_where(native_index_fn: Callable, value: Union[Tensor, PhiTreeNode, No idx_native = native_index_fn(key_nat) multi_idx_native = choose_backend(idx_native).unravel_index(idx_native[:, 0], dims.sizes) idx = reshaped_tensor(multi_idx_native, [key_batch, channel(_index=dims)]) - def lookup(t: Tensor): - if t is None: - return None - return gather(t, idx, pref_index_dim='_index') - return tree_map(lookup, value) + return tree_map(lambda t: gather(t, idx, pref_index_dim='_index'), value) def at_max(value, key: Tensor, dim: DimFilter = non_batch): @@ -2485,6 +2481,8 @@ def gather(values, indices: Tensor, dims: Union[DimFilter, None] = None, pref_in Returns: `Tensor` with combined batch dimensions, channel dimensions of `values` and spatial/instance dimensions of `indices`. """ + if values is None: + return None if not isinstance(values, Tensor): return tree_map(lambda v: gather(v, indices, dims), values) index_dim = channel(indices)