Skip to content

Commit

Permalink
Refactor Tensor._op2
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 6, 2025
1 parent b30a72f commit b1319dd
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 224 deletions.
48 changes: 43 additions & 5 deletions phiml/backend/_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import logging
import operator
import sys
import warnings
from builtins import ValueError
Expand All @@ -13,6 +14,7 @@
import numpy as np
from numpy import ndarray

from . import xops
from ._dtype import DType, combine_types, INT32, INT64

TensorType = TypeVar('TensorType')
Expand Down Expand Up @@ -194,15 +196,15 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list:
return tensors

def auto_cast1(self, tensor):
if isinstance(tensor, (bool, Number)):
if isinstance(tensor, (bool, int, float, complex)):
return tensor
dtype = self.dtype(tensor)
if dtype.kind in {int, bool}:
return tensor
result_type = combine_types(dtype, fp_precision=self.precision)
if result_type.bits == dtype.bits:
return tensor
return self.cast(tensor, result_type)
if dtype.precision != get_precision():
result_type = DType(dtype.kind, precision=get_precision())
return self.cast(tensor, result_type)
return tensor

def __str__(self):
return self.name
Expand Down Expand Up @@ -2114,3 +2116,39 @@ def assemble(b: Backend, *args):
all_values = {f.name: re_values[f.name] if f.name in tensor_fields else getattr(data, f.name) for f in fields}
return type(data)(**all_values)
return assemble, tensors


_BACKEND_OPERATORS = {
operator.eq: Backend.equal,
operator.ne: Backend.not_equal,
operator.gt: Backend.greater_than,
operator.ge: Backend.greater_or_equal,
operator.add: Backend.add,
operator.sub: Backend.sub,
operator.mul: Backend.mul,
operator.truediv: Backend.div,
operator.pow: Backend.pow,
operator.mod: Backend.mod,
operator.and_: Backend.and_,
operator.or_: Backend.or_,
operator.xor: Backend.xor,
operator.floordiv: Backend.floordiv,
operator.lshift: Backend.shift_bits_left,
operator.rshift: Backend.shift_bits_right,
operator.inv: Backend.invert,
operator.invert: Backend.invert,
divmod: divmod,
abs: Backend.abs,
xops.save_div: Backend.divide_no_nan,
xops.gamma_inc_l: Backend.gamma_inc_l,
xops.gamma_inc_u: Backend.gamma_inc_u,
xops.arctan2: Backend.arctan2,
xops.minimum: Backend.minimum,
xops.maximum: Backend.maximum,
}

def get_operator(op: Callable, backend: Backend):
fun = _BACKEND_OPERATORS.get(op)
if fun is not None:
return getattr(backend, fun.__name__)
return fun
33 changes: 33 additions & 0 deletions phiml/backend/xops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Extra operators.
This module is an extension to the built-in operators.
"""


class ExtraOperator(Exception):
pass


def save_div(numerator, denominator):
raise ExtraOperator


def gamma_inc_l(a, x):
raise ExtraOperator


def gamma_inc_u(a, x):
raise ExtraOperator


def arctan2(tan, divide_by):
raise ExtraOperator


def minimum(x1, x2):
raise ExtraOperator


def maximum(x1, x2):
raise ExtraOperator
23 changes: 9 additions & 14 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..backend import default_backend, choose_backend, Backend, get_precision, convert as b_convert, BACKENDS, NoBackendFound, ComputeDevice, NUMPY
from ..backend._dtype import DType, combine_types, INT32
from phiml.backend import xops
from .magic import PhiTreeNode
from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze, ipack
from ._shape import (Shape, EMPTY_SHAPE,
Expand Down Expand Up @@ -2137,9 +2138,9 @@ def tensor_dot(x, y):
if is_sparse(x) or is_sparse(y):
if x_dims.isdisjoint(sparse_dims(x)) and y_dims.isdisjoint(sparse_dims(y)):
if is_sparse(x):
return x._op2(y, lambda vx, vy: dot(vx, x_dims, vy, y_dims), None, 'dot', '@')
return x._op2(y, lambda vx, vy: dot(vx, x_dims, vy, y_dims), False)
else:
return y._op2(x, lambda vy, vx: dot(vx, x_dims, vy, y_dims), None, 'dot', '@')
return y._op2(x, lambda vy, vx: dot(vx, x_dims, vy, y_dims), False)
else:
return sparse_dot(x, x_dims, y, y_dims)
if x._is_tracer:
Expand Down Expand Up @@ -2289,11 +2290,10 @@ def incomplete_gamma(a: TensorOrTree, x: TensorOrTree, upper=False, regularized=
upper: Whether to complete the upper integral (x to infinity) or the lower integral (0 to x).
regularized: Whether the integral is divided by Γ(a).
"""
call = lambda a, x: incomplete_gamma(a, x, upper=upper, regularized=regularized)
if upper:
reg = custom_op2(a, x, call, lambda a, x: choose_backend(a, x).gamma_inc_u(a, x), 'gamma_inc_u')
reg = custom_op2(a, x, xops.gamma_inc_u)
else:
reg = custom_op2(a, x, call, lambda a, x: choose_backend(a, x).gamma_inc_l(a, x), 'gamma_inc_l')
reg = custom_op2(a, x, xops.gamma_inc_l)
return reg if regularized else reg * exp(log_gamma(a))


