Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 2, 2024
1 parent 6718a1e commit bbd1c9c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 30 deletions.
34 changes: 17 additions & 17 deletions python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,28 @@ def __init__(self):
super().__init__()

def _xdot(self, t, x, args):
pk, tcl = args
p, tcl = args

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_P_SYMS = p
TPL_TCL_SYMS = tcl
TPL_W_SYMS = self._w(t, x, pk, tcl)
TPL_W_SYMS = self._w(t, x, p, tcl)

TPL_XDOT_EQ

return TPL_XDOT_RET

def _w(self, t, x, pk, tcl):
def _w(self, t, x, p, tcl):
TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_P_SYMS = p
TPL_TCL_SYMS = tcl

TPL_W_EQ

return TPL_W_RET

def _x0(self, pk):
TPL_PK_SYMS = pk
def _x0(self, p):
TPL_P_SYMS = p

TPL_X0_EQ

Expand All @@ -56,36 +56,36 @@ def _x_rdata(self, x, tcl):

return TPL_X_RDATA_RET

def _tcl(self, x, pk):
def _tcl(self, x, p):
TPL_X_RDATA_SYMS = x
TPL_PK_SYMS = pk
TPL_P_SYMS = p

TPL_TOTAL_CL_EQ

return TPL_TOTAL_CL_RET

def _y(self, t, x, pk, tcl):
def _y(self, t, x, p, tcl):
TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_W_SYMS = self._w(t, x, pk, tcl)
TPL_P_SYMS = p
TPL_W_SYMS = self._w(t, x, p, tcl)

TPL_Y_EQ

return TPL_Y_RET

def _sigmay(self, y, pk):
TPL_PK_SYMS = pk
def _sigmay(self, y, p):
TPL_P_SYMS = p

TPL_Y_SYMS = y

TPL_SIGMAY_EQ

return TPL_SIGMAY_RET

def _nllh(self, t, x, pk, tcl, my, iy):
y = self._y(t, x, pk, tcl)
def _nllh(self, t, x, p, tcl, my, iy):
y = self._y(t, x, p, tcl)
TPL_Y_SYMS = y
TPL_SIGMAY_SYMS = self._sigmay(y, pk)
TPL_SIGMAY_SYMS = self._sigmay(y, p)

TPL_JY_EQ

Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _generate_jax_code(self) -> None:
outdir.mkdir(parents=True, exist_ok=True)

apply_template(
Path(amiciModulePath) / "jax.template.py",
Path(amiciModulePath) / "jax" / "jax.template.py",
outdir / "__init__.py",
tpl_data,
)
Expand Down
2 changes: 0 additions & 2 deletions python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def import_petab_problem(
petab_problem,
model_name=model_name,
model_output_dir=model_output_dir,
compile=kwargs.pop("compile", not jax),
jax=jax,
**kwargs,
)
Expand All @@ -156,7 +155,6 @@ def import_petab_problem(
model_name=model_name,
model_output_dir=model_output_dir,
non_estimated_parameters_as_constants=non_estimated_parameters_as_constants,
compile=kwargs.pop("compile", not jax),
jax=jax,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/petab/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def import_model_sbml(
verbose=verbose,
**kwargs,
)
return sbml_importer
else:
sbml_importer.sbml2amici(
model_name=model_name,
Expand Down
28 changes: 18 additions & 10 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from beartype import beartype

from amici.pysb_import import pysb2amici
from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem
Expand All @@ -39,17 +39,21 @@ def test_conversion():
pysb.Rule("conv", a(s="a") >> a(s="b"), pysb.Parameter("kcat", 0.05))
pysb.Observable("ab", a(s="b"))

with TemporaryDirectoryWinSafe(prefix=model.name) as outdir:
with TemporaryDirectoryWinSafe() as outdir:
pysb2amici(model, outdir, verbose=True, observables=["ab"])
pysb2jax(model, outdir, verbose=True, observables=["ab"])

model_module = amici.import_model_module(
amici_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)
jax_module = amici.import_model_module(
module_name=model.name + "_jax", module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((1.0, 0.1), axis=-1)
k = tuple()
_test_model(model_module, ts, p, k)
_test_model(amici_module, jax_module, ts, p, k)


@skip_on_valgrind
Expand Down Expand Up @@ -86,34 +90,38 @@ def test_dimerization():
pysb.Observable("a_obs", a())
pysb.Observable("b_obs", b())

with TemporaryDirectoryWinSafe(prefix=model.name) as outdir:
with TemporaryDirectoryWinSafe() as outdir:
pysb2amici(
model,
outdir,
verbose=True,
observables=["a_obs", "b_obs"],
constant_parameters=["ksyn_a", "ksyn_b"],
)
pysb2jax(model, outdir, verbose=True, observables=["ab"])

model_module = amici.import_model_module(
amici_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)
jax_module = amici.import_model_module(
module_name=model.name + "_jax", module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1)
k = (0.5, 5)
_test_model(model_module, ts, p, k)
_test_model(amici_module, jax_module, ts, p, k)


def _test_model(model_module, ts, p, k):
amici_model = model_module.getModel()
def _test_model(amici_module, jax_module, ts, p, k):
amici_model = amici_module.getModel()

amici_model.setTimepoints(np.asarray(ts, dtype=np.float64))
sol_amici_ref = amici.runAmiciSimulation(
amici_model, amici_model.getSolver()
)

jax_model = model_module.get_jax_model()
jax_model = jax_module.Model()

amici_model.setParameters(np.asarray(p, dtype=np.float64))
amici_model.setFixedParameters(np.asarray(k, dtype=np.float64))
Expand Down

0 comments on commit bbd1c9c

Please sign in to comment.