From 6ade6df27875f817aeb92df7afdb26bd2f9f75df Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 6 Jan 2025 16:04:47 +0100 Subject: [PATCH] Refactor op2 --- phiml/backend/_backend.py | 48 +++++++- phiml/backend/xops.py | 33 ++++++ phiml/math/_ops.py | 19 ++-- phiml/math/_sparse.py | 53 +++++---- phiml/math/_tensors.py | 230 +++++++++++++++++--------------------- phiml/math/_trace.py | 89 +++++++-------- 6 files changed, 250 insertions(+), 222 deletions(-) create mode 100644 phiml/backend/xops.py diff --git a/phiml/backend/_backend.py b/phiml/backend/_backend.py index 169da7a..229f1dc 100644 --- a/phiml/backend/_backend.py +++ b/phiml/backend/_backend.py @@ -1,5 +1,6 @@ import dataclasses import logging +import operator import sys import warnings from builtins import ValueError @@ -13,6 +14,7 @@ import numpy as np from numpy import ndarray +from . import xops from ._dtype import DType, combine_types, INT32, INT64 TensorType = TypeVar('TensorType') @@ -194,15 +196,15 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list: return tensors def auto_cast1(self, tensor): - if isinstance(tensor, (bool, Number)): + if isinstance(tensor, (bool, int, float, complex)): return tensor dtype = self.dtype(tensor) if dtype.kind in {int, bool}: return tensor - result_type = combine_types(dtype, fp_precision=self.precision) - if result_type.bits == dtype.bits: - return tensor - return self.cast(tensor, result_type) + if dtype.precision != get_precision(): + result_type = DType(dtype.kind, precision=get_precision()) + return self.cast(tensor, result_type) + return tensor def __str__(self): return self.name @@ -2114,3 +2116,39 @@ def assemble(b: Backend, *args): all_values = {f.name: re_values[f.name] if f.name in tensor_fields else getattr(data, f.name) for f in fields} return type(data)(**all_values) return assemble, tensors + + +_BACKEND_OPERATORS = { + operator.eq: Backend.equal, + operator.ne: Backend.not_equal, + operator.gt: Backend.greater_than, + operator.ge: Backend.greater_or_equal, + operator.add: Backend.add, + operator.sub: Backend.sub, + operator.mul: Backend.mul, + operator.truediv: Backend.div, + operator.pow: Backend.pow, + operator.mod: Backend.mod, + operator.and_: Backend.and_, + operator.or_: Backend.or_, + operator.xor: Backend.xor, + operator.floordiv: Backend.floordiv, + operator.lshift: Backend.shift_bits_left, + operator.rshift: Backend.shift_bits_right, + operator.inv: Backend.invert, + operator.invert: Backend.invert, + divmod: divmod, + abs: Backend.abs, + xops.save_div: Backend.divide_no_nan, + xops.gamma_inc_l: Backend.gamma_inc_l, + xops.gamma_inc_u: Backend.gamma_inc_u, + xops.arctan2: Backend.arctan2, + xops.minimum: Backend.minimum, + xops.maximum: Backend.maximum, +} + +def get_operator(op: Callable, backend: Backend): + fun = _BACKEND_OPERATORS.get(op) + if fun is not None: + return getattr(backend, fun.__name__) + return fun diff --git a/phiml/backend/xops.py b/phiml/backend/xops.py new file mode 100644 index 0000000..93e2d1f --- /dev/null +++ b/phiml/backend/xops.py @@ -0,0 +1,33 @@ +""" +Extra operators. + +This module is an extension to the built-in operators. +""" + + +class ExtraOperator(Exception): + pass + + +def save_div(numerator, denominator): + raise ExtraOperator + + +def gamma_inc_l(a, x): + raise ExtraOperator + + +def gamma_inc_u(a, x): + raise ExtraOperator + + +def arctan2(tan, divide_by): + raise ExtraOperator + + +def minimum(x1, x2): + raise ExtraOperator + + +def maximum(x1, x2): + raise ExtraOperator diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 2223e3a..3689b61 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -8,6 +8,7 @@ from ..backend import default_backend, choose_backend, Backend, get_precision, convert as b_convert, BACKENDS, NoBackendFound, ComputeDevice, NUMPY from ..backend._dtype import DType, combine_types, INT32 +from phiml.backend import xops from .magic import PhiTreeNode from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze, ipack from ._shape import (Shape, EMPTY_SHAPE, @@ -2289,11 +2290,10 @@ def incomplete_gamma(a: TensorOrTree, x: TensorOrTree, upper=False, regularized= upper: Whether to complete the upper integral (x to infinity) or the lower integral (0 to x). regularized: Whether the integral is divided by Γ(a). """ - call = lambda a, x: incomplete_gamma(a, x, upper=upper, regularized=regularized) if upper: - reg = custom_op2(a, x, call, lambda a, x: choose_backend(a, x).gamma_inc_u(a, x), 'gamma_inc_u') + reg = custom_op2(a, x, xops.gamma_inc_u) else: - reg = custom_op2(a, x, call, lambda a, x: choose_backend(a, x).gamma_inc_l(a, x), 'gamma_inc_l') + reg = custom_op2(a, x, xops.gamma_inc_l) return reg if regularized else reg * exp(log_gamma(a)) @@ -2464,7 +2464,7 @@ def arctan(x: TensorOrTree, divide_by=None) -> TensorOrTree: return _backend_op1(x, Backend.arctan) else: divide_by = to_float(divide_by) - return custom_op2(x, divide_by, arctan, lambda a, b: choose_backend(a, b).arctan2(a, b), 'arctan') + return custom_op2(x, divide_by, xops.arctan2) def angle(x: TensorOrTree) -> TensorOrTree: @@ -2560,12 +2560,7 @@ def cast_same(*values: Tensor) -> Tuple[Tensor]: def safe_div(x: Union[Number, Tensor], y: Union[Number, Tensor]): """ Computes *x/y* with the `Tensor`s `x` and `y` but returns 0 where *y=0*. """ - return custom_op2(x, y, - l_operator=safe_div, - l_native_function=lambda x_, y_: choose_backend(x_, y_).divide_no_nan(x_, y_), - r_operator=lambda y_, x_: safe_div(x_, y_), - r_native_function=lambda y_, x_: choose_backend(x_, y_).divide_no_nan(x_, y_), - op_name='divide_no_nan') + return custom_op2(x, y, xops.save_div) def maximum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False): @@ -2575,7 +2570,7 @@ def maximum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False): return y elif y is None: return x - return custom_op2(x, y, maximum, lambda x_, y_: choose_backend(x_, y_).maximum(x_, y_), op_name='maximum') + return custom_op2(x, y, xops.maximum) def minimum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False): @@ -2585,7 +2580,7 @@ def minimum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False): return y elif y is None: return x - return custom_op2(x, y, minimum, lambda x_, y_: choose_backend(x_, y_).minimum(x_, y_), op_name='minimum') + return custom_op2(x, y, xops.minimum) def clip(x: Tensor, lower_limit: Union[float, Tensor] = 0, upper_limit: Union[float, Tensor, Shape] = 1): diff --git a/phiml/math/_sparse.py b/phiml/math/_sparse.py index a57c370..d5f5665 100644 --- a/phiml/math/_sparse.py +++ b/phiml/math/_sparse.py @@ -1,3 +1,4 @@ +import operator import warnings from functools import partial from numbers import Number @@ -11,7 +12,7 @@ from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim, unstack from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, \ concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal, concat_shapes_ -from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for +from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for, custom_op2 from ..backend import choose_backend, NUMPY, Backend, get_precision from ..backend._dtype import DType, INT64 @@ -367,20 +368,20 @@ def _with_shape_replaced(self, new_shape: Shape): def _op1(self, native_function): return self._with_values(self._values._op1(native_function)) - def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor': + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': other_shape = shape(other) affects_only_values = self._dense_shape.isdisjoint(other_shape) if affects_only_values: - return self._with_values(operator(self._values, other)) + return self._with_values(op(self._values, other)) if isinstance(other, CompressedSparseMatrix): other = other.decompress() if isinstance(other, SparseCoordinateTensor): if same_sparsity_pattern(self, other): - return self._with_values(operator(self._values, other._values)) + return self._with_values(op(self._values, other._values)) else: - if op_name not in ['add', 'radd', 'sub', 'rsub']: + if op not in {operator.add, operator.sub}: same_sparsity_pattern(self, other) # debug checkpoint - raise AssertionError(f"Operation '{op_symbol}' ({op_name}) requires sparse matrices with the same sparsity pattern.") + raise AssertionError(f"Operation '{op}' requires sparse matrices with the same sparsity pattern.") all_sparse_dims = sparse_dims(other) & sparse_dims(self) self_indices = pack_dims(self._indices, instance, instance('sp_entries')) other_indices = pack_dims(other._indices, instance, instance('sp_entries')) @@ -389,21 +390,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st self_indices, self_values = with_sparsified_dim(self_indices, self_values, all_sparse_dims) other_indices, other_values = with_sparsified_dim(other_indices, other_values, all_sparse_dims) indices = concat([self_indices, other_indices], 'sp_entries') - if op_symbol == '+': + if op == operator.add: values = concat([self_values, other_values], instance(self_values), expand_values=True) - elif op_name == 'sub': - values = concat([self_values, -other_values], instance(self_values), expand_values=True) - else: # op_name == 'rsub': - values = concat([-self_values, other_values], instance(self_values), expand_values=True) + else: + values = concat([-self_values, other_values] if switch_args else [self_values, -other_values], instance(self_values), expand_values=True) return SparseCoordinateTensor(indices, values, self._dense_shape & other._dense_shape, can_contain_double_entries=True, indices_sorted=False, indices_constant=self._indices_constant) else: # other is dense if self._dense_shape in other.shape: # all dims dense -> convert to dense - return dense(self)._op2(other, operator, native_function, op_name, op_symbol) + return dense(self)._op2(other, op, switch_args) else: # only some dims dense -> stay sparse dense_dims = self._dense_shape.only(other.shape) assert instance(other).without(self._dense_shape).is_empty, f"Instance dims cannot be added to sparse tensors from sparse-dense operations but got {other.shape} for sparse tensor {self.shape}" other_values = other[self._indices.sparse_idx[dense_dims.name_list]] - values = operator(self._values, other_values) + values = custom_op2(self._values, other_values, op, switch_args) return self._with_values(values) def _getitem(self, selection: dict) -> 'Tensor': @@ -702,19 +701,19 @@ def __expand__(self, dims: Shape, **kwargs) -> 'Tensor': def _op1(self, native_function): return self._with_values(self._values._op1(native_function)) - def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor': + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': other_shape = shape(other) affects_only_values = self.sparse_dims.isdisjoint(other_shape) and non_instance(self._indices).isdisjoint(other_shape) if affects_only_values: - return self._with_values(operator(self._values, other)) + return self._with_values(custom_op2(self._values, other, op, switch_args)) elif isinstance(other, CompressedSparseMatrix): if same_sparsity_pattern(self, other): - result = operator(self._values, other._values) + result = op(self._values, other._values) if self._uncompressed_offset is not None: from ._ops import where result = where(self._valid_mask(), result, 0) return self._with_values(result) - elif op_symbol == '+': + elif op == operator.add: raise NotImplementedError("Compressed addition not yet implemented") else: # convert to COO, then perform operation @@ -723,19 +722,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st from ._ops import gather, boolean_mask, clip, where if self._uncompressed_offset is None: other_values = gather(other, self._indices, self._uncompressed_dims) - return self._with_values(operator(self._values, other_values)) + return self._with_values(op(self._values, other_values)) # if bake_slice: # baked = self._bake_slice() # other_values = gather(other, baked._indices, self._uncompressed_dims) # return baked._with_values(operator(baked._values, other_values)) indices = clip(self._indices - self._uncompressed_offset, 0, self._uncompressed_dims.volume - 1) other_values = gather(other, indices, self._uncompressed_dims) - return self._with_values(where(self._valid_mask(), operator(self._values, other_values), 0)) + return self._with_values(where(self._valid_mask(), op(self._values, other_values), 0)) elif self._compressed_dims in other_shape and self._uncompressed_dims.isdisjoint(other_shape): from ._ops import gather, boolean_mask, clip, where row_indices, _ = self._coo_indices('clamp') other_values = gather(other, row_indices, self._compressed_dims) - result_values = operator(self._values, other_values) + result_values = op(self._values, other_values) if self._uncompressed_offset is not None: result_values = where(self._valid_mask(), result_values, 0) return self._with_values(result_values) @@ -960,16 +959,16 @@ def _with_shape_replaced(self, new_shape: Shape): def _op1(self, native_function): return self._with_values(self._values._op1(native_function)) - def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor': + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': other_shape = shape(other) affects_only_values = self._compressed_dims.isdisjoint(other_shape) if affects_only_values: - return self._with_values(operator(self._values, other)) + return self._with_values(op(self._values, other)) elif isinstance(other, (CompressedSparseMatrix, CompactSparseTensor)): if same_sparsity_pattern(self, other): - result = operator(self._values, other._values) + result = op(self._values, other._values) return self._with_values(result) - elif op_symbol == '+': + elif op == operator.add: raise NotImplementedError("Compressed addition not yet implemented") else: # convert to COO, then perform operation @@ -978,18 +977,18 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st from ._ops import gather, boolean_mask, clip, where if self._uncompressed_offset is None: other_values = gather(other, self._indices, self._uncompressed_dims) - return self._with_values(operator(self._values, other_values)) + return self._with_values(op(self._values, other_values)) # if bake_slice: # baked = self._bake_slice() # other_values = gather(other, baked._indices, self._uncompressed_dims) # return baked._with_values(operator(baked._values, other_values)) indices = clip(self._indices - self._uncompressed_offset, 0, self._uncompressed_dims.volume - 1) other_values = gather(other, indices, self._uncompressed_dims) - return self._with_values(where(self._valid_mask(), operator(self._values, other_values), 0)) + return self._with_values(where(self._valid_mask(), op(self._values, other_values), 0)) elif self._compressed_dims in other_shape and self._uncompressed_dims.isdisjoint(other_shape): from ._ops import gather, boolean_mask, clip, where other_values = gather(other, self._indices, self._compressed_dims) - result_values = operator(self._values, other_values) + result_values = op(self._values, other_values) return self._with_values(result_values) else: raise NotImplementedError diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 0906e8c..c247b33 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -1,4 +1,5 @@ import dataclasses +import operator from numbers import Number import traceback import warnings @@ -19,9 +20,11 @@ prepare_renaming_gather, after_gather, concat_shapes_, Dim, PureShape, SHAPE_TYPES) from ..backend import NoBackendFound, choose_backend, BACKENDS, get_precision, default_backend, convert as convert_, \ Backend, ComputeDevice, OBJECTS, NUMPY, ML_LOGGER +from ..backend._backend import get_operator from ..backend._dtype import DType, combine_types, BOOL, INT64, INT32 from .magic import BoundDim, PhiTreeNode, slicing_dict, Shaped, _BoundDims from .magic import Shapable +from ..backend.xops import ExtraOperator class Tensor: @@ -126,75 +129,38 @@ def __array__(self, dtype=None): # NumPy conversion def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # NumPy interface if len(inputs) != 2: return NotImplemented + switch_args = self is inputs[1] + other = inputs[0] if switch_args else inputs[1] if ufunc.__name__ == 'multiply': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x * y, lambda x, y: choose_backend(x, y).mul(x, y), 'mul', '*') - else: - return self._op2(inputs[0], lambda x, y: y * x, lambda x, y: choose_backend(x, y).mul(y, x), 'rmul', '*') + return self._op2(other, operator.mul, switch_args) if ufunc.__name__ == 'add': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x + y, lambda x, y: choose_backend(x, y).add(x, y), 'add', '+') - else: - return self._op2(inputs[0], lambda x, y: y + x, lambda x, y: choose_backend(x, y).add(y, x), 'radd', '+') + return self._op2(other, operator.add, switch_args) if ufunc.__name__ == 'subtract': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x - y, lambda x, y: choose_backend(x, y).sub(x, y), 'add', '-') - else: - return self._op2(inputs[0], lambda x, y: y - x, lambda x, y: choose_backend(x, y).sub(y, x), 'rsub', '-') + return self._op2(other, operator.sub, switch_args) if ufunc.__name__ in ['divide', 'true_divide']: - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x / y, lambda x, y: choose_backend(x, y).div(x, y), 'true_divide', '/') - else: - return self._op2(inputs[0], lambda x, y: y / x, lambda x, y: choose_backend(x, y).div(y, x), 'r_true_divide', '/') + return self._op2(other, operator.truediv, switch_args) if ufunc.__name__ == 'floor_divide': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x // y, lambda x, y: choose_backend(x, y).floordiv(x, y), 'floor_divide', '//') - else: - return self._op2(inputs[0], lambda x, y: y // x, lambda x, y: choose_backend(x, y).floordiv(y, x), 'r_floor_divide', '//') + return self._op2(other, operator.floordiv, switch_args) if ufunc.__name__ == 'remainder': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x % y, lambda x, y: choose_backend(x, y).mod(x, y), 'remainder', '%') - else: - return self._op2(inputs[0], lambda x, y: y % x, lambda x, y: choose_backend(x, y).mod(y, x), 'r_remainder', '%') + return self._op2(other, operator.mod, switch_args) if ufunc.__name__ == 'power': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x ** y, lambda x, y: choose_backend(x, y).pow(x, y), 'power', '**') - else: - return self._op2(inputs[0], lambda x, y: y ** x, lambda x, y: choose_backend(x, y).pow(y, x), 'r_power', '**') + return self._op2(other, operator.pow, switch_args) if ufunc.__name__ == 'equal': return self.__eq__(inputs[1] if self is inputs[0] else inputs[0]) if ufunc.__name__ == 'not_equal': return self.__ne__(inputs[1] if self is inputs[0] else inputs[0]) if ufunc.__name__ == 'greater': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x > y, lambda x, y: choose_backend(x, y).greater_than(x, y), 'greater', '>') - else: - return self._op2(inputs[0], lambda x, y: y > x, lambda x, y: choose_backend(x, y).greater_than(y, x), 'r_greater', '>') + return self._op2(other, operator.gt, switch_args) if ufunc.__name__ == 'greater_equal': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x >= y, lambda x, y: choose_backend(x, y).greater_or_equal(x, y), 'greater_equal', '>=') - else: - return self._op2(inputs[0], lambda x, y: y >= x, lambda x, y: choose_backend(x, y).greater_or_equal(y, x), 'r_greater_equal', '>=') + return self._op2(other, operator.ge, switch_args) if ufunc.__name__ == 'less': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x < y, lambda x, y: choose_backend(x, y).greater_than(y, x), 'less', '<') - else: - return self._op2(inputs[0], lambda x, y: y < x, lambda x, y: choose_backend(x, y).greater_than(x, y), 'r_less', '<') + return self._op2(other, operator.gt, not switch_args) if ufunc.__name__ == 'less_equal': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x <= y, lambda x, y: choose_backend(x, y).greater_or_equal(y, x), 'less_equal', '<=') - else: - return self._op2(inputs[0], lambda x, y: y <= x, lambda x, y: choose_backend(x, y).greater_or_equal(x, y), 'r_less_equal', '<=') + return self._op2(other, operator.ge, not switch_args) if ufunc.__name__ == 'left_shift': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x << y, lambda x, y: choose_backend(x, y).shift_bits_left(x, y), 'left_shift', '<<') - else: - return self._op2(inputs[0], lambda x, y: y << x, lambda x, y: choose_backend(x, y).shift_bits_left(y, x), 'r_left_shift', '<<') + return self._op2(other, operator.lshift, switch_args) if ufunc.__name__ == 'right_shift': - if inputs[0] is self: - return self._op2(inputs[1], lambda x, y: x >> y, lambda x, y: choose_backend(x, y).shift_bits_right(x, y), 'right_shift', '>>') - else: - return self._op2(inputs[0], lambda x, y: y >> x, lambda x, y: choose_backend(x, y).shift_bits_right(y, x), 'r_right_shift', '>>') + return self._op2(other, operator.rshift, switch_args) raise NotImplementedError(f"NumPy function '{ufunc.__name__}' is not compatible with Φ-ML tensors.") @property @@ -627,71 +593,71 @@ def __getattr__(self, name): return TensorDim(self, name) def __add__(self, other): - return self._op2(other, lambda x, y: x + y, lambda x, y: choose_backend(x, y).add(x, y), 'add', '+') + return self._op2(other, operator.add, False) def __radd__(self, other): - return self._op2(other, lambda x, y: y + x, lambda x, y: choose_backend(x, y).add(y, x), 'radd', '+') + return self._op2(other, operator.add, True) def __sub__(self, other): - return self._op2(other, lambda x, y: x - y, lambda x, y: choose_backend(x, y).sub(x, y), 'sub', '-') + return self._op2(other, operator.sub, False) def __rsub__(self, other): - return self._op2(other, lambda x, y: y - x, lambda x, y: choose_backend(x, y).sub(y, x), 'rsub', '-') + return self._op2(other, operator.sub, True) def __and__(self, other): - return self._op2(other, lambda x, y: x & y, lambda x, y: choose_backend(x, y).and_(x, y), 'and', '&') + return self._op2(other, operator.and_, False) def __rand__(self, other): - return self._op2(other, lambda x, y: y & x, lambda x, y: choose_backend(x, y).and_(y, x), 'rand', '&') + return self._op2(other, operator.and_, True) def __or__(self, other): - return self._op2(other, lambda x, y: x | y, lambda x, y: choose_backend(x, y).or_(x, y), 'or', '|') + return self._op2(other, operator.or_, False) def __ror__(self, other): - return self._op2(other, lambda x, y: y | x, lambda x, y: choose_backend(x, y).or_(y, x), 'ror', '|') + return self._op2(other, operator.or_, True) def __xor__(self, other): - return self._op2(other, lambda x, y: x ^ y, lambda x, y: choose_backend(x, y).xor(x, y), 'xor', '^') + return self._op2(other, operator.xor, False) def __rxor__(self, other): - return self._op2(other, lambda x, y: y ^ x, lambda x, y: choose_backend(x, y).xor(y, x), 'rxor', '^') + return self._op2(other, operator.xor, True) def __mul__(self, other): - return self._op2(other, lambda x, y: x * y, lambda x, y: choose_backend(x, y).mul(x, y), 'mul', '*') + return self._op2(other, operator.mul, False) def __rmul__(self, other): - return self._op2(other, lambda x, y: y * x, lambda x, y: choose_backend(x, y).mul(y, x), 'rmul', '*') + return self._op2(other, operator.mul, True) def __truediv__(self, other): - return self._op2(other, lambda x, y: x / y, lambda x, y: choose_backend(x, y).div(x, y), 'truediv', '/') + return self._op2(other, operator.truediv, False) def __rtruediv__(self, other): - return self._op2(other, lambda x, y: y / x, lambda x, y: choose_backend(x, y).div(y, x), 'rtruediv', '/') + return self._op2(other, operator.truediv, True) def __divmod__(self, other): - return self._op2(other, lambda x, y: divmod(x, y), lambda x, y: divmod(x, y), 'divmod', 'divmod') + return self._op2(other, divmod, False) def __rdivmod__(self, other): - return self._op2(other, lambda x, y: divmod(y, x), lambda x, y: divmod(y, x), 'rdivmod', 'divmod') + return self._op2(other, divmod, True) def __floordiv__(self, other): - return self._op2(other, lambda x, y: x // y, lambda x, y: choose_backend(x, y).floordiv(x, y), 'floordiv', '//') + return self._op2(other, operator.floordiv, False) def __rfloordiv__(self, other): - return self._op2(other, lambda x, y: y // x, lambda x, y: choose_backend(x, y).floordiv(y, x), 'rfloordiv', '//') + return self._op2(other, operator.floordiv, True) def __pow__(self, power, modulo=None): assert modulo is None - return self._op2(power, lambda x, y: x ** y, lambda x, y: choose_backend(x, y).pow(x, y), 'pow', '**') + return self._op2(power, operator.pow, False) def __rpow__(self, other): - return self._op2(other, lambda x, y: y ** x, lambda x, y: choose_backend(x, y).pow(y, x), 'rpow', '**') + return self._op2(other, operator.pow, True) def __mod__(self, other): - return self._op2(other, lambda x, y: x % y, lambda x, y: choose_backend(x, y).mod(x, y), 'mod', '%') + return self._op2(other, operator.mod, False) def __rmod__(self, other): - return self._op2(other, lambda x, y: y % x, lambda x, y: choose_backend(x, y).mod(y, x), 'rmod', '%') + return self._op2(other, operator.mod, True) def __eq__(self, other) -> 'Tensor': if self is other: @@ -706,7 +672,7 @@ def __eq__(self, other) -> 'Tensor': if other is None: other = float('nan') if self.shape.is_compatible(shape(other)): - return self._op2(other, lambda x, y: x == y, lambda x, y: choose_backend(x, y).equal(x, y), 'eq', '==') + return self._op2(other, operator.eq, False) else: return wrap(False) @@ -721,33 +687,33 @@ def __ne__(self, other) -> 'Tensor': if other is None: other = float('nan') if self.shape.is_compatible(shape(other)): - return self._op2(other, lambda x, y: x != y, lambda x, y: choose_backend(x, y).not_equal(x, y), 'ne', '!=') + return self._op2(other, operator.ne, False) else: return wrap(True) def __lt__(self, other): - return self._op2(other, lambda x, y: x < y, lambda x, y: choose_backend(x, y).greater_than(y, x), 'lt', '<') + return self._op2(other, operator.gt, True) def __le__(self, other): - return self._op2(other, lambda x, y: x <= y, lambda x, y: choose_backend(x, y).greater_or_equal(y, x), 'le', '<=') + return self._op2(other, operator.ge, True) def __gt__(self, other): - return self._op2(other, lambda x, y: x > y, lambda x, y: choose_backend(x, y).greater_than(x, y), 'gt', '>') + return self._op2(other, operator.gt, False) def __ge__(self, other): - return self._op2(other, lambda x, y: x >= y, lambda x, y: choose_backend(x, y).greater_or_equal(x, y), 'ge', '>=') + return self._op2(other, operator.ge, False) def __lshift__(self, other): - return self._op2(other, lambda x, y: x << y, lambda x, y: choose_backend(x, y).shift_bits_left(x, y), 'lshift', '<<') + return self._op2(other, operator.lshift, False) def __rlshift__(self, other): - return self._op2(other, lambda y, x: x << y, lambda y, x: choose_backend(x, y).shift_bits_left(x, y), 'lshift', '<<') + return self._op2(other, operator.lshift, True) def __rshift__(self, other): - return self._op2(other, lambda x, y: x >> y, lambda x, y: choose_backend(x, y).shift_bits_right(x, y), 'rshift', '>>') + return self._op2(other, operator.rshift, False) def __rrshift__(self, other): - return self._op2(other, lambda y, x: x >> y, lambda y, x: choose_backend(x, y).shift_bits_right(x, y), 'rshift', '>>') + return self._op2(other, operator.rshift, True) def __abs__(self): return self._op1(lambda t: choose_backend(t).abs(t)) @@ -762,7 +728,7 @@ def __deepcopy__(self, memodict={}): return self._op1(lambda t: choose_backend(t).copy(t, only_mutable=False)) def __neg__(self) -> 'Tensor': - return self._op1(lambda t: -t) + return self._op1(operator.neg) def __invert__(self) -> 'Tensor': return self._op1(lambda t: choose_backend(t).invert(t)) @@ -832,17 +798,13 @@ def _op1(self, native_function) -> 'Tensor': """ raise NotImplementedError(self.__class__) - def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor': + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': """ Apply a broadcast operation on two tensors. Args: other: second argument - operator: function (Tensor, Tensor) -> Tensor, used to propagate the operation to children tensors to have Python choose the callee - native_function: function (native tensor, native tensor) -> native tensor - op_name: Name of the python function without leading and trailing `__`. - Examples: 'add', 'radd', 'sub', 'mul', 'and', 'eq', 'ge'. - op_symbol: Operation symbol, such as '+', '-', '&', '%', '>=' + op: Operator function (a, b) -> c, used to propagate the operation to children tensors to have Python choose the callee Returns: `Tensor` @@ -1134,42 +1096,43 @@ def __iter__(self): def __eq__(self, other): if _EQUALITY_REDUCE[-1]['type'] != 'elementwise': return Tensor.__eq__(self, other) - return self._op2(other, lambda x, y: x == y, lambda x, y: x == y, 'eq', '==') + return self._op2(other, operator.eq, False) def __ne__(self, other): if _EQUALITY_REDUCE[-1]['type'] != 'elementwise': return Tensor.__ne__(self, other) - return self._op2(other, lambda x, y: x != y, lambda x, y: x != y, 'ne', '!=') + return self._op2(other, operator.ne, False) def _assert_close(self, other: Tensor, rel_tolerance: float, abs_tolerance: float, msg: str, verbose: bool): from ._ops import assert_close inner_test = lambda x, y: assert_close(x, y, rel_tolerance=rel_tolerance, abs_tolerance=abs_tolerance, msg=msg, verbose=verbose) - return self._op2(other, inner_test, inner_test, 'assert_close', '≈') + return self._op2(other, inner_test, False) - def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> Tensor: - obj = self._recursive_op2(self._obj, self._stack_dim, other, operator, native_function, op_name) + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': + obj = self._recursive_op2(self._obj, self._stack_dim, other, op) new_stack = self._stack_dim + (other._stack_dim - self._stack_dim) if isinstance(other, Layout) else self._stack_dim return Layout(obj, new_stack) @staticmethod - def _recursive_op2(obj, shape: Shape, other: Tensor, operator, native_function, op_name): + def _recursive_op2(obj, shape: Shape, other: Tensor, operator: Callable): if shape: dim = shape.names[0] if isinstance(other, Tensor) and dim in other.shape: - assert other.shape.get_size(dim) == len(obj), f"Shape mismatch during {op_name}: '{dim}' has size {len(obj)} on layout but {other.shape.get_size(dim)} on other tensor." + assert other.shape.get_size(dim) == len(obj), f"Shape mismatch during {operator.__name__}: '{dim}' has size {len(obj)} on layout but {other.shape.get_size(dim)} on other tensor." others = [other[{dim: i}] for i in range(len(obj))] else: others = [other] * len(obj) if isinstance(obj, (tuple, list)): - return type(obj)([Layout._recursive_op2(i, shape[1:], o, operator, native_function, op_name) for i, o in zip(obj, others)]) + return type(obj)([Layout._recursive_op2(i, shape[1:], o, operator) for i, o in zip(obj, others)]) elif isinstance(obj, dict): - return {k: Layout._recursive_op2(v, shape[1:], o, operator, native_function, op_name) for (k, v), o in zip(obj.items(), others)} + return {k: Layout._recursive_op2(v, shape[1:], o, operator) for (k, v), o in zip(obj.items(), others)} else: # leaf if isinstance(other, Layout) and not other.shape: - return native_function(obj, other.native()) + return operator(obj, other.native()) if isinstance(other, Tensor): return operator(obj, other) else: + native_function = get_operator(operator, choose_backend(obj, other)) return native_function(obj, other) def _op1(self, native_function): @@ -1396,7 +1359,15 @@ def _op1(self, native_function): native = native_function(self._native) return Dense(native, self._names, self._shape, self._backend) if native is not None else self - def _op2(self, other, operator, native_function, op_name: str = 'unknown', op_symbol: str = '?', switch_args=False): + def _op2(self, other, op: Callable, switch_args: bool): + if isinstance(other, (bool, int, float, complex)): + n1, n2 = (other, self._backend.auto_cast1(self._native)) if switch_args else (self._backend.auto_cast1(self._native), other) + try: + result = op(n1, n2) + except ExtraOperator: + b_fun = get_operator(op, self._backend) + result = b_fun(n1, n2) + return Dense(result, self._names, self._shape, self._backend) was_other_tensor = isinstance(other, Tensor) if not was_other_tensor: try: @@ -1413,7 +1384,8 @@ def _op2(self, other, operator, native_function, op_name: str = 'unknown', op_sy nat2 = other._native if switch_args: nat1, nat2 = nat2, nat1 - result_nat = native_function(nat1, nat2) + backend_op = get_operator(op, self._backend) + result_nat = backend_op(nat1, nat2) return Dense(result_nat, names, self._shape & other._shape, backend_for(self, other)) def _natives(self) -> tuple: @@ -1599,25 +1571,28 @@ def _op1(self, native_function): else: return self._contiguous()._op1(native_function) - def _op2(self, other, operator, native_function, op_name: str = 'unknown', op_symbol: str = '?'): + def _op2(self, other, op, switch_args): other = self._tensor(other) if self.requires_broadcast: if self._stack_dim.name in other.shape: other_slices = other._unstack(self._stack_dim.name) - tensors = [operator(t1, t2) for t1, t2 in zip(self._tensors, other_slices)] + tensors = [custom_op2(t1, t2, op, switch_args) for t1, t2 in zip(self._tensors, other_slices)] else: - tensors = [operator(t, other) for t in self._tensors] + tensors = [custom_op2(t, other, op, switch_args) for t in self._tensors] return TensorStack(tensors, self._stack_dim) elif isinstance(other, Dense) or (isinstance(other, TensorStack) and not other.requires_broadcast): - names, new_shape, (native1, native2) = broadcastable_native_tensors(self, other) # ToDo we don't have to expand all - native_result = native_function(native1, native2) + names, new_shape, (n1, n2) = broadcastable_native_tensors(self, other) # ToDo we don't have to expand all + if switch_args: + n1, n2 = n2, n1 + native_function = get_operator(op, self.backend) + native_result = native_function(n1, n2) return Dense(native_result, names, new_shape, backend_for(self, other)) elif isinstance(other, TensorStack) and other.requires_broadcast: if other._stack_dim.name in self.shape: self_slices = self._unstack(other._stack_dim.name) - tensors = [operator(t1, t2) for t1, t2 in zip(self_slices, other._tensors)] + tensors = [custom_op2(t1, t2, op, switch_args) for t1, t2 in zip(self_slices, other._tensors)] else: - tensors = [operator(self, t) for t in other._tensors] + tensors = [custom_op2(self, t, op, switch_args) for t in other._tensors] return TensorStack(tensors, self._stack_dim) else: return NotImplemented @@ -1743,9 +1718,9 @@ def tensor(data, return Dense(data, (), EMPTY_SHAPE, default_backend() if convert else NUMPY) if isinstance(data, (tuple, list)): if all(isinstance(d, (bool, int, float, complex, np.generic)) for d in data): - array = np.array(data) - assert array.dtype != object - data = array + data = np.array(data) + assert data.dtype != object + data = NUMPY.auto_cast1(data) elif all(isinstance(d, str) for d in data): return layout(data, shape or default_list_dim) else: @@ -1904,7 +1879,7 @@ def broadcastable_native_tensors(*tensors) -> Tuple[Sequence[str], Shape, Sequen return var_names, broadcast_shape, natives -def custom_op2(x: Union[Tensor, float], y: Union[Tensor, float], l_operator, l_native_function, r_operator=None, r_native_function=None, op_name: str = 'unknown', op_symbol: str = None) -> Tensor: +def custom_op2(x: Union[Tensor, float], y: Union[Tensor, float], op: Callable, switch_args=False) -> Tensor: """ Perform a custom operator on two tensors. This method first tries calling _op2() on the first tensor and if that fails, tries it on the second tensor. @@ -1912,27 +1887,26 @@ def custom_op2(x: Union[Tensor, float], y: Union[Tensor, float], l_operator, l_n Args: x: Left argument y: Right argument - l_operator: Operator function acting on Tensors - l_native_function: Operator function acting on natives - r_operator: Argument-reversed operator function acting on Tensors - r_native_function: Argument-reversed operator function acting on natives - op_name: Name of the operator function for debugging purposes. Leading 'r' will be added for the operand-reversed version. - op_symbol: Short name for the operator, independent of argument order. + op: Operator function taking two arguments. Should be contained in the Backend operator mapping. Returns: `Tensor` """ - if op_symbol is None: - op_symbol = op_name + if switch_args: + x, y = y, x + if isinstance(x, Tensor): + result = x._op2(y, op, False) + if result is not NotImplemented: + return result + elif isinstance(y, Tensor): + result = y._op2(x, op, True) + if result is not NotImplemented: + return result x = wrap(x) y = wrap(y) - result = x._op2(y, l_operator, l_native_function, op_name, op_symbol) + result = x._op2(y, op, False) if result is NotImplemented: - if r_operator is None: - r_operator = lambda a, b: l_operator(b, a) - if r_native_function is None: - r_native_function = lambda a, b: l_native_function(b, a) - result = y._op2(x, r_operator, r_native_function, f'r{op_name}', op_symbol) + result = y._op2(x, op, True) if result is NotImplemented: raise NotImplementedError(f"Operation not supported between {type(x)} and {type(y)}") return result diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index 8249710..53741d1 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -1,3 +1,4 @@ +import operator from collections import namedtuple from typing import Callable, Dict, Set, Tuple, Union, Any, Optional, Sequence, List, Collection @@ -180,18 +181,14 @@ def _op1(self, native_function): else: raise NotImplementedError('Only linear operations are supported') - def _op2(self, other: Tensor, - operator: Callable, - native_function: Callable, - op_name: str = 'unknown', - op_symbol: str = '?') -> Tensor: + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': if is_sparse(other): return NotImplemented if isinstance(other, SparseLinTracer): - return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol) - assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}" - zeros_for_missing_self = op_name not in ['add', 'radd', 'rsub'] # perform `operator` where `self == 0` - zeros_for_missing_other = op_name not in ['add', 'radd', 'sub'] # perform `operator` where `other == 0` + return to_sparse_tracer(self, other)._op2(other, op, switch_args) + assert op in {operator.add, operator.sub, operator.mul, operator.truediv}, f"Unsupported operation encountered while tracing linear function: {op}" + zeros_for_missing_self = op != operator.add and not (op == operator.sub and switch_args) # perform `operator` where `self == 0` + zeros_for_missing_other = op != operator.add and not (op == operator.sub and not switch_args) # perform `operator` where `other == 0` if isinstance(other, Tensor) and other._is_tracer: if not isinstance(other, ShiftLinTracer): raise NotImplementedError @@ -201,36 +198,36 @@ def _op2(self, other: Tensor, nz_edge = {} for dim_shift in self.val.keys(): if dim_shift in other.val: - values[dim_shift] = operator(self.val[dim_shift], other.val[dim_shift]) + values[dim_shift] = op(self.val[dim_shift], other.val[dim_shift]) nz_edge[dim_shift] = self._nz_edge[dim_shift] or other._nz_edge[dim_shift] else: if zeros_for_missing_other: - values[dim_shift] = operator(self.val[dim_shift], math.zeros_like(self.val[dim_shift])) + values[dim_shift] = op(self.val[dim_shift], math.zeros_like(self.val[dim_shift])) else: values[dim_shift] = self.val[dim_shift] nz_edge[dim_shift] = self._nz_edge[dim_shift] for dim_shift, other_values in other.val.items(): if dim_shift not in self.val: if zeros_for_missing_self: - values[dim_shift] = operator(math.zeros_like(other_values), other_values) + values[dim_shift] = op(math.zeros_like(other_values), other_values) else: values[dim_shift] = other_values nz_edge[dim_shift] = other._nz_edge[dim_shift] - bias = operator(self._bias, other._bias) + bias = op(self._bias, other._bias) return ShiftLinTracer(self._source, values, self._shape, bias, self._renamed, nz_edge) else: other = self._tensor(other) - if op_symbol in '*/': + if op in {operator.mul, operator.truediv}: values = {} for dim_shift, val in self.val.items(): - values[dim_shift] = operator(val, other) - bias = operator(self._bias, other) + values[dim_shift] = op(val, other) + bias = op(self._bias, other) return ShiftLinTracer(self._source, values, self._shape & other.shape, bias, self._renamed, self._nz_edge) - elif op_symbol in '+-': - bias = operator(self._bias, other) + elif op in {operator.add, operator.sub}: + bias = op(self._bias, other) return ShiftLinTracer(self._source, self.val, self._shape & other.shape, bias, self._renamed, self._nz_edge) else: - raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}") + raise ValueError(f"Unsupported operation encountered while tracing linear function: {op}") def _natives(self) -> tuple: """ @@ -370,34 +367,30 @@ def _op1(self, native_function): else: raise NotImplementedError('Only linear operations are supported') - def _op2(self, other: Tensor, - operator: Callable, - native_function: Callable, - op_name: str = 'unknown', - op_symbol: str = '?') -> Tensor: - assert op_symbol in '+-*/', f"Unsupported operation '{op_symbol}' encountered while tracing linear function: {native_function}" + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': + assert op in {operator.add, operator.sub, operator.mul, operator.truediv}, f"Unsupported operation '{op}' encountered while tracing linear function" if isinstance(other, ShiftLinTracer): other = other._to_gather_tracer() if isinstance(other, GatherLinTracer): - assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" + assert op in {operator.add, operator.sub}, f"Non-linear operation '{op}' cannot be converted to matrix" if not math.always_close(self._selection, other._selection): - return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol) - diag = operator(self._diag, other._diag) - bias = operator(self._bias, other._bias) + return to_sparse_tracer(self, other)._op2(other, op, switch_args) + diag = op(self._diag, other._diag) + bias = op(self._bias, other._bias) return GatherLinTracer(self._source, diag, bias, self._shape, self._selection, self._renamed) if isinstance(other, SparseLinTracer) or is_sparse(other): return NotImplemented else: other = self._tensor(other) - if op_symbol in '*/': - matrix = operator(self._diag, other) - bias = operator(self._bias, other) + if op in {operator.mul, operator.truediv}: + matrix = op(self._diag, other) + bias = op(self._bias, other) return GatherLinTracer(self._source, matrix, bias, self._shape & other.shape, self._selection, self._renamed) - elif op_symbol in '+-': - bias = operator(self._bias, other) + elif op in {operator.add, operator.sub}: + bias = op(self._bias, other) return GatherLinTracer(self._source, self._matrix, bias, self._shape & other.shape, self._selection, self._renamed) else: - raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}") + raise ValueError(f"Unsupported operation {op} encountered while tracing linear function") @property def _is_tracer(self) -> bool: @@ -522,35 +515,31 @@ def _op1(self, native_function): else: raise NotImplementedError('Only linear operations are supported') - def _op2(self, other, - operator: Callable, - native_function: Callable, - op_name: str = 'unknown', - op_symbol: str = '?') -> 'SparseLinTracer': + def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor': other = self._tensor(other) - assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}" + assert op in {operator.add, operator.sub, operator.mul, operator.truediv}, f"Unsupported operation {op} encountered while tracing linear function" if other._is_tracer and not isinstance(other, SparseLinTracer): other = to_sparse_tracer(other, self) if isinstance(other, SparseLinTracer): - assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix" - bias = operator(self._bias, other._bias) + assert op in {operator.add, operator.sub}, f"Non-linear operation '{op}' cannot be converted to matrix" + bias = op(self._bias, other._bias) matrix_dims = sparse_dims(self._matrix) & sparse_dims(other._matrix) self_matrix = expand_matrix(self._matrix, matrix_dims) other_matrix = expand_matrix(other._matrix, matrix_dims) - matrix = operator(self_matrix, other_matrix) # ToDo if other has no dependence on vector, it would also be in the output + matrix = op(self_matrix, other_matrix) # ToDo if other has no dependence on vector, it would also be in the output shape = self._shape & other._shape return SparseLinTracer(self._source, matrix, bias, shape) else: # other = self._tensor(other) - if op_symbol in '*/': - matrix = operator(self._matrix, other) - bias = operator(self._bias, other) + if op in {operator.mul, operator.truediv}: + matrix = op(self._matrix, other) + bias = op(self._bias, other) return SparseLinTracer(self._source, matrix, bias, self._shape & other.shape) - elif op_symbol in '+-': - bias = operator(self._bias, other) + elif op in {operator.add, operator.sub}: + bias = op(self._bias, other) return SparseLinTracer(self._source, self._matrix, bias, self._shape & other.shape) else: - raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}") + raise ValueError(f"Unsupported operation {op} encountered while tracing linear function") @property def _is_tracer(self) -> bool: