Skip to content

Commit

Permalink
Support flatten() for non-Shaped args
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 26, 2023
1 parent b46948f commit 2934bba
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def flatten(value, flat_dim: Shape = instance('flat'), flatten_batch=False, **kw
Args:
value: `phiml.math.magic.Shapable`, such as `Tensor`.
If a non-`phiml.math.magic.Shaped` object or one with an empty `Shape` is passed, it is returned without alteration.
flat_dim: Dimension name and type as `Shape` object. The size is ignored.
flatten_batch: Whether to flatten batch dimensions as well.
If `False`, batch dimensions are kept, only onn-batch dimensions are flattened.
Expand All @@ -602,6 +603,10 @@ def flatten(value, flat_dim: Shape = instance('flat'), flatten_batch=False, **kw
(flatⁱ=12) const 0.0
"""
assert isinstance(flat_dim, Shape) and flat_dim.rank == 1, flat_dim
if not isinstance(value, Shaped):
return value
if shape(value).is_empty:
return value
assert isinstance(value, Shapable) and isinstance(value, Shaped), f"value must be Shapable but got {type(value)}"
# --- First try __flatten__ ---
if hasattr(value, '__flatten__'):
Expand Down

0 comments on commit 2934bba

Please sign in to comment.