Skip to content

Commit

Permalink
update after test refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 19, 2024
1 parent 4596dc4 commit bd103db
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 129 deletions.
9 changes: 7 additions & 2 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,10 @@ def _process_hybridisation(self, hybridisation: dict) -> None:
"""
added_expressions = False
for net_id, net in hybridisation.items():
if not (net["output"] == "ode" or net["input"] == "ode"):
if not (

Check warning on line 2348 in python/sdist/amici/de_model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/de_model.py#L2346-L2348

Added lines #L2346 - L2348 were not covered by tests
net["hybridization"]["output"] == "ode"
or net["hybridization"]["input"] == "ode"
):
continue # do not integrate into ODEs, handle in amici.jax.petab
inputs = [

Check warning on line 2353 in python/sdist/amici/de_model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/de_model.py#L2352-L2353

Added lines #L2352 - L2353 were not covered by tests
comp
Expand Down Expand Up @@ -2400,7 +2403,9 @@ def _process_hybridisation(self, hybridisation: dict) -> None:
)

# generate dummy Function
out_val = sp.Function(net_id)(*inputs, iout)
out_val = sp.Function(net_id)(

Check warning on line 2406 in python/sdist/amici/de_model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/de_model.py#L2406

Added line #L2406 was not covered by tests
*[input.get_id() for input in inputs], iout
)

# add to the model
if isinstance(comp, DifferentialState):
Expand Down
9 changes: 5 additions & 4 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,11 @@ def _generate_jax_code(self) -> None:

def _generate_nn_code(self) -> None:
for net_name, net in self.hybridisation.items():
generate_equinox(
net["model"],
self.model_path / f"{net_name}.py",
)
for model in net["model"]:
generate_equinox(

Check warning on line 260 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L259-L260

Added lines #L259 - L260 were not covered by tests
model,
self.model_path / f"{net_name}.py",
)

def set_paths(self, output_dir: str | Path | None = None) -> None:
"""
Expand Down
65 changes: 45 additions & 20 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,36 @@ def _get_nominal_parameter_values(
}
# extract nominal values from petab problem
for pname, row in self._petab_problem.parameter_df.iterrows():
if (net := pname.split("_")[0]) in model.nns:
if (net := pname.split(".")[0]) in model.nns:
to_set = []
nn = model_pars[net]
layer = nn[pname.split("_")[1]]
attribute = pname.split("_")[2]
index = tuple(np.array(pname.split("_")[3:]).astype(int))
layer[attribute] = (
layer[attribute].at[index].set(row[petab.NOMINAL_VALUE])
)
if len(pname.split(".")) > 1:
layer = nn[pname.split(".")[1]]
if len(pname.split(".")) > 2:
to_set.append(

Check warning on line 291 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L286-L291

Added lines #L286 - L291 were not covered by tests
(pname.split(".")[1], pname.split(".")[2])
)
else:
to_set.extend(

Check warning on line 295 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L295

Added line #L295 was not covered by tests
[
(pname.split(".")[1], attribute)
for attribute in layer.keys()
]
)
else:
to_set.extend(

Check warning on line 302 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L302

Added line #L302 was not covered by tests
[
(layer_name, attribute)
for layer_name, layer in nn.items()
for attribute in layer.keys()
]
)

for layer, attribute in to_set:
nn[layer][attribute] = row[

Check warning on line 311 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L310-L311

Added lines #L310 - L311 were not covered by tests
petab.NOMINAL_VALUE
] * jnp.ones_like(nn[layer][attribute])

# set values in model
for net_id in model_pars:
for layer_id in model_pars[net_id]:
Expand All @@ -316,14 +338,9 @@ def _get_nominal_parameter_values(
), model

