Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid silent preequilibration failure in JAX #2631

Merged
merged 14 commits into from
Dec 19, 2024
10 changes: 9 additions & 1 deletion python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,15 @@
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
throw=False,
)
return sol.ys[-1, :], sol.stats
# If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the
# solution is the last state and the event mask is False. In the latter case, we return inf for the steady
# state.
ys = jnp.where(

Check warning on line 287 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L287

Added line #L287 was not covered by tests
sol.event_mask,
sol.ys[-1, :],
jnp.inf * jnp.ones_like(sol.ys[-1, :]),
)
return ys, sol.stats

Check warning on line 292 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L292

Added line #L292 was not covered by tests

def _solve(
self,
Expand Down
25 changes: 24 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import diffrax
import numpy as np
from beartype import beartype
from petab.v2.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID

from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem, ReturnValue
from amici.jax import JAXProblem, ReturnValue, run_simulations
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

Expand Down Expand Up @@ -268,6 +269,28 @@ def check_fields_jax(
)


def test_preequilibration_failure(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
# oscillating system, preequilibation should fail when interaction is active
with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert not np.isinf(r[0].item())
petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = (
petab_problem.measurement_df[SIMULATION_CONDITION_ID]
)
with TemporaryDirectoryWinSafe(prefix="failure") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert np.isinf(r[0].item())


@skip_on_valgrind
def test_serialisation(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
Expand Down
Loading