Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor control handling in everest config #9805

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig
from ert.runpaths import Runpaths
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig
from everest.optimizer.everest2ropt import everest2ropt
from everest.simulator.everest_to_ert import everest_to_ert_config
from everest.strings import EVEREST
Expand Down Expand Up @@ -433,39 +433,36 @@ def _init_batch_data(
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
) -> dict[int, dict[str, Any]]:
def add_control(
controls: dict[str, Any],
control_name: tuple[Any, ...],
control_value: float,
) -> None:
group_name = control_name[0]
variable_name = control_name[1]
group = controls.get(group_name, {})
if len(control_name) > 2:
index_name = str(control_name[2])
if variable_name in group:
group[variable_name][index_name] = control_value
else:
group[variable_name] = {index_name: control_value}
else:
group[variable_name] = control_value
controls[group_name] = group

batch_data = {}
for control_idx in range(control_values.shape[0]):
if control_idx not in cached_results and (
evaluator_context.active is None
or evaluator_context.active[evaluator_context.realizations[control_idx]]
):
controls: dict[str, Any] = {}
for control_name, control_value in zip(
self._everest_config.control_name_tuples,
control_values[control_idx, :],
strict=False,
):
add_control(controls, control_name, control_value)
batch_data[control_idx] = controls
return batch_data
def _add_controls(
controls_config: list[ControlConfig], values: NDArray[np.float64]
) -> dict[str, Any]:
batch_data_item: dict[str, Any] = {}
value_list = values.tolist()
for control in controls_config:
control_dict: dict[str, Any] = batch_data_item.get(control.name, {})
for variable in control.variables:
variable_value = control_dict.get(variable.name, {})
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
variable_value[str(index)] = value_list.pop(0)
elif variable.index is not None:
variable_value[str(variable.index)] = value_list.pop(0)
else:
variable_value = value_list.pop(0)
control_dict[variable.name] = variable_value
batch_data_item[control.name] = control_dict
return batch_data_item

active = evaluator_context.active
realizations = evaluator_context.realizations
return {
idx: _add_controls(self._everest_config.controls, control_values[idx, :])
for idx in range(control_values.shape[0])
if (
idx not in cached_results
and (active is None or active[realizations[idx]])
)
}

def _setup_sim(
self,
Expand Down
27 changes: 4 additions & 23 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from argparse import ArgumentParser
from copy import copy
from functools import cached_property
from io import StringIO
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -623,25 +622,11 @@ def control_names(self):
controls = self.controls or []
return [control.name for control in controls]

@cached_property
def control_name_tuples(self) -> list[tuple[str, str, int | tuple[str, str]]]:
tuples = []
for control in self.controls:
for variable in control.variables:
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
tuples.append((control.name, variable.name, index))
elif variable.index is not None:
tuples.append((control.name, variable.name, variable.index))
else:
tuples.append((control.name, variable.name))
return tuples

@property
def objective_names(self) -> list[str]:
return [objective.name for objective in self.objective_functions]

@cached_property
@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering, method seems unchanged and result doesn't seem to change, is there a reason we remove caching here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point I ran into an issue where @cached_property did not work properly. EverestConfig is not immutable and if you change it afterwards by assigning to it the cache is not updated. So that is a bit dangerous. There is currently an effort to make it immutable again, then we could use @cached_property.

def constraint_names(self) -> list[str]:
names: list[str] = []

Expand All @@ -651,16 +636,12 @@ def _add_output_constraint(rhs_value: float | None, suffix=None):
names.append(name if suffix is None else f"{name}:{suffix}")

for constr in self.output_constraints or []:
_add_output_constraint(constr.target)
_add_output_constraint(
constr.target,
)
_add_output_constraint(
constr.upper_bound,
None if constr.lower_bound is None else "upper",
constr.upper_bound, None if constr.lower_bound is None else "upper"
)
_add_output_constraint(
constr.lower_bound,
None if constr.upper_bound is None else "lower",
constr.lower_bound, None if constr.upper_bound is None else "lower"
)

return names
Expand Down
145 changes: 145 additions & 0 deletions src/everest/config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections.abc import Generator, Iterator
from typing import Any

from .control_config import ControlConfig
from .control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from .sampler_config import SamplerConfig


class FlattenedControls:
def __init__(self, controls: list[ControlConfig]) -> None:
self._controls = []
self._samplers: list[SamplerConfig] = []

for control in controls:
control_sampler_idx = -1
variables = []
for variable in control.variables:
match variable:
case ControlVariableConfig():
var_dict, control_sampler_idx = self._add_variable(
control, variable, control_sampler_idx
)
variables.append(var_dict)
case ControlVariableGuessListConfig():
var_dicts, control_sampler_idx = self._add_variable_guess_list(
control, variable, control_sampler_idx
)
variables.extend(var_dicts)
self._inject_defaults(control, variables)
self._controls.extend(variables)

self.names = [control["name"] for control in self._controls]
self.types = [
None if control["control_type"] is None else control["control_type"]
for control in self._controls
]
self.initial_guesses = [control["initial_guess"] for control in self._controls]
self.lower_bounds = [control["min"] for control in self._controls]
self.upper_bounds = [control["max"] for control in self._controls]
self.auto_scales = [control["auto_scale"] for control in self._controls]
self.scaled_ranges = [
(0.0, 1.0) if control["scaled_range"] is None else control["scaled_range"]
for control in self._controls
]
self.enabled = [control["enabled"] for control in self._controls]
self.perturbation_magnitudes = [
control["perturbation_magnitude"] for control in self._controls
]
self.perturbation_types = [
control["perturbation_type"] for control in self._controls
]
self.sampler_indices = [control["sampler_idx"] for control in self._controls]
self.samplers = self._samplers

def _add_variable(
self,
control: ControlConfig,
variable: ControlVariableConfig,
control_sampler_idx: int,
) -> tuple[dict[str, Any], int]:
var_dict = {
key: getattr(variable, key)
for key in [
"control_type",
"enabled",
"auto_scale",
"scaled_range",
"min",
"max",
"perturbation_magnitude",
"initial_guess",
]
}
var_dict["name"] = (
(control.name, variable.name)
if variable.index is None
else (control.name, variable.name, variable.index)
)
if variable.sampler is not None:
self._samplers.append(variable.sampler)
var_dict["sampler_idx"] = len(self._samplers) - 1
else:
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
var_dict["sampler_idx"] = control_sampler_idx
return var_dict, control_sampler_idx

def _add_variable_guess_list(
self,
control: ControlConfig,
variable: ControlVariableGuessListConfig,
control_sampler_idx: int,
) -> tuple[Generator[dict[str, Any], None, None], int]:
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
return (
(
{
"name": (control.name, variable.name, index + 1),
"initial_guess": guess,
"sampler_idx": control_sampler_idx,
}
for index, guess in enumerate(variable.initial_guess)
),
control_sampler_idx,
)

@staticmethod
def _inject_defaults(
control: ControlConfig, variables: list[dict[str, Any]]
) -> None:
for var_dict in variables:
for key in [
"type",
"initial_guess",
"control_type",
"enabled",
"auto_scale",
"min",
"max",
"perturbation_type",
"perturbation_magnitude",
"scaled_range",
]:
if var_dict.get(key) is None:
var_dict[key] = getattr(control, key)


def control_tuples(
controls: list[ControlConfig],
) -> Iterator[tuple[str, str, int] | tuple[str, str]]:
for control in controls:
for variable in control.variables:
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
yield (control.name, variable.name, index)
elif variable.index is not None:
yield (control.name, variable.name, variable.index)
else:
yield (control.name, variable.name)
Loading
Loading