diff --git a/phiml/math/__init__.py b/phiml/math/__init__.py index fc1444ca..2d345b73 100644 --- a/phiml/math/__init__.py +++ b/phiml/math/__init__.py @@ -91,6 +91,7 @@ identity, trace_check, map_ as map, + when_available, ) from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index e8b34f3d..1337384c 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -6,7 +6,7 @@ import numpy as np -from . import _ops as math +from . import _ops as math, all_available from ._magic_ops import stack, pack_dims, expand, unpack_dim from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel, merge_shapes, DimFilter, shape from ._sparse import SparseCoordinateTensor @@ -156,6 +156,8 @@ def __init__(self, f: Callable, auxiliary_args: Set[str], forget_traces: bool): self.traces: Dict[SignatureKey, Callable] = {} self.recorded_mappings: Dict[SignatureKey, SignatureKey] = {} self.grad_jit = GradientFunction(f.f, self.f_params, f.wrt, f.get_output, f.is_f_scalar, jit=True) if isinstance(f, GradientFunction) else None + self._extract_tensors: List[Tuple[Tensor]] = [] + self._post_call: List[Callable] = [] def _jit_compile(self, in_key: SignatureKey): ML_LOGGER.debug(f"Φ-ML-jit: '{f_name(self.f)}' called with new key. shapes={[s.volume for s in in_key.shapes]}, args={in_key.tree}") @@ -165,8 +167,8 @@ def jit_f_native(*natives): _TRACING_JIT.append(self) in_tensors = assemble_tensors(natives, in_key.specs) kwargs = assemble_tree(in_key.tree, in_tensors) - result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors - tree, out_tensors = disassemble_tree(result) + f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors + tree, out_tensors = disassemble_tree((f_output, self._extract_tensors)) result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing) assert _TRACING_JIT.pop(-1) is self @@ -198,7 +200,14 @@ def __call__(self, *args, **kwargs): native_result = self.traces[key](*natives) output_key = match_output_signature(key, self.recorded_mappings, self) output_tensors = assemble_tensors(native_result, output_key.specs) - return assemble_tree(output_key.tree, output_tensors) + output, extracted_tensor_lists = assemble_tree(output_key.tree, output_tensors) + for extracted_tensors, runnable in zip(extracted_tensor_lists, self._post_call): + runnable(*extracted_tensors) + return output + + def extract_and_call(self, tensors: Tuple[Tensor], runnable: Callable): + self._extract_tensors.append(tensors) + self._post_call.append(runnable) def __repr__(self): return f"jit({f_name(self.f)})" @@ -208,9 +217,6 @@ def __name__(self): return f_name(self.f) -_TRACING_JIT: List[JitFunction] = [] - - def jit_compile(f: Callable = None, auxiliary_args: str = '', forget_traces: bool = None) -> Callable: """ Compiles a graph based on the function `f`. @@ -297,7 +303,9 @@ def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict): else: if self.forget_traces: self.matrices_and_biases.clear() + _TRACING_LINEAR.append(self) matrix, bias = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True) + assert _TRACING_JIT.pop(-1) is self if not key.tracing: self.matrices_and_biases[key] = matrix, bias if len(self.matrices_and_biases) >= 4: @@ -406,6 +414,30 @@ def my_linear_function(x: math.Tensor) -> math.Tensor: return f if isinstance(f, LinearFunction) and f.auxiliary_args == auxiliary_args else LinearFunction(f, auxiliary_args, forget_traces or False) +_TRACING_JIT: List[JitFunction] = [] +_TRACING_LINEAR: List[LinearFunction] = [] + + +def when_available(runnable: Callable, *tensor_args: Tensor): + """ + Calls `runnable(*tensor_args)` once the concrete values of all tensors are available. + In eager mode, `runnable` is called immediately. + When jit-compiled, `runnable` is called after the jit-compiled function has returned. + + Args: + runnable: Function to call as `runnable(*tensor_args)`. This can be a `lambda` function. + *tensor_args: `Tensor` values to pass to `runnable` with concrete values. + """ + if _TRACING_LINEAR: + raise RuntimeError(f"when_available() cannot be called inside a function marked as @jit_compile_linear") + if all_available(*tensor_args): # eager or NumPy + runnable(*tensor_args) + else: + assert _TRACING_JIT, f"tensors are not available but no JIT function is being traced. Maybe you are using external jit?" + for jit_f in _TRACING_JIT: + jit_f.extract_and_call(tensor_args, runnable) + + def simplify_wrt(f, wrt: Union[str, int, tuple, list]): f_params = function_parameters(f) if wrt is None: # Old default diff --git a/tests/commit/math/test__functional.py b/tests/commit/math/test__functional.py index 4f4f0203..0c9808b5 100644 --- a/tests/commit/math/test__functional.py +++ b/tests/commit/math/test__functional.py @@ -1,11 +1,12 @@ import time from functools import partial +from typing import List from unittest import TestCase from phiml import math from phiml.backend import Backend from phiml.backend._backend import init_installed_backends -from phiml.math import tensor, spatial, batch, channel, wrap, dual +from phiml.math import tensor, spatial, batch, channel, wrap, dual, Tensor BACKENDS = init_installed_backends() @@ -269,3 +270,24 @@ def test_broadcast(self): len_ = math.broadcast(len, channel) strings = math.vec('vector', 'a', 'bc', '') math.assert_close([1, 2, 0], len_(strings)) + + def test_when_available(self): + for backend in BACKENDS: + with backend: + TRACER: List[Tensor] = [] + CONCRETE: List[Tensor] = [] + + @math.jit_compile + def fun(x): + TRACER.append(x) + + def print_x(x): + CONCRETE.append(x) + + math.when_available(print_x, x) + return x + + fun(tensor(0)) + assert CONCRETE + assert CONCRETE[0].available + assert TRACER