Skip to content

Commit

Permalink
Add when_available()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 27, 2023
1 parent ab1ffbc commit 1d39fe7
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 8 deletions.
1 change: 1 addition & 0 deletions phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
identity,
trace_check,
map_ as map,
when_available,
)

from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu
Expand Down
46 changes: 39 additions & 7 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from . import _ops as math
from . import _ops as math, all_available
from ._magic_ops import stack, pack_dims, expand, unpack_dim
from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel, merge_shapes, DimFilter, shape
from ._sparse import SparseCoordinateTensor
Expand Down Expand Up @@ -156,6 +156,8 @@ def __init__(self, f: Callable, auxiliary_args: Set[str], forget_traces: bool):
self.traces: Dict[SignatureKey, Callable] = {}
self.recorded_mappings: Dict[SignatureKey, SignatureKey] = {}
self.grad_jit = GradientFunction(f.f, self.f_params, f.wrt, f.get_output, f.is_f_scalar, jit=True) if isinstance(f, GradientFunction) else None
self._extract_tensors: List[Tuple[Tensor]] = []
self._post_call: List[Callable] = []

def _jit_compile(self, in_key: SignatureKey):
ML_LOGGER.debug(f"Φ-ML-jit: '{f_name(self.f)}' called with new key. shapes={[s.volume for s in in_key.shapes]}, args={in_key.tree}")
Expand All @@ -165,8 +167,8 @@ def jit_f_native(*natives):
_TRACING_JIT.append(self)
in_tensors = assemble_tensors(natives, in_key.specs)
kwargs = assemble_tree(in_key.tree, in_tensors)
result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
tree, out_tensors = disassemble_tree(result)
f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
tree, out_tensors = disassemble_tree((f_output, self._extract_tensors))
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing)
assert _TRACING_JIT.pop(-1) is self
Expand Down Expand Up @@ -198,7 +200,14 @@ def __call__(self, *args, **kwargs):
native_result = self.traces[key](*natives)
output_key = match_output_signature(key, self.recorded_mappings, self)
output_tensors = assemble_tensors(native_result, output_key.specs)
return assemble_tree(output_key.tree, output_tensors)
output, extracted_tensor_lists = assemble_tree(output_key.tree, output_tensors)
for extracted_tensors, runnable in zip(extracted_tensor_lists, self._post_call):
runnable(*extracted_tensors)
return output

def extract_and_call(self, tensors: Tuple[Tensor], runnable: Callable):
self._extract_tensors.append(tensors)
self._post_call.append(runnable)

def __repr__(self):
return f"jit({f_name(self.f)})"
Expand All @@ -208,9 +217,6 @@ def __name__(self):
return f_name(self.f)


_TRACING_JIT: List[JitFunction] = []


def jit_compile(f: Callable = None, auxiliary_args: str = '', forget_traces: bool = None) -> Callable:
"""
Compiles a graph based on the function `f`.
Expand Down Expand Up @@ -297,7 +303,9 @@ def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict):
else:
if self.forget_traces:
self.matrices_and_biases.clear()
_TRACING_LINEAR.append(self)
matrix, bias = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True)
assert _TRACING_JIT.pop(-1) is self
if not key.tracing:
self.matrices_and_biases[key] = matrix, bias
if len(self.matrices_and_biases) >= 4:
Expand Down Expand Up @@ -406,6 +414,30 @@ def my_linear_function(x: math.Tensor) -> math.Tensor:
return f if isinstance(f, LinearFunction) and f.auxiliary_args == auxiliary_args else LinearFunction(f, auxiliary_args, forget_traces or False)


_TRACING_JIT: List[JitFunction] = []
_TRACING_LINEAR: List[LinearFunction] = []


def when_available(runnable: Callable, *tensor_args: Tensor):
"""
Calls `runnable(*tensor_args)` once the concrete values of all tensors are available.
In eager mode, `runnable` is called immediately.
When jit-compiled, `runnable` is called after the jit-compiled function has returned.
Args:
runnable: Function to call as `runnable(*tensor_args)`. This can be a `lambda` function.
*tensor_args: `Tensor` values to pass to `runnable` with concrete values.
"""
if _TRACING_LINEAR:
raise RuntimeError(f"when_available() cannot be called inside a function marked as @jit_compile_linear")
if all_available(*tensor_args): # eager or NumPy
runnable(*tensor_args)
else:
assert _TRACING_JIT, f"tensors are not available but no JIT function is being traced. Maybe you are using external jit?"
for jit_f in _TRACING_JIT:
jit_f.extract_and_call(tensor_args, runnable)


def simplify_wrt(f, wrt: Union[str, int, tuple, list]):
f_params = function_parameters(f)
if wrt is None: # Old default
Expand Down
24 changes: 23 additions & 1 deletion tests/commit/math/test__functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from functools import partial
from typing import List
from unittest import TestCase

from phiml import math
from phiml.backend import Backend
from phiml.backend._backend import init_installed_backends
from phiml.math import tensor, spatial, batch, channel, wrap, dual
from phiml.math import tensor, spatial, batch, channel, wrap, dual, Tensor

BACKENDS = init_installed_backends()

Expand Down Expand Up @@ -269,3 +270,24 @@ def test_broadcast(self):
len_ = math.broadcast(len, channel)
strings = math.vec('vector', 'a', 'bc', '')
math.assert_close([1, 2, 0], len_(strings))

def test_when_available(self):
for backend in BACKENDS:
with backend:
TRACER: List[Tensor] = []
CONCRETE: List[Tensor] = []

@math.jit_compile
def fun(x):
TRACER.append(x)

def print_x(x):
CONCRETE.append(x)

math.when_available(print_x, x)
return x

fun(tensor(0))
assert CONCRETE
assert CONCRETE[0].available
assert TRACER

0 comments on commit 1d39fe7

Please sign in to comment.