Skip to content

Commit

Permalink
Allow shape spec in layout()
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Oct 7, 2024
1 parent 9793843 commit e8e2c9d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,7 @@ def _simplify(self):


def tensor(data,
*shape: Union[Shape,str],
*shape: Union[Shape, str],
convert: bool = True,
default_list_dim=channel('vector')) -> Tensor: # TODO assume convert_unsupported, add convert_external=False for constants
"""
Expand Down Expand Up @@ -1673,12 +1673,12 @@ def tensor(data,
raise ValueError(f"{type(data)} is not supported. Only (Tensor, tuple, list, np.ndarray, native tensors) are allowed.\nCurrent backends: {BACKENDS}")


def wrap(data, *shape: Union[Shape,str], default_list_dim=channel('vector')) -> Tensor:
def wrap(data, *shape: Union[Shape, str], default_list_dim=channel('vector')) -> Tensor:
""" Short for `phiml.math.tensor()` with `convert=False`. """
return tensor(data, *shape, convert=False, default_list_dim=default_list_dim)


def layout(objects, *shape: Shape) -> Tensor:
def layout(objects, *shape: Union[Shape, str]) -> Tensor:
"""
Wraps a Python tree in a `Tensor`, allowing elements to be accessed via dimensions.
A python tree is a structure of nested `tuple`, `list`, `dict` and *leaf* objects where leaves can be any Python object.
Expand All @@ -1704,6 +1704,7 @@ def layout(objects, *shape: Shape) -> Tensor:
`Tensor`.
Calling `Tensor.native()` on the returned tensor will return `objects`.
"""
shape = [parse_shape_spec(s) if isinstance(s, str) else s for s in shape]
assert all(isinstance(s, Shape) for s in shape), f"shape needs to be one or multiple Shape instances but got {shape}"
shape = EMPTY_SHAPE if len(shape) == 0 else concat_shapes(*shape)
if isinstance(objects, Layout):
Expand Down

0 comments on commit e8e2c9d

Please sign in to comment.