From 44daf9f81df9afb267820be53cb70e6d9c95ad28 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 30 Dec 2024 11:34:04 +0100 Subject: [PATCH] Return primitives when possible --- phiml/backend/_backend.py | 11 +++++++++++ phiml/math/_magic_ops.py | 2 +- phiml/math/_tensors.py | 4 +++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/phiml/backend/_backend.py b/phiml/backend/_backend.py index ca438ec7..eda959a2 100644 --- a/phiml/backend/_backend.py +++ b/phiml/backend/_backend.py @@ -194,6 +194,17 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list: tensors = [self.cast(t, result_type) for t in tensors] return tensors + def auto_cast1(self, tensor): + if isinstance(tensor, (bool, Number)): + return tensor + dtype = self.dtype(tensor) + if dtype.kind in {int, bool}: + return tensor + result_type = combine_types(dtype, fp_precision=self.precision) + if result_type.bits == dtype.bits: + return tensor + return self.cast(tensor, result_type) + def __str__(self): return self.name diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index 3c9032d4..a186d1bd 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -283,7 +283,7 @@ def stack(values: Union[Sequence[PhiTreeNodeType], Dict[str, PhiTreeNodeType]], return values[0] -def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], /, expand_values=False, **kwargs) -> PhiTreeNodeType: +def concat(values: Sequence[PhiTreeNodeType], dim: Union[str, Shape], expand_values=False, **kwargs) -> PhiTreeNodeType: """ Concatenates a sequence of `phiml.math.magic.Shapable` objects, e.g. `Tensor`, along one dimension. All values must have the same spatial, instance and channel dimensions and their sizes must be equal, except for `dim`. diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index c6c89348..71d511fa 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1240,7 +1240,7 @@ def native(self, order: Union[str, tuple, list, Shape] = None, force_expand=True if order is None: assert len(self._shape) <= 1, f"When calling Tensor.native() or Tensor.numpy(), the dimension order must be specified for Tensors with more than one dimension, e.g. '{','.join(self._shape.names)}'. The listed default dimension order can vary depending on the chosen backend. Consider using math.reshaped_native(Tensor) instead." if len(self._names) == len(self._shape): - return self.backend.auto_cast(self._native)[0] + return self.backend.auto_cast1(self._native) assert len(self._names) == 0 # shape.rank is 1 return self.backend.tile(self.backend.expand_dims(self.backend.auto_cast(self._native)[0]), (self._shape.size,)) if isinstance(order, str): @@ -1251,6 +1251,8 @@ def native(self, order: Union[str, tuple, list, Shape] = None, force_expand=True return self._reshaped_native(groups) def _reshaped_native(self, groups: Sequence[Shape]): + if not groups: + return self.backend.auto_cast1(self._native) perm = [] slices = [] tile = []