def _get_inputs(self):
if (
self._petab_problem.mapping_df is None
or "netId" not in self._petab_problem.mapping_df.columns
):
if self._petab_problem.mapping_df is None:
return {}
inputs = {
net: {} for net in self._petab_problem.mapping_df["netId"].unique()
}
inputs = {net: {} for net in self.model.nns.keys()}
for petab_id, row in self._petab_problem.mapping_df.iterrows():
if (filepath := Path(petab_id)).is_file():
data_flat = pd.read_csv(filepath, sep="\t").sort_values(

Check warning on line 346 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L343-L346

Added lines #L343 - L346 were not covered by tests
Expand Down Expand Up @@ -368,9 +385,10 @@ def nn_output_ids(self) -> list[str]:
if self._petab_problem.mapping_df is None:
return []
return self._petab_problem.mapping_df[

Check warning on line 387 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L387

Added line #L387 was not covered by tests
self._petab_problem.mapping_df[
petab.MODEL_ENTITY_ID
].str.startswith("output")
self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID]
.str.split(".")
.str[1]
.str.startswith("output")
].index.tolist()

def get_petab_parameter_by_id(self, name: str) -> jnp.float_:
Expand Down Expand Up @@ -402,11 +420,18 @@ def _unscale(
)

def _eval_nn(self, output_par: str):
net_id = self._petab_problem.mapping_df.loc[output_par, "netId"]
net_id = self._petab_problem.mapping_df.loc[

Check warning on line 423 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L423

Added line #L423 was not covered by tests
output_par, petab.MODEL_ENTITY_ID
].split(".")[0]
nn = self.model.nns[net_id]

Check warning on line 426 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L426

Added line #L426 was not covered by tests

model_id_map = (

Check warning on line 428 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L428

Added line #L428 was not covered by tests
self._petab_problem.mapping_df.query(f'netId == "{net_id}"')
self._petab_problem.mapping_df[
self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID]
.str.split(".")
.str[0]
== net_id
]
.reset_index()
.set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID]
.to_dict()
Expand All @@ -422,7 +447,7 @@ def _eval_nn(self, output_par: str):
petab_id, petab.NOMINAL_VALUE
]
for model_id, petab_id in model_id_map.items()
if model_id.startswith("input")
if model_id.split(".")[1].startswith("input")
]
)
return nn.forward(net_input).squeeze()

Check warning on line 453 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L453

Added line #L453 was not covered by tests
Expand Down
47 changes: 22 additions & 25 deletions python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,43 +150,40 @@ def import_petab_problem(
from petab_sciml import PetabScimlStandard

config = petab_problem.extensions_config["petab_sciml"]
net_files = config.get("net_files", [])
# TODO: net files need to be absolute paths
ml_models = [
model
for net_file in net_files
for model in PetabScimlStandard.load_data(
Path() / net_file
).models
]
hybridisation = {
net: {
"model": next(
ml_model
for ml_model in ml_models
if ml_model.mlmodel_id == net
),
net_id: {
"model": PetabScimlStandard.load_data(
Path() / net_config["file"]
).models,
"input_vars": [
petab_id
for petab_id, model_id in petab_problem.mapping_df.query(
f"netId == '{net}'"
)[petab.MODEL_ENTITY_ID]
for petab_id, model_id in petab_problem.mapping_df.loc[
petab_problem.mapping_df[petab.MODEL_ENTITY_ID]
.str.split(".")
.str[0]
== net_id,
petab.MODEL_ENTITY_ID,
]
.to_dict()
.items()
if model_id.startswith("input")
if model_id.split(".")[1].startswith("input")
],
"output_vars": [
petab_id
for petab_id, model_id in petab_problem.mapping_df.query(
f"netId == '{net}'"
)[petab.MODEL_ENTITY_ID]
for petab_id, model_id in petab_problem.mapping_df.loc[
petab_problem.mapping_df[petab.MODEL_ENTITY_ID]
.str.split(".")
.str[0]
== net_id,
petab.MODEL_ENTITY_ID,
]
.to_dict()
.items()
if model_id.startswith("output")
if model_id.split(".")[1].startswith("output")
],
**hybrid,
**net_config,
}
for net, hybrid in config["hybridization"].items()
for net_id, net_config in config.items()
}
if not jax or petab_problem.model.type_id == MODEL_TYPE_PYSB:
raise NotImplementedError(
Expand Down
Loading

0 comments on commit bd103db

Please sign in to comment.