Skip to content

Commit

Permalink
Add perf_counter()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 27, 2023
1 parent b27234f commit d3fa949
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 1 deletion.
2 changes: 1 addition & 1 deletion phiml/backend/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, kind: type, bits: int = None, precision: int = None):
else:
bits = precision * 2
else:
assert isinstance(bits, int)
assert isinstance(bits, int), f"bits must be an int but got {type(bits)}"
self.kind = kind
""" Python class corresponding to the type of data, ignoring precision. One of (bool, int, float, complex, str) """
self.bits = bits
Expand Down
1 change: 1 addition & 0 deletions phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
trace_check,
map_ as map,
when_available,
perf_counter,
)

from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu
Expand Down
30 changes: 30 additions & 0 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import time
import types
import warnings
from functools import wraps, partial
Expand All @@ -15,6 +16,7 @@
from ..backend import Backend, NUMPY
from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER
from .magic import PhiTreeNode, Shapable
from ..backend._dtype import DType

X = TypeVar('X')
Y = TypeVar('Y')
Expand Down Expand Up @@ -158,20 +160,23 @@ def __init__(self, f: Callable, auxiliary_args: Set[str], forget_traces: bool):
self.grad_jit = GradientFunction(f.f, self.f_params, f.wrt, f.get_output, f.is_f_scalar, jit=True) if isinstance(f, GradientFunction) else None
self._extract_tensors: List[Tuple[Tensor]] = []
self._post_call: List[Callable] = []
self._tracing_in_key: SignatureKey = None

def _jit_compile(self, in_key: SignatureKey):
ML_LOGGER.debug(f"Φ-ML-jit: '{f_name(self.f)}' called with new key. shapes={[s.volume for s in in_key.shapes]}, args={in_key.tree}")

def jit_f_native(*natives):
ML_LOGGER.debug(f"Φ-ML-jit: Tracing '{f_name(self.f)}'")
_TRACING_JIT.append(self)
self._tracing_in_key = in_key
in_tensors = assemble_tensors(natives, in_key.specs)
kwargs = assemble_tree(in_key.tree, in_tensors)
f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
tree, out_tensors = disassemble_tree((f_output, self._extract_tensors))
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing)
assert _TRACING_JIT.pop(-1) is self
self._tracing_in_key = None
return result_natives

jit_f_native.__name__ = f"native({f_name(self.f) if isinstance(self.f, types.FunctionType) else str(self.f)})"
Expand Down Expand Up @@ -1206,3 +1211,28 @@ def identity(x):
`x`
"""
return x


def perf_counter(wait_for_tensor, *wait_for_tensors: Tensor) -> Tensor:
"""
Get the time (`time.perf_counter()`) at which all `wait_for_tensors` are computed.
If all tensors are already available, returns the current `time.perf_counter()`.
Args:
wait_for_tensor: `Tensor` that need to be computed before the time is measured.
*wait_for_tensors: Additional tensors that need to be computed before the time is measured.
Returns:
Time at which all `wait_for_tensors` are ready as a scalar `Tensor`.
"""
assert not _TRACING_LINEAR, f"Cannot use perf_counter inside a function decorated with @jit_compile_linear"
if not _TRACING_JIT:
return wrap(time.perf_counter())
else: # jit
backend = _TRACING_JIT[0]._tracing_in_key.backend
natives, _, _ = disassemble_tensors([wait_for_tensor, *wait_for_tensors], expand=False)
natives = [n for n in natives if backend.is_tensor(n, only_native=True)]
assert natives, f"in jit mode, perf_counter must be given at least one traced tensor, as the current time is evaluated after all tensors are computed."
def perf_counter(*_wait_for_natives):
return np.asarray(time.perf_counter())
return wrap(backend.numpy_call(perf_counter, (), DType(float, 64), *natives))
17 changes: 17 additions & 0 deletions tests/commit/math/test__functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,20 @@ def print_x(x):
assert CONCRETE
assert CONCRETE[0].available
assert TRACER

def test_perf_counter(self):
@math.jit_compile
def fun(x):
t0 = math.perf_counter(x)
print("fun called, time=", t0)
for i in range(1000):
x *= 0.5
x += 1
dt = math.perf_counter(x) - t0
return x, dt

for backend in BACKENDS:
with backend:
result, exec_time = fun(tensor(0))
print("time taken", 1000_000 * float(exec_time), "result:", result)

0 comments on commit d3fa949

Please sign in to comment.