Skip to content

Commit

Permalink
Fix stack() for highly non-uniform tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 13, 2023
1 parent a305fd2 commit 15f9620
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,11 @@ def broadcast_op(operation: Callable,
iter_dims.update(tensor.shape.shape.without('dims').names)
if isinstance(tensor, TensorStack) and tensor.requires_broadcast:
iter_dims.add(tensor._stack_dim.name)
# --- remove iter_dims for which the sizes vary among tensors ---
for dim in tuple(iter_dims):
sizes = [t.shape.get_size(dim) if dim in t.shape else None for t in tensors]
if not all(s == sizes[0] for s in sizes[1:]):
iter_dims.remove(dim)
if len(iter_dims) == 0:
return operation(*tensors)
else:
Expand Down

0 comments on commit 15f9620

Please sign in to comment.