Skip to content

Commit

Permalink
Merge pull request #63 from dnv-innersource/enhance/unit-tests
Browse files Browse the repository at this point in the history
Add more unit tests
  • Loading branch information
Jorgelmh authored Oct 30, 2024
2 parents 791788d + 87d94d1 commit 9caf053
Show file tree
Hide file tree
Showing 9 changed files with 516 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ The changelog format is based on [Keep a Changelog](https://keepachangelog.com/e
## [Unreleased]

### Changed
* Added missing unit tests for the template data generated when building the FMU.
* Unit tests for the modelDescription.xml generation.
* Unit tests for the Interface JSON validation.
* Changed from `pip`/`tox` to `uv` as package manager
* README.md : Completely rewrote section "Development Setup", introducing `uv` as package manager.
* Added missing docstrings for py/cpp/h files with help of Github Copilot
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"sphinx.ext.napoleon",
"sphinx_argparse_cli",
"sphinx.ext.mathjax",
"matplotlib.sphinxext.plot_directive",
"sphinx.ext.autosummary",
"sphinx.ext.todo",
]
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ dev-dependencies = [
"sphinx-autodoc-typehints>=2.2",
"myst-parser>=4.0",
"furo>=2024.8",
"matplotlib>=3.9",
]
native-tls = true

Expand Down
18 changes: 9 additions & 9 deletions src/mlfmu/types/fmu_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ def check_only_one_initialization(self) -> InternalState:
if (not start_value) and name:
raise ValueError(
"name is set without start_value being set. "
"Both fields needs to be set for the state initialization to be valid."
"Both fields need to be set for the state initialization to be valid."
)
if start_value and (not name):
raise ValueError(
"start_value is set without name being set. "
"Both fields needs to be set for the state initialization to be valid."
"Both fields need to be set for the state initialization to be valid."
)
return self

Expand Down Expand Up @@ -306,7 +306,7 @@ class FmiInputVariable(InputVariable):

causality: FmiCausality
variable_references: list[int]
agent_state_init_indexes: list[list[int]]
agent_state_init_indexes: list[list[int]] = [] # noqa: RUF008

