Skip to content

Commit

Permalink
Return primitives when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 30, 2024
1 parent 167980e commit 44daf9f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
11 changes: 11 additions & 0 deletions phiml/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list:
tensors = [self.cast(t, result_type) for t in tensors]
return tensors

def auto_cast1(self, tensor):
if isinstance(tensor, (bool, Number)):
return tensor
dtype = self.dtype(tensor)
if dtype.kind in {int, bool}:
return tensor
result_type = combine_types(dtype, fp_precision=self.precision)
if result_type.bits == dtype.bits:
return tensor
return self.cast(tensor, result_type)

def __str__(self):
return self.name

Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def stack(values: Union[Sequence[PhiTreeNodeType], Dict[str, PhiTreeNodeType]],
return values[0]


def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], /, expand_values=False, **kwargs) -> PhiTreeNodeType:
def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_values=False, **kwargs) -> PhiTreeNodeType:
"""
Concatenates a sequence of `phiml.math.magic.Shapable` objects, e.g. `Tensor`, along one dimension.
All values must have the same spatial, instance and channel dimensions and their sizes must be equal, except for `dim`.
Expand Down
4 changes: 3 additions & 1 deletion phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ def native(self, order: Union[str, tuple, list, Shape] = None, force_expand=True
if order is None:
assert len(self._shape) <= 1, f"When calling Tensor.native() or Tensor.numpy(), the dimension order must be specified for Tensors with more than one dimension, e.g. '{','.join(self._shape.names)}'. The listed default dimension order can vary depending on the chosen backend. Consider using math.reshaped_native(Tensor) instead."
if len(self._names) == len(self._shape):
return self.backend.auto_cast(self._native)[0]
return self.backend.auto_cast1(self._native)
assert len(self._names) == 0 # shape.rank is 1
return self.backend.tile(self.backend.expand_dims(self.backend.auto_cast(self._native)[0]), (self._shape.size,))
if isinstance(order, str):
Expand All @@ -1251,6 +1251,8 @@ def native(self, order: Union[str, tuple, list, Shape] = None, force_expand=True
return self._reshaped_native(groups)

def _reshaped_native(self, groups: Sequence[Shape]):
if not groups:
return self.backend.auto_cast1(self._native)
perm = []
slices = []
tile = []
Expand Down

0 comments on commit 44daf9f

Please sign in to comment.