Skip to content

Commit

Permalink
Fix boolean_mask() for Layouts, None
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Dec 16, 2024
1 parent 331b1a9 commit 4de3ed7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2686,18 +2686,23 @@ def boolean_mask(x, dim: DimFilter, mask: Tensor, preserve_names=False):
Returns:
Selected values of `x` as `Tensor` with dimensions from `x` and `mask`.
"""
if x is None:
return None
dim, original_dim = shape(mask).only(dim), dim
assert dim, f"mask dimension '{original_dim}' must be present on the mask {mask.shape}"
assert dim.rank == 1, f"boolean mask only supports 1D selection"
if not isinstance(x, Tensor) and isinstance(x, PhiTreeNode):
return tree_map(boolean_mask, x, all_attributes, dim=dim, mask=mask, preserve_names=preserve_names, include_non_attrs=False, treat_layout_as_leaf=True)
if isinstance(x, Layout):
if dim.name == x._stack_dim.name:
if x._stack_dim.without(dim):
from ._functional import map_
return map_(boolean_mask, x, dims=x._stack_dim - dim, dim=dim - x._stack_dim, mask=mask, preserve_names=preserve_names)
if dim in x._stack_dim:
indices = np.nonzero(mask.numpy())[0]
gathered = [x._obj[i] for i in indices]
size = len(gathered) if not preserve_names or x._stack_dim.item_names[0] is None else [x._stack_dim.item_names[0][i] for i in indices]
return Layout(gathered, dim.with_size(size))
return tree_map(boolean_mask, x, all_attributes, dim=dim, mask=mask, preserve_names=preserve_names, include_non_attrs=False, treat_layout_as_leaf=True)
raise NotImplementedError
if is_sparse(x):
indices = nonzero(mask, list_dim=instance('_boolean_mask'))
result = x[indices]
Expand Down

0 comments on commit 4de3ed7

Please sign in to comment.