diff --git a/src/ert/config/parsing/__init__.py b/src/ert/config/parsing/__init__.py index 6ac0114fe67..36f4e1fcd91 100644 --- a/src/ert/config/parsing/__init__.py +++ b/src/ert/config/parsing/__init__.py @@ -1,4 +1,5 @@ from .analysis_mode import AnalysisMode +from .base_model_context import BaseModelWithContextSupport from .config_dict import ConfigDict from .config_errors import ConfigValidationError, ConfigWarning from .config_keywords import ConfigKeys @@ -20,6 +21,7 @@ __all__ = [ "AnalysisMode", + "BaseModelWithContextSupport", "ConfigDict", "ConfigKeys", "ConfigValidationError", diff --git a/src/ert/config/parsing/base_model_context.py b/src/ert/config/parsing/base_model_context.py new file mode 100644 index 00000000000..29bdf17d4cb --- /dev/null +++ b/src/ert/config/parsing/base_model_context.py @@ -0,0 +1,26 @@ +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any + +from pydantic import BaseModel + +init_context_var = ContextVar("_init_context_var", default=None) + + +@contextmanager +def init_context(value: dict[str, Any]) -> Iterator[None]: + token = init_context_var.set(value) # type: ignore + try: + yield + finally: + init_context_var.reset(token) + + +class BaseModelWithContextSupport(BaseModel): + def __init__(__pydantic_self__, **data: Any) -> None: + __pydantic_self__.__pydantic_validator__.validate_python( + data, + self_instance=__pydantic_self__, + context=init_context_var.get(), + ) diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index aad424db1bb..ddc15d3aa7e 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -8,11 +8,13 @@ from typing import Annotated, Any, Literal, no_type_check import pydantic -from pydantic import BaseModel, Field +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass +from pydantic_core.core_schema import ValidationInfo from ._get_num_cpu import get_num_cpu_from_data_file from .parsing import ( + BaseModelWithContextSupport, ConfigDict, ConfigKeys, ConfigValidationError, @@ -37,12 +39,24 @@ def activate_script() -> str: return "" -class QueueOptions(BaseModel, validate_assignment=True, extra="forbid"): +class QueueOptions(BaseModelWithContextSupport, validate_assignment=True, extra="forbid"): name: str max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 project_code: str | None = None - activate_script: str = Field(default_factory=activate_script) + activate_script: str | None = Field(default=None, validate_default=True) + + @field_validator("activate_script", mode="before") + @classmethod + def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str: + # User value gets highest priority + if isinstance(v, str): + return v + # Use from plugin system if user has not specified + plugin_script = None + if info.context: + plugin_script = info.context.get(info.field_name) + return plugin_script or activate_script() # Return default value @staticmethod def create_queue_options( diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index d9dc43c9657..e8473a9f90e 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -28,6 +28,9 @@ from ruamel.yaml import YAML, YAMLError from ert.config import ErtConfig +from ert.config.parsing import BaseModelWithContextSupport +from ert.config.parsing.base_model_context import init_context +from ert.plugins import ErtPluginManager from everest.config.control_variable_config import ControlVariableGuessListConfig from everest.config.install_template_config import InstallTemplateConfig from everest.config.server_config import ServerConfig @@ -134,7 +137,7 @@ class HasName(Protocol): name: str -class EverestConfig(BaseModelWithPropertySupport): # type: ignore +class EverestConfig(BaseModelWithPropertySupport, BaseModelWithContextSupport): # type: ignore controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field( description="""Defines a list of controls. Controls should have unique names each control defines @@ -807,6 +810,15 @@ def load_file(config_file: str) -> "EverestConfig": raise exp from error + @classmethod + def with_plugins(cls, config_dict): + context = {} + activate_script = ErtPluginManager().activate_script() + if activate_script: + context["activate_script"] = ErtPluginManager().activate_script() + with init_context(context): + return cls(**config_dict) + @staticmethod def load_file_with_argparser( config_path, parser: ArgumentParser diff --git a/src/everest/config/server_config.py b/src/everest/config/server_config.py index f4f7bd9b27a..de4c9691100 100644 --- a/src/everest/config/server_config.py +++ b/src/everest/config/server_config.py @@ -2,7 +2,7 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from ert.config.queue_config import ( LocalQueueOptions, @@ -10,7 +10,6 @@ SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager from ..strings import ( CERTIFICATE_DIR, @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore extra="forbid", ) - @field_validator("queue_system", mode="before") - @classmethod - def default_local_queue(cls, v): - if v is None: - return v - elif "activate_script" not in v and ErtPluginManager().activate_script(): - v["activate_script"] = ErtPluginManager().activate_script() - return v - @model_validator(mode="before") @classmethod def check_old_config(cls, data: Any) -> Any: diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index b62ace63d6b..1af76a8c471 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -15,7 +15,6 @@ SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager simulator_example = {"queue_system": {"name": "local", "max_running": 3}} @@ -97,10 +96,6 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore def default_local_queue(cls, v): if v is None: return LocalQueueOptions(max_running=8) - if "activate_script" not in v and ( - active_script := ErtPluginManager().activate_script() - ): - v["activate_script"] = active_script return v @model_validator(mode="before") diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 064c881301d..91597894498 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -285,7 +285,7 @@ def test_generate_queue_options_use_simulator_values( queue_options, expected_result, monkeypatch ): monkeypatch.setattr( - everest.config.server_config.ErtPluginManager, + everest.config.everest_config.ErtPluginManager, "activate_script", MagicMock(return_value=activate_script()), ) @@ -295,6 +295,35 @@ def test_generate_queue_options_use_simulator_values( assert config.server.queue_system == expected_result +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"name": "slurm", "activate_script": "From user"}, + {"name": "slurm"}, + ], +) +def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_config): + plugin_result = "From plugin" + if "activate_script" in queue_options: + expected_result = queue_options["activate_script"] + elif use_plugin: + expected_result = plugin_result + else: + expected_result = activate_script() + + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtPluginManager, + "activate_script", + MagicMock(return_value=plugin_result), + ) + config = EverestConfig.with_plugins( + {"simulator": {"queue_system": queue_options}} | min_config + ) + assert config.server.queue_system.activate_script == expected_result + + @pytest.mark.timeout(5) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest")