Skip to content

Commit

Permalink
Trace sparse matrices in linear functions (part 4)
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 10, 2023
1 parent e1fdadf commit c87517a
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 119 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,7 +2314,7 @@ def gather(values: Tensor, indices: Tensor, dims: Union[DimFilter, None] = None)
indices = expand(indices, channel(gather=dims))
if not channel(indices).item_names[0]:
indices = indices._with_shape_replaced(indices.shape.with_dim_size(channel(indices), dims))
return values.to_sparse_tracer().gather(indices)
return values._gather(indices)
treat_as_batch = non_channel(indices).only(values.shape).without(dims)
batch_ = (values.shape.batch & indices.shape.batch).without(dims) & treat_as_batch
channel_ = values.shape.without(dims).without(batch_)
Expand Down
Loading

0 comments on commit c87517a

Please sign in to comment.