Skip to content

Commit

Permalink
Improved gather error message
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 16, 2024
1 parent d8f00ce commit 8ba23b8
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,8 @@ def gather(values, indices: Tensor, dims: Union[DimFilter, None] = None):
assert dims, f"No indexing dimensions for tensor {values.shape} given indices {indices.shape}"
if dims not in values.shape:
return expand(values, non_channel(indices))
assert channel(indices).rank == 1, f"indices must have a single channel dimension listing the indexed dims {dims} but got {indices.shape}."
assert channel(indices).volume == len(dims), f"channel dimension of indices must have size equal to the number of indexed dims {dims} but got {channel(indices)}"
if values._is_tracer or is_sparse(values):
if not channel(indices):
indices = expand(indices, channel(gather=dims))
Expand Down

0 comments on commit 8ba23b8

Please sign in to comment.