Expand Down Expand Up @@ -2464,7 +2464,7 @@ def arctan(x: TensorOrTree, divide_by=None) -> TensorOrTree:
return _backend_op1(x, Backend.arctan)
else:
divide_by = to_float(divide_by)
return custom_op2(x, divide_by, arctan, lambda a, b: choose_backend(a, b).arctan2(a, b), 'arctan')
return custom_op2(x, divide_by, xops.arctan2)


def angle(x: TensorOrTree) -> TensorOrTree:
Expand Down Expand Up @@ -2560,12 +2560,7 @@ def cast_same(*values: Tensor) -> Tuple[Tensor]:

def safe_div(x: Union[Number, Tensor], y: Union[Number, Tensor]):
""" Computes *x/y* with the `Tensor`s `x` and `y` but returns 0 where *y=0*. """
return custom_op2(x, y,
l_operator=safe_div,
l_native_function=lambda x_, y_: choose_backend(x_, y_).divide_no_nan(x_, y_),
r_operator=lambda y_, x_: safe_div(x_, y_),
r_native_function=lambda y_, x_: choose_backend(x_, y_).divide_no_nan(x_, y_),
op_name='divide_no_nan')
return custom_op2(x, y, xops.save_div)


def maximum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
Expand All @@ -2575,7 +2570,7 @@ def maximum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
return y
elif y is None:
return x
return custom_op2(x, y, maximum, lambda x_, y_: choose_backend(x_, y_).maximum(x_, y_), op_name='maximum')
return custom_op2(x, y, xops.maximum)


def minimum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
Expand All @@ -2585,7 +2580,7 @@ def minimum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
return y
elif y is None:
return x
return custom_op2(x, y, minimum, lambda x_, y_: choose_backend(x_, y_).minimum(x_, y_), op_name='minimum')
return custom_op2(x, y, xops.minimum)


def clip(x: Tensor, lower_limit: Union[float, Tensor] = 0, upper_limit: Union[float, Tensor, Shape] = 1):
Expand Down
53 changes: 26 additions & 27 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
import warnings
from functools import partial
from numbers import Number
Expand All @@ -11,7 +12,7 @@
from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim, unstack
from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, \
concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal, concat_shapes_
from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for
from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for, custom_op2
from ..backend import choose_backend, NUMPY, Backend, get_precision
from ..backend._dtype import DType, INT64

Expand Down Expand Up @@ -367,20 +368,20 @@ def _with_shape_replaced(self, new_shape: Shape):
def _op1(self, native_function):
return self._with_values(self._values._op1(native_function))

def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor':
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
other_shape = shape(other)
affects_only_values = self._dense_shape.isdisjoint(other_shape)
if affects_only_values:
return self._with_values(operator(self._values, other))
return self._with_values(op(self._values, other))
if isinstance(other, CompressedSparseMatrix):
other = other.decompress()
if isinstance(other, SparseCoordinateTensor):
if same_sparsity_pattern(self, other):
return self._with_values(operator(self._values, other._values))
return self._with_values(op(self._values, other._values))
else:
if op_name not in ['add', 'radd', 'sub', 'rsub']:
if op not in {operator.add, operator.sub}:
same_sparsity_pattern(self, other) # debug checkpoint
raise AssertionError(f"Operation '{op_symbol}' ({op_name}) requires sparse matrices with the same sparsity pattern.")
raise AssertionError(f"Operation '{op}' requires sparse matrices with the same sparsity pattern.")
all_sparse_dims = sparse_dims(other) & sparse_dims(self)
self_indices = pack_dims(self._indices, instance, instance('sp_entries'))
other_indices = pack_dims(other._indices, instance, instance('sp_entries'))
Expand All @@ -389,21 +390,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
self_indices, self_values = with_sparsified_dim(self_indices, self_values, all_sparse_dims)
other_indices, other_values = with_sparsified_dim(other_indices, other_values, all_sparse_dims)
indices = concat([self_indices, other_indices], 'sp_entries')
if op_symbol == '+':
if op == operator.add:
values = concat([self_values, other_values], instance(self_values), expand_values=True)
elif op_name == 'sub':
values = concat([self_values, -other_values], instance(self_values), expand_values=True)
else: # op_name == 'rsub':
values = concat([-self_values, other_values], instance(self_values), expand_values=True)
else:
values = concat([-self_values, other_values] if switch_args else [self_values, -other_values], instance(self_values), expand_values=True)
return SparseCoordinateTensor(indices, values, self._dense_shape & other._dense_shape, can_contain_double_entries=True, indices_sorted=False, indices_constant=self._indices_constant)
else: # other is dense
if self._dense_shape in other.shape: # all dims dense -> convert to dense
return dense(self)._op2(other, operator, native_function, op_name, op_symbol)
return dense(self)._op2(other, op, switch_args)
else: # only some dims dense -> stay sparse
dense_dims = self._dense_shape.only(other.shape)
assert instance(other).without(self._dense_shape).is_empty, f"Instance dims cannot be added to sparse tensors from sparse-dense operations but got {other.shape} for sparse tensor {self.shape}"
other_values = other[self._indices.sparse_idx[dense_dims.name_list]]
values = operator(self._values, other_values)
values = custom_op2(self._values, other_values, op, switch_args)
return self._with_values(values)

