Skip to content

Commit

Permalink
Fix concat() for values without concat dim
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Jan 13, 2025
1 parent 7831b6e commit 498e163
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,20 +325,22 @@ def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_val
else:
dim = auto(dim, channel).name
# --- Filter 0-length values ---
values = [v for v in values if shape(v).get_size(dim) > 0]
shapes = [shape(v) for v in values]
values = [v for v, s in zip(values, shapes) if dim not in s or s.get_size(dim) > 0]
if len(values) == 1:
return values[0]
shapes = [s for s in shapes if dim not in s or s.get_size(dim) > 0]
# --- Add missing dimensions ---
if expand_values:
all_dims = merge_shapes(*values, allow_varying_sizes=True)
all_dims = merge_shapes(*shapes, allow_varying_sizes=True)
all_dims = all_dims.with_dim_size(dim, 1, keep_item_names=False)
values = [expand(v, all_dims.without(shape(v))) for v in values]
values = [expand(v, all_dims - s) for v, s in zip(values, shapes)]
else:
for v in values:
assert dim in shape(v), f"concat dim '{dim}' must be present in the shapes of all values bot got value {type(v).__name__} with shape {shape(v)}"
for v, s in zip(values, shapes):
assert dim in s, f"concat dim '{dim}' must be present in the shapes of all values bot got value {type(v).__name__} with shape {s}"
for v in values[1:]:
assert set(non_batch(v).names) == set(non_batch(values[0]).names), f"Concatenated values must have the same non-batch dims but got {non_batch(values[0])} and {non_batch(v)}"
all_batch_dims = merge_shapes(*[shape(v).batch.without(dim) for v in values])
all_batch_dims = merge_shapes(*[s.batch - dim for s in shapes])
values = [expand(v, all_batch_dims) for v in values]
# --- First try __concat__ ---
for v in values:
Expand Down Expand Up @@ -367,7 +369,7 @@ def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_val
raise MagicNotImplemented(f"concat: No value implemented __concat__ and not all values were Sliceable along {dim}. values = {[type(v) for v in values]}")
if len(unstacked) > 8:
warnings.warn(f"concat() default implementation is slow on large dims ({dim}={len(unstacked)}). Please implement __concat__()", RuntimeWarning, stacklevel=2)
dim = shape(values[0])[dim].with_size(None)
dim = shapes[0][dim].with_size(None)
try:
return stack(unstacked, dim, **kwargs)
except MagicNotImplemented:
Expand Down

0 comments on commit 498e163

Please sign in to comment.