def __init__(self, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(**kwargs)
Expand Down Expand Up @@ -411,7 +411,6 @@ class ModelComponent(BaseModelConfig):
"""

name: str = Field(
default=None,
description="The name of the simulation model.",
)
version: str = Field(
Expand Down Expand Up @@ -708,8 +707,8 @@ def format_fmi_variable(

if var.is_array:
for idx, var_ref in enumerate(var.variable_references):
# Create port names that contain the index starting from 1. E.i signal[1], signal[2] ...
name = f"{var.name}[{idx+1}]"
# Create port names that contain the index starting from 1. E.i signal[0], signal[1] ...
name = f"{var.name}[{idx}]"
fmi_var = FmiVariable(
name=name,
variable_reference=var_ref,
Expand All @@ -730,7 +729,7 @@ def format_fmi_variable(
description=var.description or "",
variability=var.variability
or (FmiVariability.CONTINUOUS if var.causality != FmiCausality.PARAMETER else FmiVariability.TUNABLE),
start_value=var.start_value or 0,
start_value=var.start_value if var.start_value is not None else 0,
type=var.type or FmiVariableType.REAL,
)
variables.append(fmi_var)
Expand Down Expand Up @@ -772,7 +771,8 @@ def get_template_mapping(
for state_init_indexes in inp.agent_state_init_indexes:
num_state_init_indexes = len(state_init_indexes)
for variable_index, state_init_index in enumerate(state_init_indexes):
if variable_index >= num_variable_references:
_variable_index = variable_index
if _variable_index >= num_variable_references:
if not self.state_initialization_reuse:
warnings.warn(
f"Too few variables in {inp.name} (={num_variable_references}) "
Expand All @@ -782,7 +782,7 @@ def get_template_mapping(
stacklevel=1,
)
break
_variable_index = variable_index % num_variable_references
_variable_index = _variable_index % num_variable_references
state_init_mapping.append((state_init_index, inp.variable_references[_variable_index]))

for out in self.outputs:
Expand Down
2 changes: 1 addition & 1 deletion src/mlfmu/utils/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def format_template_data(onnx: ONNXModel, fmi_model: FmiModel, model_component:
"The number of total input indexes for all inputs and parameter in the interface file "
f"(={num_fmu_inputs}) cannot exceed the input size of the ml model (={onnx.input_size})"
)

if num_fmu_outputs > onnx.output_size:
raise ValueError(
"The number of total output indexes for all outputs in the interface file "
Expand Down Expand Up @@ -206,7 +207,6 @@ def validate_interface_spec(
The pydantic model instance that contains all the interface information.
"""
parsed_spec = ModelComponent.model_validate_json(json_data=spec, strict=True)

try:
validated_model = ModelComponent.model_validate(parsed_spec)
except ValidationError as e:
Expand Down
Binary file added tests/data/example.onnx
Binary file not shown.
166 changes: 166 additions & 0 deletions tests/utils/test_fmu_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import json
import re
from pathlib import Path

import pytest

from mlfmu.types.fmu_component import FmiModel
from mlfmu.types.onnx_model import ONNXModel
from mlfmu.utils.builder import format_template_data, validate_interface_spec


@pytest.fixture(scope="session")
def wind_generator_onnx() -> ONNXModel:
return ONNXModel(Path.cwd().parent / "data" / "example.onnx", time_input=True)


def test_valid_template_data(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"outputs": [
{
"name": "outputs",
"description": "My outputs",
"agentOutputIndexes": ["0:2"],
"isArray": True,
"length": 2,
}
],
"states": [
{"agentOutputIndexes": ["2:130"]},
{"name": "state1", "startValue": 10.0, "agentOutputIndexes": ["0"]},
{"name": "state2", "startValue": 180.0, "agentOutputIndexes": ["1"]},
],
}
_, model = validate_interface_spec(json.dumps(valid_spec))
assert model is not None

fmi_model = FmiModel(model=model)
template_data = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert template_data["FmuName"] == "example"
assert template_data["numFmuVariables"] == "6"
assert template_data["numOnnxInputs"] == "2"
assert template_data["numOnnxOutputs"] == "130"
assert template_data["numOnnxStates"] == "130"
assert template_data["onnxInputValueReferences"] == "0, 0, 1, 1"
assert template_data["onnxOutputValueReferences"] == "0, 2, 1, 3"


def test_template_data_invalid_input_size(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2},
{
"name": "inputs2",
"description": "My inputs 2",
"agentInputIndexes": ["0:10"],
"isArray": True,
"length": 10,
},
],
"outputs": [
{"name": "outputs", "description": "My outputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"states": [
{"agentOutputIndexes": ["2:130"]},
{"name": "state1", "startValue": 10.0, "agentOutputIndexes": ["0"]},
{"name": "state2", "startValue": 180.0, "agentOutputIndexes": ["1"]},
],
}

_, model = validate_interface_spec(json.dumps(valid_spec))
assert model is not None

fmi_model = FmiModel(model=model)

with pytest.raises(ValueError) as exc_info:
_ = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert exc_info.match(
re.escape(
"The number of total input indexes for all inputs and parameter in the interface file (=12) \
cannot exceed the input size of the ml model (=2)"
)
)


def test_template_data_invalid_output_size(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"outputs": [
{
"name": "outputs",
"description": "My outputs",
"agentOutputIndexes": ["0:2"],
"isArray": True,
"length": 2,
},
{
"name": "outputs2",
"description": "My outputs 2",
"agentOutputIndexes": ["0:200"],
"isArray": True,
"length": 200,
},
],
"states": [
{"agentOutputIndexes": ["2:130"]},
{"name": "state1", "startValue": 10.0, "agentOutputIndexes": ["0"]},
{"name": "state2", "startValue": 180.0, "agentOutputIndexes": ["1"]},
],
}

_, model = validate_interface_spec(json.dumps(valid_spec))
fmi_model = FmiModel(model=model)

with pytest.raises(ValueError) as exc_info:
_ = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert exc_info.match(
re.escape(
"The number of total output indexes for all outputs in the interface file (=202) \
cannot exceed the output size of the ml model (=130)"
)
)


def test_template_data_invalid_state_size(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"outputs": [
{"name": "outputs", "description": "My outputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"states": [
{"agentOutputIndexes": ["2:200"]},
],
}

_, model = validate_interface_spec(json.dumps(valid_spec))
assert model is not None

fmi_model = FmiModel(model=model)

with pytest.raises(ValueError) as exc_info:
_ = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert exc_info.match(
re.escape(
"The number of total output indexes for all states in the interface file (=198) \
cannot exceed either the state input size (=130)"
)
)
Loading

0 comments on commit 9caf053

Please sign in to comment.