diff --git a/phiml/backend/_backend.py b/phiml/backend/_backend.py index 4e9e9a4..1e15b28 100644 --- a/phiml/backend/_backend.py +++ b/phiml/backend/_backend.py @@ -186,7 +186,7 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list: tensors cast to a common data type """ dtypes = [self.dtype(t) for t in tensors] - result_type = self.combine_types(*dtypes) + result_type = combine_types(*dtypes, fp_precision=get_precision()) if result_type.kind == bool and bool_to_int: result_type = INT32 if result_type.kind == int and int_to_float: diff --git a/phiml/backend/_dtype.py b/phiml/backend/_dtype.py index ecc3700..04fa3a0 100644 --- a/phiml/backend/_dtype.py +++ b/phiml/backend/_dtype.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Union import numpy as np @@ -49,6 +50,7 @@ def __init__(self, kind: type, bits: int = None, precision: int = None): """ Python class corresponding to the type of data, ignoring precision. One of (bool, int, float, complex, str) """ self.bits = bits """ Number of bits used to store a single value of this type. See `DType.itemsize`. """ + self._hash = hash(self.kind) + hash(self.bits) @property def precision(self): @@ -78,7 +80,7 @@ def __ne__(self, other): return not self == other def __hash__(self): - return hash(self.kind) + hash(self.bits) + return self._hash def __repr__(self): return f"{self.kind.__name__}{self.bits}" @@ -159,6 +161,7 @@ def from_numpy_dtype(np_dtype) -> DType: _FROM_NUMPY[bool] = BOOL +@lru_cache def combine_types(*dtypes: DType, fp_precision: int = None) -> DType: # all bool? if all(dt.kind == bool for dt in dtypes): diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index e4bd391..148aa0c 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1387,17 +1387,25 @@ def _op2(self, other, op: Callable, switch_args: bool): return NotImplemented if not isinstance(other, Dense): other = Dense(other.native(other.shape), other.shape.names, other.shape, self._backend) - backend_op = get_operator(op, backend_for(self, other)) + backend = backend_for(self, other) + backend_op = get_operator(op, backend) if backend_op is None: return op(self, other) if switch_args else op(other, self) - first_names = [n for n in self._names if n not in other._names] - names = first_names + list(other._names) - nat1 = self._transposed_native(names, False) - nat2 = other._native - if switch_args: - nat1, nat2 = nat2, nat1 - result_nat = backend_op(nat1, nat2) - return Dense(result_nat, names, self._shape & other._shape, backend_for(self, other)) + if other._names == self._names: + nat1, nat2 = self._native, other._native + names = self._names + if other._shape == self._shape: + r_shape = self._shape + else: + r_shape = self._shape & other._shape + else: + first_names = [n for n in self._names if n not in other._names] + names = first_names + list(other._names) + nat1 = self._transposed_native(names, False) + nat2 = other._native + r_shape = self._shape & other._shape + result_nat = backend_op(nat2, nat1) if switch_args else backend_op(nat1, nat2) + return Dense(result_nat, names, r_shape, backend) def _natives(self) -> tuple: return self._native, diff --git a/tests/commit/math/test_speed.py b/tests/commit/math/test_speed.py index a494a1f..487be59 100644 --- a/tests/commit/math/test_speed.py +++ b/tests/commit/math/test_speed.py @@ -11,7 +11,7 @@ def rnpv(size=64, d=2): return np.random.randn(1, *[size] * d, d).astype(np.float32) -def _assert_equally_fast(f1, f2, n=100, tolerance_per_round=0.001): +def _assert_equally_fast(f1, f2, n=100, tolerance_per_round=0.01): start = time.perf_counter() for _ in range(n): f1() @@ -20,7 +20,7 @@ def _assert_equally_fast(f1, f2, n=100, tolerance_per_round=0.001): for _ in range(n): f2() t_time = time.perf_counter() - start - print(np_time, t_time) + print("np:", np_time, "Φ:", t_time) assert abs(t_time - np_time) / n <= tolerance_per_round