Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more unit tests #63

Merged
merged 15 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ The changelog format is based on [Keep a Changelog](https://keepachangelog.com/e
## [Unreleased]

### Changed
Jorgelmh marked this conversation as resolved.
Show resolved Hide resolved
* Added missing unit tests for the template data generated when building the FMU.
* Unit tests for the modelDescription.xml generation.
* Unit tests for the Interface JSON validation.
* Changed from `pip`/`tox` to `uv` as package manager
* README.md : Completely rewrote section "Development Setup", introducing `uv` as package manager.
* Added missing docstrings for py/cpp/h files with help of Github Copilot
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"sphinx.ext.napoleon",
"sphinx_argparse_cli",
"sphinx.ext.mathjax",
"matplotlib.sphinxext.plot_directive",
"sphinx.ext.autosummary",
"sphinx.ext.todo",
]
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ dev-dependencies = [
"sphinx-autodoc-typehints>=2.2",
"myst-parser>=4.0",
"furo>=2024.8",
"matplotlib>=3.9",
]
native-tls = true

Expand Down
18 changes: 9 additions & 9 deletions src/mlfmu/types/fmu_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ def check_only_one_initialization(self) -> InternalState:
if (not start_value) and name:
raise ValueError(
"name is set without start_value being set. "
"Both fields needs to be set for the state initialization to be valid."
"Both fields need to be set for the state initialization to be valid."
)
if start_value and (not name):
raise ValueError(
"start_value is set without name being set. "
"Both fields needs to be set for the state initialization to be valid."
"Both fields need to be set for the state initialization to be valid."
)
return self

Expand Down Expand Up @@ -306,7 +306,7 @@ class FmiInputVariable(InputVariable):

causality: FmiCausality
variable_references: list[int]
agent_state_init_indexes: list[list[int]]
agent_state_init_indexes: list[list[int]] = [] # noqa: RUF008

def __init__(self, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(**kwargs)
Expand Down Expand Up @@ -411,7 +411,6 @@ class ModelComponent(BaseModelConfig):
"""

name: str = Field(
default=None,
description="The name of the simulation model.",
)
version: str = Field(
Expand Down Expand Up @@ -708,8 +707,8 @@ def format_fmi_variable(

if var.is_array:
for idx, var_ref in enumerate(var.variable_references):
# Create port names that contain the index starting from 1. E.i signal[1], signal[2] ...
name = f"{var.name}[{idx+1}]"
# Create port names that contain the index starting from 1. E.i signal[0], signal[1] ...
name = f"{var.name}[{idx}]"
fmi_var = FmiVariable(
name=name,
variable_reference=var_ref,
Expand All @@ -730,7 +729,7 @@ def format_fmi_variable(
description=var.description or "",
variability=var.variability
or (FmiVariability.CONTINUOUS if var.causality != FmiCausality.PARAMETER else FmiVariability.TUNABLE),
start_value=var.start_value or 0,
start_value=var.start_value if var.start_value is not None else 0,
type=var.type or FmiVariableType.REAL,
)
variables.append(fmi_var)
Expand Down Expand Up @@ -772,7 +771,8 @@ def get_template_mapping(
for state_init_indexes in inp.agent_state_init_indexes:
num_state_init_indexes = len(state_init_indexes)
for variable_index, state_init_index in enumerate(state_init_indexes):
if variable_index >= num_variable_references:
_variable_index = variable_index
if _variable_index >= num_variable_references:
if not self.state_initialization_reuse:
warnings.warn(
f"Too few variables in {inp.name} (={num_variable_references}) "
Expand All @@ -782,7 +782,7 @@ def get_template_mapping(
stacklevel=1,
)
break
_variable_index = variable_index % num_variable_references
_variable_index = _variable_index % num_variable_references
state_init_mapping.append((state_init_index, inp.variable_references[_variable_index]))

for out in self.outputs:
Expand Down
2 changes: 1 addition & 1 deletion src/mlfmu/utils/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def format_template_data(onnx: ONNXModel, fmi_model: FmiModel, model_component:
"The number of total input indexes for all inputs and parameter in the interface file "
f"(={num_fmu_inputs}) cannot exceed the input size of the ml model (={onnx.input_size})"
)

if num_fmu_outputs > onnx.output_size:
raise ValueError(
"The number of total output indexes for all outputs in the interface file "
Expand Down Expand Up @@ -206,7 +207,6 @@ def validate_interface_spec(
The pydantic model instance that contains all the interface information.
"""
parsed_spec = ModelComponent.model_validate_json(json_data=spec, strict=True)

try:
validated_model = ModelComponent.model_validate(parsed_spec)
except ValidationError as e:
Expand Down
Binary file added tests/data/example.onnx
Binary file not shown.
166 changes: 166 additions & 0 deletions tests/utils/test_fmu_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import json
import re
from pathlib import Path

import pytest

from mlfmu.types.fmu_component import FmiModel
from mlfmu.types.onnx_model import ONNXModel
from mlfmu.utils.builder import format_template_data, validate_interface_spec


@pytest.fixture(scope="session")
def wind_generator_onnx() -> ONNXModel:
return ONNXModel(Path.cwd().parent / "data" / "example.onnx", time_input=True)


def test_valid_template_data(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"outputs": [
{
"name": "outputs",
"description": "My outputs",
"agentOutputIndexes": ["0:2"],
"isArray": True,
"length": 2,
}
],
"states": [
{"agentOutputIndexes": ["2:130"]},
{"name": "state1", "startValue": 10.0, "agentOutputIndexes": ["0"]},
{"name": "state2", "startValue": 180.0, "agentOutputIndexes": ["1"]},
],
}
_, model = validate_interface_spec(json.dumps(valid_spec))
assert model is not None

fmi_model = FmiModel(model=model)
template_data = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert template_data["FmuName"] == "example"
assert template_data["numFmuVariables"] == "6"
assert template_data["numOnnxInputs"] == "2"
assert template_data["numOnnxOutputs"] == "130"
assert template_data["numOnnxStates"] == "130"
assert template_data["onnxInputValueReferences"] == "0, 0, 1, 1"
assert template_data["onnxOutputValueReferences"] == "0, 2, 1, 3"


def test_template_data_invalid_input_size(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2},
{
"name": "inputs2",
"description": "My inputs 2",
"agentInputIndexes": ["0:10"],
"isArray": True,
"length": 10,
},
],
"outputs": [
{"name": "outputs", "description": "My outputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"states": [
{"agentOutputIndexes": ["2:130"]},
{"name": "state1", "startValue": 10.0, "agentOutputIndexes": ["0"]},
{"name": "state2", "startValue": 180.0, "agentOutputIndexes": ["1"]},
],
}

