Skip to content

Commit

Permalink
Fix for tensor(Shape)
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Dec 4, 2024
1 parent b6e0d17 commit e21b7ab
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ def _getitem(self, selection: dict):
else:
return TensorStack(tensors, self._stack_dim)

def _unstack(self, dim):
def _unstack(self, dim: str):
if dim == self._stack_dim.name:
return self._tensors
else:
Expand Down Expand Up @@ -1613,18 +1613,7 @@ def tensor(data,
"""
shape = [parse_shape_spec(s) if isinstance(s, str) else s for s in shape]
shape = None if len(shape) == 0 else concat_shapes(*shape)
if isinstance(data, Tensor):
if convert:
backend = data.default_backend
if backend != default_backend():
data = data._op1(lambda n: convert_(n, use_dlpack=False))
if shape is None:
return data
else:
if None in shape.sizes:
shape = shape.with_sizes(data.shape)
return data._with_shape_replaced(shape)
elif isinstance(data, Shape):
if isinstance(data, Shape):
if shape is None:
shape = channel('dims')
shape = shape.with_size(data.names)
Expand All @@ -1636,6 +1625,17 @@ def tensor(data,
assert shape.rank == 1, "Can only convert 1D shapes to Tensors"
shape = shape.with_size(data.names)
data = data.sizes
if isinstance(data, Tensor):
if convert:
backend = data.default_backend
if backend != default_backend():
data = data._op1(lambda n: convert_(n, use_dlpack=False))
if shape is None:
return data
else:
if None in shape.sizes:
shape = shape.with_sizes(data.shape)
return data._with_shape_replaced(shape)
elif isinstance(data, str) or data is None:
return layout(data)
elif isinstance(data, (Number, bool)):
Expand Down

0 comments on commit e21b7ab

Please sign in to comment.