From 15f9620259501885bb125ba5749f1309272cf9fc Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 13 Nov 2023 20:42:54 +0100 Subject: [PATCH] Fix stack() for highly non-uniform tensors --- phiml/math/_ops.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 9141f3f7..bcd973ff 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -1100,6 +1100,11 @@ def broadcast_op(operation: Callable, iter_dims.update(tensor.shape.shape.without('dims').names) if isinstance(tensor, TensorStack) and tensor.requires_broadcast: iter_dims.add(tensor._stack_dim.name) + # --- remove iter_dims for which the sizes vary among tensors --- + for dim in tuple(iter_dims): + sizes = [t.shape.get_size(dim) if dim in t.shape else None for t in tensors] + if not all(s == sizes[0] for s in sizes[1:]): + iter_dims.remove(dim) if len(iter_dims) == 0: return operation(*tensors) else: