diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index 31e3cbe..b1149f0 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -111,7 +111,10 @@ def tensor_like(existing_tensor: Tensor, values: Union[Tensor, Number, bool], va values = where(existing_tensor._valid_mask(), values, 0) return existing_tensor._with_values(values) if not is_sparse(existing_tensor): - return unpack_dim(values, instance, existing_tensor.shape.non_channel.non_batch) + if instance(values): + return unpack_dim(values, instance, existing_tensor.shape.non_batch) + else: + return expand(values, existing_tensor.shape.non_batch) raise NotImplementedError