Skip to content

Commit

Permalink
Remove redundant shape() calls in stack()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- authored and Philipp Holl committed Oct 18, 2024
1 parent 9970c66 commit 433b86c
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,26 @@ def stack(values: Union[tuple, list, dict], dim: Union[Shape, str], expand_value
if not isinstance(dim, Shape):
dim = auto(dim)
values_ = tuple(values.values()) if isinstance(values, dict) else values
shapes = [shape(v) for v in values_]
if not expand_values:
for v in values_[1:]:
if set(non_batch(v).names) != set(non_batch(values_[0]).names): # shapes don't match
v0_dims = set(shapes[0].non_batch.names)
for s in shapes[1:]:
if set(s.non_batch.names) != v0_dims: # shapes don't match
from ._tensors import layout
return layout(values, dim)
# --- Add missing dimensions ---
if expand_values:
all_dims = merge_shapes(*values_, allow_varying_sizes=True)
all_dims = merge_shapes(*shapes, allow_varying_sizes=True)
if isinstance(values, dict):
values = {k: expand(v, all_dims.without(shape(v))) for k, v in values.items()}
values = {k: expand(v, all_dims - s) for (k, v), s in zip(values.items(), shapes)}
else:
values = [expand(v, all_dims.without(shape(v))) for v in values]
values = [expand(v, all_dims - s) for v, s in zip(values, shapes)]
else:
all_batch_dims = merge_shapes(*[shape(v).batch for v in values_], allow_varying_sizes=True)
all_batch_dims = merge_shapes(*[s.batch for s in shapes], allow_varying_sizes=True)
if isinstance(values, dict):
values = {k: expand(v, all_batch_dims.without(shape(v))) for k, v in values.items()}
values = {k: expand(v, all_batch_dims - s) for (k, v), s in zip(values.items(), shapes)}
else:
values = [expand(v, all_batch_dims.without(shape(v))) for v in values]
values = [expand(v, all_batch_dims - s) for v, s in zip(values, shapes)]
if dim.rank == 1:
assert dim.size == len(values) or dim.size is None, f"stack dim size must match len(values) or be undefined but got {dim} for {len(values)} values"
if dim.size is None:
Expand Down Expand Up @@ -355,6 +357,8 @@ def expand(value, *dims: Union[Shape, str], **kwargs):
Returns:
Same type as `value`.
"""
if not dims:
return value
dims = concat_shapes(*[d if isinstance(d, Shape) else parse_shape_spec(d) for d in dims])
combined = merge_shapes(value, dims) # check that existing sizes match
if not dims.without(shape(value)): # no new dims to add
Expand Down

0 comments on commit 433b86c

Please sign in to comment.