diff --git a/phiml/backend/_dtype.py b/phiml/backend/_dtype.py index 49d95e7d..b537a001 100644 --- a/phiml/backend/_dtype.py +++ b/phiml/backend/_dtype.py @@ -44,7 +44,7 @@ def __init__(self, kind: type, bits: int = None, precision: int = None): else: bits = precision * 2 else: - assert isinstance(bits, int) + assert isinstance(bits, int), f"bits must be an int but got {type(bits)}" self.kind = kind """ Python class corresponding to the type of data, ignoring precision. One of (bool, int, float, complex, str) """ self.bits = bits diff --git a/phiml/math/__init__.py b/phiml/math/__init__.py index 2d345b73..1cb90fe5 100644 --- a/phiml/math/__init__.py +++ b/phiml/math/__init__.py @@ -92,6 +92,7 @@ trace_check, map_ as map, when_available, + perf_counter, ) from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index 1337384c..a883976b 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -1,4 +1,5 @@ import inspect +import time import types import warnings from functools import wraps, partial @@ -15,6 +16,7 @@ from ..backend import Backend, NUMPY from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER from .magic import PhiTreeNode, Shapable +from ..backend._dtype import DType X = TypeVar('X') Y = TypeVar('Y') @@ -158,6 +160,7 @@ def __init__(self, f: Callable, auxiliary_args: Set[str], forget_traces: bool): self.grad_jit = GradientFunction(f.f, self.f_params, f.wrt, f.get_output, f.is_f_scalar, jit=True) if isinstance(f, GradientFunction) else None self._extract_tensors: List[Tuple[Tensor]] = [] self._post_call: List[Callable] = [] + self._tracing_in_key: SignatureKey = None def _jit_compile(self, in_key: SignatureKey): ML_LOGGER.debug(f"Φ-ML-jit: '{f_name(self.f)}' called with new key. shapes={[s.volume for s in in_key.shapes]}, args={in_key.tree}") @@ -165,6 +168,7 @@ def _jit_compile(self, in_key: SignatureKey): def jit_f_native(*natives): ML_LOGGER.debug(f"Φ-ML-jit: Tracing '{f_name(self.f)}'") _TRACING_JIT.append(self) + self._tracing_in_key = in_key in_tensors = assemble_tensors(natives, in_key.specs) kwargs = assemble_tree(in_key.tree, in_tensors) f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors @@ -172,6 +176,7 @@ def jit_f_native(*natives): result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing) assert _TRACING_JIT.pop(-1) is self + self._tracing_in_key = None return result_natives jit_f_native.__name__ = f"native({f_name(self.f) if isinstance(self.f, types.FunctionType) else str(self.f)})" @@ -1206,3 +1211,28 @@ def identity(x): `x` """ return x + + +def perf_counter(wait_for_tensor, *wait_for_tensors: Tensor) -> Tensor: + """ + Get the time (`time.perf_counter()`) at which all `wait_for_tensors` are computed. + If all tensors are already available, returns the current `time.perf_counter()`. + + Args: + wait_for_tensor: `Tensor` that need to be computed before the time is measured. + *wait_for_tensors: Additional tensors that need to be computed before the time is measured. + + Returns: + Time at which all `wait_for_tensors` are ready as a scalar `Tensor`. + """ + assert not _TRACING_LINEAR, f"Cannot use perf_counter inside a function decorated with @jit_compile_linear" + if not _TRACING_JIT: + return wrap(time.perf_counter()) + else: # jit + backend = _TRACING_JIT[0]._tracing_in_key.backend + natives, _, _ = disassemble_tensors([wait_for_tensor, *wait_for_tensors], expand=False) + natives = [n for n in natives if backend.is_tensor(n, only_native=True)] + assert natives, f"in jit mode, perf_counter must be given at least one traced tensor, as the current time is evaluated after all tensors are computed." + def perf_counter(*_wait_for_natives): + return np.asarray(time.perf_counter()) + return wrap(backend.numpy_call(perf_counter, (), DType(float, 64), *natives)) diff --git a/tests/commit/math/test__functional.py b/tests/commit/math/test__functional.py index 0c9808b5..fd6bdb7a 100644 --- a/tests/commit/math/test__functional.py +++ b/tests/commit/math/test__functional.py @@ -291,3 +291,20 @@ def print_x(x): assert CONCRETE assert CONCRETE[0].available assert TRACER + + def test_perf_counter(self): + @math.jit_compile + def fun(x): + t0 = math.perf_counter(x) + print("fun called, time=", t0) + for i in range(1000): + x *= 0.5 + x += 1 + dt = math.perf_counter(x) - t0 + return x, dt + + for backend in BACKENDS: + with backend: + result, exec_time = fun(tensor(0)) + print("time taken", 1000_000 * float(exec_time), "result:", result) +