_, model = validate_interface_spec(json.dumps(valid_spec))
assert model is not None

fmi_model = FmiModel(model=model)

with pytest.raises(ValueError) as exc_info:
_ = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert exc_info.match(
re.escape(
"The number of total input indexes for all inputs and parameter in the interface file (=12) \
cannot exceed the input size of the ml model (=2)"
)
)


def test_template_data_invalid_output_size(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"outputs": [
{
"name": "outputs",
"description": "My outputs",
"agentOutputIndexes": ["0:2"],
"isArray": True,
"length": 2,
},
{
"name": "outputs2",
"description": "My outputs 2",
"agentOutputIndexes": ["0:200"],
"isArray": True,
"length": 200,
},
],
"states": [
{"agentOutputIndexes": ["2:130"]},
{"name": "state1", "startValue": 10.0, "agentOutputIndexes": ["0"]},
{"name": "state2", "startValue": 180.0, "agentOutputIndexes": ["1"]},
],
}

_, model = validate_interface_spec(json.dumps(valid_spec))
fmi_model = FmiModel(model=model)

with pytest.raises(ValueError) as exc_info:
_ = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert exc_info.match(
re.escape(
"The number of total output indexes for all outputs in the interface file (=202) \
cannot exceed the output size of the ml model (=130)"
)
)


def test_template_data_invalid_state_size(wind_generator_onnx: ONNXModel):
valid_spec = {
"name": "example",
"version": "1.0",
"inputs": [
{"name": "inputs", "description": "My inputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"outputs": [
{"name": "outputs", "description": "My outputs", "agentInputIndexes": ["0:2"], "isArray": True, "length": 2}
],
"states": [
{"agentOutputIndexes": ["2:200"]},
],
}

_, model = validate_interface_spec(json.dumps(valid_spec))
assert model is not None

fmi_model = FmiModel(model=model)

with pytest.raises(ValueError) as exc_info:
_ = format_template_data(onnx=wind_generator_onnx, fmi_model=fmi_model, model_component=model)

assert exc_info.match(
re.escape(
"The number of total output indexes for all states in the interface file (=198) \
cannot exceed either the state input size (=130)"
)
)
Loading