Skip to content

Commit

Permalink
Fix pick_random for count=1
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Jan 13, 2025
1 parent 498e163 commit 6adb86e
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion phiml/math/perm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..backend import default_backend
from ._shape import concat_shapes_, batch, DimFilter, Shape, SHAPE_TYPES, shape, non_batch, channel, dual, primal
from ._magic_ops import unpack_dim, expand, stack, slice_
from ._magic_ops import unpack_dim, expand, stack, slice_, squeeze
from ._tensors import reshaped_tensor, TensorOrTree, Tensor, wrap
from ._ops import unravel_index, psum, dmin

Expand Down Expand Up @@ -84,6 +84,8 @@ def pick_random(value: TensorOrTree, dim: DimFilter, count: Union[int, Shape, No
else:
raise ValueError(f"Cannot pick random from empty tensor {u_dim}")
idx = wrap(np_idx, count if isinstance(count, SHAPE_TYPES) else u_dim.without_sizes())
if count == 1:
idx = squeeze(count, u_dim)
# idx = ravel_index()
idx_slices.append(expand(idx, channel(index=u_dim.name)))
idx = stack(idx_slices, nu_dims)
Expand Down

0 comments on commit 6adb86e

Please sign in to comment.