From 498e163814d8019587d231e3b7accf5273d3a07d Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 13 Jan 2025 17:57:12 +0100 Subject: [PATCH] Fix concat() for values without concat dim --- phiml/math/_magic_ops.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index 547310a..7825582 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -325,20 +325,22 @@ def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_val else: dim = auto(dim, channel).name # --- Filter 0-length values --- - values = [v for v in values if shape(v).get_size(dim) > 0] + shapes = [shape(v) for v in values] + values = [v for v, s in zip(values, shapes) if dim not in s or s.get_size(dim) > 0] if len(values) == 1: return values[0] + shapes = [s for s in shapes if dim not in s or s.get_size(dim) > 0] # --- Add missing dimensions --- if expand_values: - all_dims = merge_shapes(*values, allow_varying_sizes=True) + all_dims = merge_shapes(*shapes, allow_varying_sizes=True) all_dims = all_dims.with_dim_size(dim, 1, keep_item_names=False) - values = [expand(v, all_dims.without(shape(v))) for v in values] + values = [expand(v, all_dims - s) for v, s in zip(values, shapes)] else: - for v in values: - assert dim in shape(v), f"concat dim '{dim}' must be present in the shapes of all values bot got value {type(v).__name__} with shape {shape(v)}" + for v, s in zip(values, shapes): + assert dim in s, f"concat dim '{dim}' must be present in the shapes of all values bot got value {type(v).__name__} with shape {s}" for v in values[1:]: assert set(non_batch(v).names) == set(non_batch(values[0]).names), f"Concatenated values must have the same non-batch dims but got {non_batch(values[0])} and {non_batch(v)}" - all_batch_dims = merge_shapes(*[shape(v).batch.without(dim) for v in values]) + all_batch_dims = merge_shapes(*[s.batch - dim for s in shapes]) values = [expand(v, all_batch_dims) for v in values] # --- First try __concat__ --- for v in values: @@ -367,7 +369,7 @@ def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_val raise MagicNotImplemented(f"concat: No value implemented __concat__ and not all values were Sliceable along {dim}. values = {[type(v) for v in values]}") if len(unstacked) > 8: warnings.warn(f"concat() default implementation is slow on large dims ({dim}={len(unstacked)}). Please implement __concat__()", RuntimeWarning, stacklevel=2) - dim = shape(values[0])[dim].with_size(None) + dim = shapes[0][dim].with_size(None) try: return stack(unstacked, dim, **kwargs) except MagicNotImplemented: