Skip to content

Commit

Permalink
Global variables for common DTypes
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 6, 2025
1 parent fe649af commit 1c33006
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 70 deletions.
13 changes: 6 additions & 7 deletions phiml/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import numpy as np
from numpy import ndarray

from ._dtype import DType, combine_types

from ._dtype import DType, combine_types, INT32, INT64

TensorType = TypeVar('TensorType')
TensorOrArray = Union[TensorType, np.ndarray]
Expand Down Expand Up @@ -187,7 +186,7 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list:
dtypes = [self.dtype(t) for t in tensors]
result_type = self.combine_types(*dtypes)
if result_type.kind == bool and bool_to_int:
result_type = DType(int, 32)
result_type = INT32
if result_type.kind == int and int_to_float:
result_type = DType(float, self.precision)
if result_type.kind in (int, float, complex, bool): # do not cast everything to string!
Expand Down Expand Up @@ -627,7 +626,7 @@ def nonzero(self, values, length=None, fill_value=-1):
def mean(self, value, axis=None, keepdims=False):
raise NotImplementedError(self)

def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
def range(self, start, limit=None, delta=1, dtype: DType = INT32):
raise NotImplementedError(self)

def zeros(self, shape, dtype: DType = None):
Expand Down Expand Up @@ -848,10 +847,10 @@ def to_float(self, x):
return self.cast(x, self.float_type)

def to_int32(self, x):
return self.cast(x, DType(int, 32))
return self.cast(x, INT32)

def to_int64(self, x):
return self.cast(x, DType(int, 64))
return self.cast(x, INT64)

def to_complex(self, x):
return self.cast(x, DType(complex, max(64, self.precision * 2)))
Expand Down Expand Up @@ -1090,7 +1089,7 @@ def argsort(self, x, axis=-1):
def sort(self, x, axis=-1):
raise NotImplementedError(self)

def searchsorted(self, sorted_sequence, search_values, side: str, dtype=DType(int, 32)):
def searchsorted(self, sorted_sequence, search_values, side: str, dtype=INT32):
raise NotImplementedError(self)

def fft(self, x, axes: Union[tuple, list]):
Expand Down
41 changes: 27 additions & 14 deletions phiml/backend/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def as_dtype(value: Union['DType', tuple, type, None]) -> Union['DType', None]:
if isinstance(value, DType):
return value
elif value is int:
return DType(int, 32)
return INT32
elif value is float:
from . import get_precision
return DType(float, get_precision())
Expand Down Expand Up @@ -122,28 +122,41 @@ def from_numpy_dtype(np_dtype) -> DType:
else:
for base_np_dtype, dtype in _FROM_NUMPY.items():
if np_dtype == base_np_dtype:
_FROM_NUMPY[np_dtype] = dtype
return dtype
if np_dtype.char == 'U':
return DType(str, 8 * np_dtype.itemsize)
raise ValueError(np_dtype)


BOOL = DType(bool)
INT8 = DType(int, 8)
INT16 = DType(int, 16)
INT32 = DType(int, 32)
INT64 = DType(int, 64)
FLOAT16 = DType(float, 16)
FLOAT32 = DType(float, 32)
FLOAT64 = DType(float, 64)
COMPLEX64 = DType(complex, 64)
COMPLEX128 = DType(complex, 128)
OBJECT = DType(object)

_TO_NUMPY = {
DType(float, 16): np.float16,
DType(float, 32): np.float32,
DType(float, 64): np.float64,
DType(complex, 64): np.complex64,
DType(complex, 128): np.complex128,
DType(int, 8): np.int8,
DType(int, 16): np.int16,
DType(int, 32): np.int32,
DType(int, 64): np.int64,
DType(bool): np.bool_,
DType(object): object,
BOOL: np.bool_,
INT8: np.int8,
INT16: np.int16,
INT32: np.int32,
INT64: np.int64,
FLOAT16: np.float16,
FLOAT32: np.float32,
FLOAT64: np.float64,
COMPLEX64: np.complex64,
COMPLEX128: np.complex128,
OBJECT: object,
}
_FROM_NUMPY = {np: dtype for dtype, np in _TO_NUMPY.items()}
_FROM_NUMPY[np.bool_] = DType(bool)
_FROM_NUMPY[bool] = DType(bool)
_FROM_NUMPY[np.bool_] = BOOL
_FROM_NUMPY[bool] = BOOL


def combine_types(*dtypes: DType, fp_precision: int = None) -> DType:
Expand Down
22 changes: 11 additions & 11 deletions phiml/backend/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.sparse.linalg import spsolve, LinearOperator

from ._backend import Backend, SolveResult, List, DType, spatial_derivative_evaluation, combined_dim, choose_backend, TensorType, Preconditioner, ML_LOGGER, convert, disassemble_dataclass
from ._dtype import to_numpy_dtype, combine_types
from ._dtype import to_numpy_dtype, combine_types, INT32, BOOL
from ._numpy_backend import NUMPY


Expand Down Expand Up @@ -65,8 +65,8 @@ def cg(b: Backend, lin, y, x0, rtol, atol, max_iter, pre: Optional[Preconditione
delta0 = b.sum(residual * dx, -1, keepdims=True)
delta0_tol = b.sum(residual_tol * dx_tol, -1, keepdims=True)
check_progress = stop_on_l2(b, abs(delta0_tol), rtol, atol, max_iter)
iterations = b.zeros([batch_size], DType(int, 32))
function_evaluations = b.ones([batch_size], DType(int, 32))
iterations = b.zeros([batch_size], INT32)
function_evaluations = b.ones([batch_size], INT32)
continue_, converged, diverged = check_progress(iterations, delta0)

def cg_loop_body(continue_, x, dx, delta, residual, iterations, function_evaluations, _converged, _diverged):
Expand Down Expand Up @@ -103,8 +103,8 @@ def cg_adaptive(b, lin, y, x0, rtol, atol, max_iter, pre: Optional[Preconditione
x = x0
dx = residual = y - linear(b, lin, x, matrix_offset)
dy = linear(b, lin, dx, matrix_offset)
iterations = b.zeros([batch_size], DType(int, 32))
function_evaluations = b.ones([batch_size], DType(int, 32))
iterations = b.zeros([batch_size], INT32)
function_evaluations = b.ones([batch_size], INT32)
residual_squared = b.sum(residual ** 2, -1, keepdims=True)
check_progress = stop_on_l2(b, b.sum(y ** 2, -1), rtol, atol, max_iter)
continue_, converged, diverged = check_progress(iterations, residual_squared)
Expand Down Expand Up @@ -142,8 +142,8 @@ def bicg(b: Backend, lin, y, x0, rtol, atol, max_iter, pre: Optional[Preconditio
x = b.copy(b.to_float(x0), only_mutable=True)
batch_size = b.staticshape(y)[0]
r0_tild = residual = y - linear(b, lin, x, matrix_offset)
iterations = b.zeros([batch_size], DType(int, 32))
function_evaluations = b.ones([batch_size], DType(int, 32))
iterations = b.zeros([batch_size], INT32)
function_evaluations = b.ones([batch_size], INT32)
residual_squared = b.sum(residual ** 2, -1, keepdims=True)
check_progress = stop_on_l2(b, b.sum(y ** 2, -1), rtol, atol, max_iter)
continue_, converged, diverged = check_progress(iterations, residual_squared)
Expand Down Expand Up @@ -235,8 +235,8 @@ def bicg_stab_first_order(b: Backend, lin, y, x0, rtol, atol, max_iter, pre: Opt
batch_size = b.staticshape(y)[0]
residual = y - b.linear(lin, x)
r0_h = b.ones(x0.shape)
iterations = b.zeros([batch_size], DType(int, 32))
function_evaluations = b.ones([batch_size], DType(int, 32))
iterations = b.zeros([batch_size], INT32)
function_evaluations = b.ones([batch_size], INT32)
residual_squared = b.sum(residual ** 2, -1, keepdims=True)
check_progress = stop_on_l2(b, b.sum(y ** 2, -1), rtol, atol, max_iter)
continue_, converged, diverged = check_progress(iterations, residual_squared)
Expand Down Expand Up @@ -336,8 +336,8 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors):
npr = scipy_iterative_sparse_solve(NUMPY, lin, np_y, np_x0, np_rtol, np_atol, max_iter, np_pre, function)
return npr.x, npr.residual, npr.iterations, npr.function_evaluations, npr.converged, npr.diverged
fp = b.float_type
i = DType(int, 32)
bo = DType(bool)
i = INT32
bo = BOOL
x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, x0.shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors)
if was_row_added:
residual = residual[:, :-1]
Expand Down
7 changes: 4 additions & 3 deletions phiml/backend/_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy

from ._backend import Backend, SolveResult, DType, ML_LOGGER
from ._dtype import BOOL, INT32
from ._linalg import _max_iter


Expand Down Expand Up @@ -119,16 +120,16 @@ def gradient_descent(self: Backend, f, x0, atol, max_iter, trj: bool, step_size=
fg = self.jacobian(f, [0], get_output=True, is_f_scalar=True)
method = f"Gradient descent with {self.name}"

iterations = self.zeros([batch_size], DType(int, 32))
function_evaluations = self.ones([batch_size], DType(int, 32))
iterations = self.zeros([batch_size], INT32)
function_evaluations = self.ones([batch_size], INT32)

adaptive_step_size = step_size == 'adaptive'
if adaptive_step_size:
step_size = self.zeros([batch_size]) + 0.1

loss, grad = fg(x0) # Evaluate function and gradient
diverged = self.any(~self.isfinite(x0), axis=(1,))
converged = self.zeros([batch_size], DType(bool))
converged = self.zeros([batch_size], BOOL)
trajectory = [SolveResult(method, self.numpy(x0), self.numpy(loss), self.numpy(iterations), self.numpy(function_evaluations), self.numpy(converged), self.numpy(diverged), [""] * batch_size)] if trj else None
max_iter_ = self.to_int32(max_iter)
continue_ = ~converged & ~diverged & (iterations < max_iter_)
Expand Down
14 changes: 7 additions & 7 deletions phiml/backend/_numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import Backend, ComputeDevice
from ._backend import combined_dim, SolveResult, TensorType
from ._dtype import from_numpy_dtype, to_numpy_dtype, DType
from ._dtype import from_numpy_dtype, to_numpy_dtype, DType, FLOAT64, BOOL, COMPLEX128, INT32


class NumPyBackend(Backend):
Expand Down Expand Up @@ -202,7 +202,7 @@ def random_normal(self, shape, dtype: DType):
def random_permutations(self, permutations: int, n: int):
return np.stack([np.random.permutation(n) for _ in range(permutations)])

def range(self, start, limit=None, delta=1, dtype: DType = DType(int, 32)):
def range(self, start, limit=None, delta=1, dtype: DType = INT32):
if limit is None:
start, limit = 0, start
return np.arange(start, limit, delta, to_numpy_dtype(dtype))
Expand Down Expand Up @@ -413,7 +413,7 @@ def argsort(self, x, axis=-1):
def sort(self, x, axis=-1):
return np.sort(x, axis)

def searchsorted(self, sorted_sequence, search_values, side: str, dtype=DType(int, 32)):
def searchsorted(self, sorted_sequence, search_values, side: str, dtype=INT32):
if self.ndims(sorted_sequence) == 1:
return np.searchsorted(sorted_sequence, search_values, side=side).astype(to_numpy_dtype(dtype))
else:
Expand Down Expand Up @@ -442,13 +442,13 @@ def ifft(self, k, axes: Union[tuple, list]):

def dtype(self, array) -> DType:
if isinstance(array, bool):
return DType(bool)
return BOOL
if isinstance(array, int):
return DType(int, 32)
return INT32
if isinstance(array, float):
return DType(float, 64)
return FLOAT64
if isinstance(array, complex):
return DType(complex, 128)
return COMPLEX128
if not hasattr(array, 'dtype'):
array = np.array(array)
return from_numpy_dtype(array.dtype)
Expand Down
12 changes: 6 additions & 6 deletions phiml/backend/jax/_jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if version.parse(jax.__version__) >= version.parse('0.2.20'):
from jax.experimental.sparse import BCOO, COO, CSR, CSC

from .._dtype import DType, to_numpy_dtype, from_numpy_dtype
from .._dtype import DType, to_numpy_dtype, from_numpy_dtype, COMPLEX128, FLOAT64, INT32, BOOL
from .._backend import Backend, ComputeDevice, combined_dim, ML_LOGGER, TensorType, map_structure

jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -563,7 +563,7 @@ def scatter_single(base_grid, indices, values):
result = self.vectorized_call(scatter_single, base_grid, indices, values)
if self.dtype(result).kind != out_kind:
if out_kind == bool:
result = self.cast(result, DType(bool))
result = self.cast(result, BOOL)
return result

def histogram1d(self, values, weights, bin_edges):
Expand Down Expand Up @@ -619,13 +619,13 @@ def ifft(self, k, axes: Union[tuple, list]):

def dtype(self, array) -> DType:
if isinstance(array, bool):
return DType(bool)
return BOOL
if isinstance(array, int):
return DType(int, 32)
return INT32
if isinstance(array, float):
return DType(float, 64)
return FLOAT64
if isinstance(array, complex):
return DType(complex, 128)
return COMPLEX128
if not isinstance(array, jnp.ndarray):
array = jnp.array(array)
return from_numpy_dtype(array.dtype)
Expand Down
10 changes: 5 additions & 5 deletions phiml/backend/tensorflow/_tf_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.python.framework.errors_impl import NotFoundError

from .._backend import combined_dim, TensorType
from .._dtype import DType, to_numpy_dtype, from_numpy_dtype
from .._dtype import DType, to_numpy_dtype, from_numpy_dtype, BOOL
from .. import Backend, ComputeDevice, NUMPY
from ._tf_cuda_resample import resample_cuda, use_cuda

Expand Down Expand Up @@ -563,19 +563,19 @@ def boolean_mask(self, x, mask, axis=0, new_length=None, fill_value=0):

def isfinite(self, x):
if self.dtype(x).kind in (bool, int):
return self.ones(self.shape(x), dtype=DType(bool))
return self.ones(self.shape(x), dtype=BOOL)
with self.device_of(x):
return tf.math.is_finite(x)

def isnan(self, x):
if self.dtype(x).kind in (bool, int):
return self.zeros(self.shape(x), dtype=DType(bool))
return self.zeros(self.shape(x), dtype=BOOL)
with self.device_of(x):
return tf.math.is_nan(x)

def isinf(self, x):
if self.dtype(x).kind in (bool, int):
return self.zeros(self.shape(x), dtype=DType(bool))
return self.zeros(self.shape(x), dtype=BOOL)
with self.device_of(x):
return tf.math.is_inf(x)

Expand Down Expand Up @@ -625,7 +625,7 @@ def scatter_single(b_grid, b_indices, b_values):
result = self.vectorized_call(scatter_single, base_grid, indices, values, output_dtypes=self.dtype(base_grid))
if self.dtype(result).kind != out_kind:
if out_kind == bool:
result = self.cast(result, DType(bool))
result = self.cast(result, BOOL)
return result

def histogram1d(self, values, weights, bin_edges):
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..backend import Backend
from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER
from ..backend._buffer import set_buffer_config, get_buffer_config, get_required_buffer_sizes, wasted_memory
from ..backend._dtype import DType
from ..backend._dtype import DType, FLOAT64

X = TypeVar('X')
Y = TypeVar('Y')
Expand Down Expand Up @@ -1408,4 +1408,4 @@ def perf_counter(wait_for_tensor, *wait_for_tensors: Tensor) -> Tensor:
assert natives, f"in jit mode, perf_counter must be given at least one traced tensor, as the current time is evaluated after all tensors are computed."
def perf_counter(*_wait_for_natives):
return np.asarray(time.perf_counter())
return wrap(backend.numpy_call(perf_counter, (), DType(float, 64), *natives))
return wrap(backend.numpy_call(perf_counter, (), FLOAT64, *natives))
6 changes: 3 additions & 3 deletions phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .extrapolation import Extrapolation
from .magic import PhiTreeNode
from ..backend import choose_backend
from ..backend._dtype import DType
from ..backend._dtype import INT64


def vec(name: Union[str, Shape] = 'vector', *sequence, tuple_dim=spatial('sequence'), list_dim=instance('sequence'), **components) -> Tensor:
Expand Down Expand Up @@ -905,13 +905,13 @@ def find_closest(vectors: Tensor, query: Tensor, method='kd', index_dim=channel(
kd_tree = KDTree(vectors[i].numpy([..., channel]))
def perform_query(np_query):
return kd_tree.query(np_query)[1]
native_idx = query.default_backend.numpy_call(perform_query, (query_i.shape.non_channel.volume,), DType(int, 64), native_query)
native_idx = query.default_backend.numpy_call(perform_query, (query_i.shape.non_channel.volume,), INT64, native_query)
else:
b = backend_for(vectors, query)
native_vectors = vectors[i].native([..., channel])
def perform_query(np_vectors, np_query):
return KDTree(np_vectors).query(np_query)[1]
native_idx = b.numpy_call(perform_query, (query.shape.without(batch(vectors)).non_channel.volume,), DType(int, 64), native_vectors, native_query)
native_idx = b.numpy_call(perform_query, (query.shape.without(batch(vectors)).non_channel.volume,), INT64, native_vectors, native_query)
native_multi_idx = choose_backend(native_idx).unravel_index(native_idx, after_gather(vectors.shape, i).non_channel.sizes)
result.append(reshaped_tensor(native_multi_idx, [query_i.shape.non_channel, index_dim or math.EMPTY_SHAPE]))
return stack(result, batch(vectors))
Loading

0 comments on commit 1c33006

Please sign in to comment.