Skip to content

Commit

Permalink
Optimize tensor arithmetic for tensors with equal shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 8, 2025
1 parent d9faf17 commit 1156678
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 13 deletions.
2 changes: 1 addition & 1 deletion phiml/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list:
tensors cast to a common data type
"""
dtypes = [self.dtype(t) for t in tensors]
result_type = self.combine_types(*dtypes)
result_type = combine_types(*dtypes, fp_precision=get_precision())
if result_type.kind == bool and bool_to_int:
result_type = INT32
if result_type.kind == int and int_to_float:
Expand Down
5 changes: 4 additions & 1 deletion phiml/backend/_dtype.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from typing import Union

import numpy as np
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, kind: type, bits: int = None, precision: int = None):
""" Python class corresponding to the type of data, ignoring precision. One of (bool, int, float, complex, str) """
self.bits = bits
""" Number of bits used to store a single value of this type. See `DType.itemsize`. """
self._hash = hash(self.kind) + hash(self.bits)

@property
def precision(self):
Expand Down Expand Up @@ -78,7 +80,7 @@ def __ne__(self, other):
return not self == other

def __hash__(self):
return hash(self.kind) + hash(self.bits)
return self._hash

def __repr__(self):
return f"{self.kind.__name__}{self.bits}"
Expand Down Expand Up @@ -159,6 +161,7 @@ def from_numpy_dtype(np_dtype) -> DType:
_FROM_NUMPY[bool] = BOOL


@lru_cache
def combine_types(*dtypes: DType, fp_precision: int = None) -> DType:
# all bool?
if all(dt.kind == bool for dt in dtypes):
Expand Down
26 changes: 17 additions & 9 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,17 +1387,25 @@ def _op2(self, other, op: Callable, switch_args: bool):
return NotImplemented
if not isinstance(other, Dense):
other = Dense(other.native(other.shape), other.shape.names, other.shape, self._backend)
backend_op = get_operator(op, backend_for(self, other))
backend = backend_for(self, other)
backend_op = get_operator(op, backend)
if backend_op is None:
return op(self, other) if switch_args else op(other, self)
first_names = [n for n in self._names if n not in other._names]
names = first_names + list(other._names)
nat1 = self._transposed_native(names, False)
nat2 = other._native
if switch_args:
nat1, nat2 = nat2, nat1
result_nat = backend_op(nat1, nat2)
return Dense(result_nat, names, self._shape & other._shape, backend_for(self, other))
if other._names == self._names:
nat1, nat2 = self._native, other._native
names = self._names
if other._shape == self._shape:
r_shape = self._shape
else:
r_shape = self._shape & other._shape
else:
first_names = [n for n in self._names if n not in other._names]
names = first_names + list(other._names)
nat1 = self._transposed_native(names, False)
nat2 = other._native
r_shape = self._shape & other._shape
result_nat = backend_op(nat2, nat1) if switch_args else backend_op(nat1, nat2)
return Dense(result_nat, names, r_shape, backend)

def _natives(self) -> tuple:
return self._native,
Expand Down
4 changes: 2 additions & 2 deletions tests/commit/math/test_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def rnpv(size=64, d=2):
return np.random.randn(1, *[size] * d, d).astype(np.float32)


def _assert_equally_fast(f1, f2, n=100, tolerance_per_round=0.001):
def _assert_equally_fast(f1, f2, n=100, tolerance_per_round=0.01):
start = time.perf_counter()
for _ in range(n):
f1()
Expand All @@ -20,7 +20,7 @@ def _assert_equally_fast(f1, f2, n=100, tolerance_per_round=0.001):
for _ in range(n):
f2()
t_time = time.perf_counter() - start
print(np_time, t_time)
print("np:", np_time, "Φ:", t_time)
assert abs(t_time - np_time) / n <= tolerance_per_round


Expand Down

0 comments on commit 1156678

Please sign in to comment.