Skip to content

Commit

Permalink
Fix non-uniform boolean_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Dec 4, 2024
1 parent 95808c6 commit b6e0d17
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2678,8 +2678,7 @@ def boolean_mask(x, dim: DimFilter, mask: Tensor, preserve_names=False):
return tree_map(boolean_mask, x, all_attributes, dim=dim, mask=mask, preserve_names=preserve_names, include_non_attrs=False, treat_layout_as_leaf=True)
if isinstance(x, Layout):
if dim.name == x._stack_dim.name:
np_mask = mask.numpy()
indices = np.nonzero(np_mask)[0]
indices = np.nonzero(mask.numpy())[0]
gathered = [x._obj[i] for i in indices]
size = len(gathered) if not preserve_names or x._stack_dim.item_names[0] is None else [x._stack_dim.item_names[0][i] for i in indices]
return Layout(gathered, dim.with_size(size))
Expand All @@ -2692,6 +2691,10 @@ def boolean_mask(x, dim: DimFilter, mask: Tensor, preserve_names=False):
keep_slices = nonzero_slices(mask)
x_slices = [x[s] for s in keep_slices]
return concat(x_slices, dim.name)
if isinstance(x, TensorStack) and dim.name in broadcast_dims(x):
indices = np.nonzero(mask.numpy())[0]
items = x._unstack(dim.name)
return TensorStack([items[i] for i in indices], dim)

def uniform_boolean_mask(x: Tensor, mask_1d: Tensor):
if dim in x.shape:
Expand Down

0 comments on commit b6e0d17

Please sign in to comment.