Skip to content

Commit

Permalink
Avoid silent preequilibration failure in JAX (#2631)
Browse files Browse the repository at this point in the history
* add event check in equilibration, add test

* Update test_jax.py

* Update model.py

* Update model.py

* decrease integration tols

* decrease integration tols

* decrease integration tols

* decrease integration tols

* revert scaling factor, add ss event API

* Update test_jax.py

* Update ExampleJaxPEtab.ipynb

* Update petab.py
  • Loading branch information
FFroehlich authored Dec 18, 2024
1 parent 3d73624 commit 1716181
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 11 deletions.
1 change: 1 addition & 0 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@
" iy_trafos=jnp.array(iy_trafos),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" steady_state_event=diffrax.steady_state_event(),\n",
" max_steps=2**10,\n",
" adjoint=diffrax.DirectAdjoint(),\n",
" ret=ReturnValue.y, # Return observables\n",
Expand Down
52 changes: 48 additions & 4 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import jax
import jaxtyping as jt

from collections.abc import Callable


class ReturnValue(enum.Enum):
llh = "log-likelihood"
Expand All @@ -32,6 +34,13 @@ class JAXModel(eqx.Module):
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by
classes inheriting from JAXModel.
:ivar api_version:
API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION).
:ivar MODEL_API_VERSION:
API version of the base class.
:ivar jax_py_file:
Path to the JAX model file.
"""

MODEL_API_VERSION = "0.0.2"
Expand Down Expand Up @@ -248,6 +257,9 @@ def _eq(
x0: jt.Float[jt.Array, "nxs"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]:
"""
Expand Down Expand Up @@ -278,10 +290,20 @@ def _eq(
stepsize_controller=controller,
max_steps=max_steps,
adjoint=diffrax.DirectAdjoint(),
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
event=diffrax.Event(
cond_fn=steady_state_event,
),
throw=False,
)
return sol.ys[-1, :], sol.stats
# If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the
# solution is the last state and the event mask is False. In the latter case, we return inf for the steady
# state.
ys = jnp.where(
sol.event_mask,
sol.ys[-1, :],
jnp.inf * jnp.ones_like(sol.ys[-1, :]),
)
return ys, sol.stats

def _solve(
self,
Expand Down Expand Up @@ -450,6 +472,9 @@ def simulate_condition(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
Expand Down Expand Up @@ -525,7 +550,13 @@ def simulate_condition(
# Post-equilibration
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(
p, tcl, x_solver, solver, controller, max_steps
p,
tcl,
x_solver,
solver,
controller,
steady_state_event,
max_steps,
)
else:
stats_posteq = None
Expand Down Expand Up @@ -596,13 +627,20 @@ def preequilibrate_condition(
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param solver:
ODE solver
:param controller:
Expand All @@ -619,7 +657,13 @@ def preequilibrate_condition(
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
p,
tcl,
current_x,
solver,
controller,
steady_state_event,
max_steps,
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)
Expand Down
26 changes: 24 additions & 2 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numbers import Number
from collections.abc import Iterable
from pathlib import Path
from collections.abc import Callable


import diffrax
Expand Down Expand Up @@ -465,6 +466,9 @@ def run_simulation(
simulation_condition: tuple[str, ...],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ret: ReturnValue = ReturnValue.llh,
Expand Down Expand Up @@ -507,6 +511,7 @@ def run_simulation(
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
adjoint=diffrax.RecursiveCheckpointAdjoint()
if ret in (ReturnValue.llh, ReturnValue.chi2)
else diffrax.DirectAdjoint(),
Expand All @@ -518,6 +523,9 @@ def run_preequilibration(
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Expand All @@ -539,12 +547,13 @@ def run_preequilibration(
simulation_condition, p
)
return self.model.preequilibrate_condition(
p=eqx.debug.backward_nan(p),
p=p,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
)


Expand All @@ -555,6 +564,9 @@ def run_simulations(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
ret: ReturnValue | str = ReturnValue.llh,
):
Expand All @@ -569,6 +581,9 @@ def run_simulations(
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param steady_state_event:
Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
Expand All @@ -583,7 +598,9 @@ def run_simulations(
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
sc: problem.run_preequilibration(
sc, solver, controller, steady_state_event, max_steps
)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}
Expand All @@ -593,6 +610,7 @@ def run_simulations(
sc,
solver,
controller,
steady_state_event,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret=ret,
Expand All @@ -617,6 +635,9 @@ def petab_simulate(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
):
"""
Expand All @@ -637,6 +658,7 @@ def petab_simulate(
problem,
solver=solver,
controller=controller,
steady_state_event=steady_state_event,
max_steps=max_steps,
ret=ReturnValue.y,
)
Expand Down
26 changes: 25 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import diffrax
import numpy as np
from beartype import beartype
from petab.v1.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID

from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem, ReturnValue
from amici.jax import JAXProblem, ReturnValue, run_simulations
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

Expand Down Expand Up @@ -198,6 +199,7 @@ def check_fields_jax(
"solver": diffrax.Kvaerno5(),
"controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM),
"adjoint": diffrax.RecursiveCheckpointAdjoint(),
"steady_state_event": diffrax.steady_state_event(),
"max_steps": 2**8, # max_steps
}
fun = beartype(jax_model.simulate_condition)
Expand Down Expand Up @@ -266,6 +268,28 @@ def check_fields_jax(
)


def test_preequilibration_failure(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
# oscillating system, preequilibation should fail when interaction is active
with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert not np.isinf(r[0].item())
petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = (
petab_problem.measurement_df[SIMULATION_CONDITION_ID]
)
with TemporaryDirectoryWinSafe(prefix="failure") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert np.isinf(r[0].item())


@skip_on_valgrind
def test_serialisation(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
Expand Down
4 changes: 3 additions & 1 deletion tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from functools import partial
from pathlib import Path

import fiddy
import amici
import numpy as np
Expand Down Expand Up @@ -342,8 +343,9 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
llh_jax, _ = beartype(run_simulations)(jax_problem)

if problem_id in problems_for_gradient_check:
beartype(run_simulations)(jax_problem)
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
run_simulations, has_aux=True
)(jax_problem)
Expand Down
15 changes: 12 additions & 3 deletions tests/petab_test_suite/test_petab_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import sys

import diffrax

import amici
import pandas as pd
import petab.v1 as petab
Expand Down Expand Up @@ -68,10 +70,17 @@ def _test_case(case, model_type, version, jax):
if jax:
from amici.jax import JAXProblem, run_simulations, petab_simulate

steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
jax_problem = JAXProblem(model, problem)
llh, ret = run_simulations(jax_problem)
chi2, _ = run_simulations(jax_problem, ret="chi2")
simulation_df = petab_simulate(jax_problem)
llh, ret = run_simulations(
jax_problem, steady_state_event=steady_state_event
)
chi2, _ = run_simulations(
jax_problem, ret="chi2", steady_state_event=steady_state_event
)
simulation_df = petab_simulate(
jax_problem, steady_state_event=steady_state_event
)
simulation_df.rename(
columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True
)
Expand Down

0 comments on commit 1716181

Please sign in to comment.