def _getitem(self, selection: dict) -> 'Tensor':
Expand Down Expand Up @@ -702,19 +701,19 @@ def __expand__(self, dims: Shape, **kwargs) -> 'Tensor':
def _op1(self, native_function):
return self._with_values(self._values._op1(native_function))

def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor':
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
other_shape = shape(other)
affects_only_values = self.sparse_dims.isdisjoint(other_shape) and non_instance(self._indices).isdisjoint(other_shape)
if affects_only_values:
return self._with_values(operator(self._values, other))
return self._with_values(custom_op2(self._values, other, op, switch_args))
elif isinstance(other, CompressedSparseMatrix):
if same_sparsity_pattern(self, other):
result = operator(self._values, other._values)
result = op(self._values, other._values)
if self._uncompressed_offset is not None:
from ._ops import where
result = where(self._valid_mask(), result, 0)
return self._with_values(result)
elif op_symbol == '+':
elif op == operator.add:
raise NotImplementedError("Compressed addition not yet implemented")
else:
# convert to COO, then perform operation
Expand All @@ -723,19 +722,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
from ._ops import gather, boolean_mask, clip, where
if self._uncompressed_offset is None:
other_values = gather(other, self._indices, self._uncompressed_dims)
return self._with_values(operator(self._values, other_values))
return self._with_values(op(self._values, other_values))
# if bake_slice:
# baked = self._bake_slice()
# other_values = gather(other, baked._indices, self._uncompressed_dims)
# return baked._with_values(operator(baked._values, other_values))
indices = clip(self._indices - self._uncompressed_offset, 0, self._uncompressed_dims.volume - 1)
other_values = gather(other, indices, self._uncompressed_dims)
return self._with_values(where(self._valid_mask(), operator(self._values, other_values), 0))
return self._with_values(where(self._valid_mask(), op(self._values, other_values), 0))
elif self._compressed_dims in other_shape and self._uncompressed_dims.isdisjoint(other_shape):
from ._ops import gather, boolean_mask, clip, where
row_indices, _ = self._coo_indices('clamp')
other_values = gather(other, row_indices, self._compressed_dims)
result_values = operator(self._values, other_values)
result_values = op(self._values, other_values)
if self._uncompressed_offset is not None:
result_values = where(self._valid_mask(), result_values, 0)
return self._with_values(result_values)
Expand Down Expand Up @@ -960,16 +959,16 @@ def _with_shape_replaced(self, new_shape: Shape):
def _op1(self, native_function):
return self._with_values(self._values._op1(native_function))

def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor':
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
other_shape = shape(other)
affects_only_values = self._compressed_dims.isdisjoint(other_shape)
if affects_only_values:
return self._with_values(operator(self._values, other))
return self._with_values(op(self._values, other))
elif isinstance(other, (CompressedSparseMatrix, CompactSparseTensor)):
if same_sparsity_pattern(self, other):
result = operator(self._values, other._values)
result = op(self._values, other._values)
return self._with_values(result)
elif op_symbol == '+':
elif op == operator.add:
raise NotImplementedError("Compressed addition not yet implemented")
else:
# convert to COO, then perform operation
Expand All @@ -978,18 +977,18 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
from ._ops import gather, boolean_mask, clip, where
if self._uncompressed_offset is None:
other_values = gather(other, self._indices, self._uncompressed_dims)
return self._with_values(operator(self._values, other_values))
return self._with_values(op(self._values, other_values))
# if bake_slice:
# baked = self._bake_slice()
# other_values = gather(other, baked._indices, self._uncompressed_dims)
# return baked._with_values(operator(baked._values, other_values))
indices = clip(self._indices - self._uncompressed_offset, 0, self._uncompressed_dims.volume - 1)
other_values = gather(other, indices, self._uncompressed_dims)
return self._with_values(where(self._valid_mask(), operator(self._values, other_values), 0))
return self._with_values(where(self._valid_mask(), op(self._values, other_values), 0))
elif self._compressed_dims in other_shape and self._uncompressed_dims.isdisjoint(other_shape):
from ._ops import gather, boolean_mask, clip, where
other_values = gather(other, self._indices, self._compressed_dims)
result_values = operator(self._values, other_values)
result_values = op(self._values, other_values)
return self._with_values(result_values)
else:
raise NotImplementedError
Expand Down
Loading

0 comments on commit b1319dd

Please sign in to comment.