Skip to content

Commit

Permalink
Fix Shape.after_gather()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 26, 2024
1 parent 5d94879 commit a0a15e3
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,18 +1302,19 @@ def resolve_index(self, index: Dict[str, Union[slice, int, 'Shape', str, tuple,
return {dim: self.prepare_gather(dim, s) for dim, s in index.items()}

def after_gather(self, selection: dict) -> 'Shape':
result = self
if self.is_non_uniform:
from . import Tensor
sizes = [(s[selection] if isinstance(s, Tensor) else s) for s in self.sizes]
sizes = [(int(s) if isinstance(s, Tensor) and s.rank == 0 else s) for s in sizes]
result = self.with_sizes(sizes)
else:
result = self
for sel_dim, selection in selection.items():
# Even if sel_dim not part of self, can still be part of non_uniform_shape
if sel_dim not in self.names:
continue
selection = self.prepare_gather(sel_dim, selection)
if isinstance(selection, int):
if result.is_uniform:
result = result.without(sel_dim)
else:
from . import Tensor
gathered_sizes = [(s[{sel_dim: selection}] if isinstance(s, Tensor) else s) for s in result.sizes]
gathered_sizes = [(int(s) if isinstance(s, Tensor) and s.rank == 0 else s) for s in gathered_sizes]
result = result.with_sizes(gathered_sizes, keep_item_names=True).without(sel_dim)
result = result.without(sel_dim)
elif isinstance(selection, slice):
step = int(selection.step) if selection.step is not None else 1
start = int(selection.start) if selection.start is not None else (0 if step > 0 else self.get_size(sel_dim)-1)
Expand Down

0 comments on commit a0a15e3

Please sign in to comment.