diff --git a/pytest.ini b/pytest.ini index 8cc45e0fd9..03d50d80e1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -24,5 +24,7 @@ filterwarnings = ignore:.*PyDevIPCompleter6.*:DeprecationWarning # ignore numpy log(0) warnings (np.log(0) = -inf) ignore:divide by zero encountered in log:RuntimeWarning + # ignore jax deprecation warnings + ignore:jax.* is deprecated:DeprecationWarning norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 1310091f4c..7fe21257a9 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -358,7 +358,7 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Load condition-specific data\n", - "ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", + "ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", " simulation_condition\n", "]\n", "\n", @@ -371,7 +371,6 @@ "def grad_ts_dyn(tt):\n", " return jax_problem.model.simulate_condition(\n", " p=p,\n", - " ts_init=ts_init,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 43edc30843..4692524070 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -466,7 +466,6 @@ def _sigmays( def simulate_condition( self, p: jt.Float[jt.Array, "np"], - ts_init: jt.Float[jt.Array, "nt_preeq"], ts_dyn: jt.Float[jt.Array, "nt_dyn"], ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], @@ -486,13 +485,9 @@ def simulate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: - :param ts_init: - time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to - the number of observables that are evaluated before dynamic simulation. :param ts_dyn: - time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order. - Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time - points. + time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are + allowed to facilitate the evaluation of multiple observables at specific time points. :param ts_posteq: time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to the number of observables that are evaluated after post-equilibration. @@ -532,8 +527,6 @@ def simulate_condition( x_solver = self._x_solver(x) tcl = self._tcl(x, p) - x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0) - # Dynamic simulation if ts_dyn.shape[0]: x_dyn, stats_dyn = self._solve( @@ -565,8 +558,8 @@ def simulate_condition( x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0 ) - ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) - x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) + ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0) + x = jnp.concatenate((x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) llh = -jnp.sum(nllhs) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index b2b42f5c2a..6a7da4b42f 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -71,7 +71,7 @@ class JAXProblem(eqx.Module): :ivar _parameter_mappings: :class:`ParameterMappingForCondition` instances for each simulation condition. :ivar _measurements: - Subset measurement dataframes for each simulation condition. + Preprocessed arrays for each simulation condition. :ivar _petab_problem: PEtab problem to simulate. """ @@ -87,7 +87,6 @@ class JAXProblem(eqx.Module): np.ndarray, np.ndarray, np.ndarray, - np.ndarray, ], ] _petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]] @@ -187,7 +186,6 @@ def _get_measurements( np.ndarray, np.ndarray, np.ndarray, - np.ndarray, ], ], dict[tuple[str, ...], tuple[int, ...]], @@ -213,11 +211,9 @@ def _get_measurements( ) ts = m[petab.TIME] - ts_preeq = ts[np.isfinite(ts) & (ts == 0)] - ts_dyn = ts[np.isfinite(ts) & (ts > 0)] + ts_dyn = ts[np.isfinite(ts)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] - index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index - ts_preeq = ts_preeq.values + index = pd.concat([ts_dyn, ts_posteq]).index ts_dyn = ts_dyn.values ts_posteq = ts_posteq.values my = m[petab.MEASUREMENT].values @@ -245,7 +241,6 @@ def _get_measurements( iy_trafos = np.zeros_like(iys) measurements[tuple(simulation_condition)] = ( - ts_preeq, ts_dyn, ts_posteq, my, @@ -492,7 +487,7 @@ def run_simulation( :return: Tuple of output value and simulation statistics """ - ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ + ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ simulation_condition ] p = self.load_parameters(simulation_condition[0]) @@ -501,7 +496,6 @@ def run_simulation( ) return self.model.simulate_condition( p=p, - ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), @@ -650,7 +644,7 @@ def petab_simulate( for sc, ys in y.items(): obs = [ problem.model.observable_ids[io] - for io in problem._measurements[sc][4] + for io in problem._measurements[sc][3] ] t = jnp.concat(problem._measurements[sc][:2]) df_sc = pd.DataFrame( diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index b62903240e..db48bd6766 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -83,11 +83,11 @@ examples = [ "scipy", ] jax = [ - "jax>=0.4.34,<0.4.36", - "jaxlib>=0.4.34", - "diffrax>=0.6.0", + "jax>=0.4.36", + "jaxlib>=0.4.36", + "diffrax>=0.6.1", "jaxtyping>=0.2.34", - "equinox>=0.11.8", + "equinox>=0.11.10", "optimistix>=0.0.9", "interpax>=0.3.3", ] diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 7e997cb9a3..f8261fb15d 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -180,8 +180,7 @@ def check_fields_jax( iys = iys.flatten() iy_trafos = np.zeros_like(iys) - ts_init = ts[ts == 0] - ts_dyn = ts[ts > 0] + ts_dyn = ts ts_posteq = np.array([]) par_dict = { @@ -191,7 +190,6 @@ def check_fields_jax( p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) kwargs = { - "ts_init": jnp.array(ts_init), "ts_dyn": jnp.array(ts_dyn), "ts_posteq": jnp.array(ts_posteq), "my": jnp.array(my),