Skip to content

Commit

Permalink
Add function to propagate context to validation
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jan 17, 2025
1 parent 930da71 commit 004309c
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/ert/config/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .analysis_mode import AnalysisMode
from .base_model_context import BaseModelWithContextSupport
from .config_dict import ConfigDict
from .config_errors import ConfigValidationError, ConfigWarning
from .config_keywords import ConfigKeys
Expand All @@ -20,6 +21,7 @@

__all__ = [
"AnalysisMode",
"BaseModelWithContextSupport",
"ConfigDict",
"ConfigKeys",
"ConfigValidationError",
Expand Down
26 changes: 26 additions & 0 deletions src/ert/config/parsing/base_model_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any

from pydantic import BaseModel

init_context_var = ContextVar("_init_context_var", default=None)


@contextmanager
def init_context(value: dict[str, Any]) -> Iterator[None]:
token = init_context_var.set(value) # type: ignore
try:
yield
finally:
init_context_var.reset(token)


class BaseModelWithContextSupport(BaseModel):
def __init__(__pydantic_self__, **data: Any) -> None:
__pydantic_self__.__pydantic_validator__.validate_python(
data,
self_instance=__pydantic_self__,
context=init_context_var.get(),
)
20 changes: 17 additions & 3 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from typing import Annotated, Any, Literal, no_type_check

import pydantic
from pydantic import BaseModel, Field
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from pydantic_core.core_schema import ValidationInfo

from ._get_num_cpu import get_num_cpu_from_data_file
from .parsing import (
BaseModelWithContextSupport,
ConfigDict,
ConfigKeys,
ConfigValidationError,
Expand All @@ -37,12 +39,24 @@ def activate_script() -> str:
return ""


class QueueOptions(BaseModel, validate_assignment=True, extra="forbid"):
class QueueOptions(BaseModelWithContextSupport, validate_assignment=True, extra="forbid"):
name: str
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: str | None = None
activate_script: str = Field(default_factory=activate_script)
activate_script: str | None = Field(default=None, validate_default=True)

@field_validator("activate_script", mode="before")
@classmethod
def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str:
# User value gets highest priority
if isinstance(v, str):
return v
# Use from plugin system if user has not specified
plugin_script = None
if info.context:
plugin_script = info.context.get(info.field_name)
return plugin_script or activate_script() # Return default value

@staticmethod
def create_queue_options(
Expand Down
14 changes: 13 additions & 1 deletion src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from ruamel.yaml import YAML, YAMLError

from ert.config import ErtConfig
from ert.config.parsing import BaseModelWithContextSupport
from ert.config.parsing.base_model_context import init_context
from ert.plugins import ErtPluginManager
from everest.config.control_variable_config import ControlVariableGuessListConfig
from everest.config.install_template_config import InstallTemplateConfig
from everest.config.server_config import ServerConfig
Expand Down Expand Up @@ -134,7 +137,7 @@ class HasName(Protocol):
name: str


class EverestConfig(BaseModelWithPropertySupport): # type: ignore
class EverestConfig(BaseModelWithPropertySupport, BaseModelWithContextSupport): # type: ignore
controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field(
description="""Defines a list of controls.
Controls should have unique names each control defines
Expand Down Expand Up @@ -807,6 +810,15 @@ def load_file(config_file: str) -> "EverestConfig":

raise exp from error

@classmethod
def with_plugins(cls, config_dict):
context = {}
activate_script = ErtPluginManager().activate_script()
if activate_script:
context["activate_script"] = ErtPluginManager().activate_script()
with init_context(context):
return cls(**config_dict)

@staticmethod
def load_file_with_argparser(
config_path, parser: ArgumentParser
Expand Down
12 changes: 1 addition & 11 deletions src/everest/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator

from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

from ..strings import (
CERTIFICATE_DIR,
Expand Down Expand Up @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore
extra="forbid",
)

@field_validator("queue_system", mode="before")
@classmethod
def default_local_queue(cls, v):
if v is None:
return v
elif "activate_script" not in v and ErtPluginManager().activate_script():
v["activate_script"] = ErtPluginManager().activate_script()
return v

@model_validator(mode="before")
@classmethod
def check_old_config(cls, data: Any) -> Any:
Expand Down
5 changes: 0 additions & 5 deletions src/everest/config/simulator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

simulator_example = {"queue_system": {"name": "local", "max_running": 3}}

Expand Down Expand Up @@ -97,10 +96,6 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore
def default_local_queue(cls, v):
if v is None:
return LocalQueueOptions(max_running=8)
if "activate_script" not in v and (
active_script := ErtPluginManager().activate_script()
):
v["activate_script"] = active_script
return v

@model_validator(mode="before")
Expand Down
31 changes: 30 additions & 1 deletion tests/everest/test_detached.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_generate_queue_options_use_simulator_values(
queue_options, expected_result, monkeypatch
):
monkeypatch.setattr(
everest.config.server_config.ErtPluginManager,
everest.config.everest_config.ErtPluginManager,
"activate_script",
MagicMock(return_value=activate_script()),
)
Expand All @@ -295,6 +295,35 @@ def test_generate_queue_options_use_simulator_values(
assert config.server.queue_system == expected_result


@pytest.mark.parametrize("use_plugin", (True, False))
@pytest.mark.parametrize(
"queue_options",
[
{"name": "slurm", "activate_script": "From user"},
{"name": "slurm"},
],
)
def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_config):
plugin_result = "From plugin"
if "activate_script" in queue_options:
expected_result = queue_options["activate_script"]
elif use_plugin:
expected_result = plugin_result
else:
expected_result = activate_script()

if use_plugin:
monkeypatch.setattr(
everest.config.everest_config.ErtPluginManager,
"activate_script",
MagicMock(return_value=plugin_result),
)
config = EverestConfig.with_plugins(
{"simulator": {"queue_system": queue_options}} | min_config
)
assert config.server.queue_system.activate_script == expected_result


@pytest.mark.timeout(5) # Simulation might not finish
@pytest.mark.integration_test
@pytest.mark.xdist_group(name="starts_everest")
Expand Down

0 comments on commit 004309c

Please sign in to comment.