Skip to content

Commit

Permalink
Refactor everest2ropt control parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Jan 20, 2025
1 parent 941a21b commit 3632cb5
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 289 deletions.
65 changes: 31 additions & 34 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig
from ert.runpaths import Runpaths
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig
from everest.optimizer.everest2ropt import everest2ropt
from everest.simulator.everest_to_ert import everest_to_ert_config
from everest.strings import EVEREST
Expand Down Expand Up @@ -433,39 +433,36 @@ def _init_batch_data(
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
) -> dict[int, dict[str, Any]]:
def add_control(
controls: dict[str, Any],
control_name: tuple[Any, ...],
control_value: float,
) -> None:
group_name = control_name[0]
variable_name = control_name[1]
group = controls.get(group_name, {})
if len(control_name) > 2:
index_name = str(control_name[2])
if variable_name in group:
group[variable_name][index_name] = control_value
else:
group[variable_name] = {index_name: control_value}
else:
group[variable_name] = control_value
controls[group_name] = group

batch_data = {}
for control_idx in range(control_values.shape[0]):
if control_idx not in cached_results and (
evaluator_context.active is None
or evaluator_context.active[evaluator_context.realizations[control_idx]]
):
controls: dict[str, Any] = {}
for control_name, control_value in zip(
self._everest_config.control_name_tuples,
control_values[control_idx, :],
strict=False,
):
add_control(controls, control_name, control_value)
batch_data[control_idx] = controls
return batch_data
def _add_controls(
controls_config: list[ControlConfig], values: NDArray[np.float64]
) -> dict[str, Any]:
batch_data_item: dict[str, Any] = {}
value_list = values.tolist()
for control in controls_config:
control_dict: dict[str, Any] = batch_data_item.get(control.name, {})
for variable in control.variables:
variable_value = control_dict.get(variable.name, {})
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
variable_value[index] = value_list.pop(0)
elif variable.index is not None:
variable_value[variable.index] = value_list.pop(0)
else:
variable_value = value_list.pop(0)
control_dict[variable.name] = variable_value
batch_data_item[control.name] = control_dict
return batch_data_item

active = evaluator_context.active
realizations = evaluator_context.realizations
return {
idx: _add_controls(self._everest_config.controls, control_values[idx, :])
for idx in range(control_values.shape[0])
if (
idx not in cached_results
and (active is None or active[realizations[idx]])
)
}

def _setup_sim(
self,
Expand Down
27 changes: 4 additions & 23 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from argparse import ArgumentParser
from copy import copy
from functools import cached_property
from io import StringIO
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -647,25 +646,11 @@ def control_names(self):
controls = self.controls or []
return [control.name for control in controls]

@cached_property
def control_name_tuples(self) -> list[tuple[str, str, int | tuple[str, str]]]:
tuples = []
for control in self.controls:
for variable in control.variables:
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
tuples.append((control.name, variable.name, index))
elif variable.index is not None:
tuples.append((control.name, variable.name, variable.index))
else:
tuples.append((control.name, variable.name))
return tuples

@property
def objective_names(self) -> list[str]:
return [objective.name for objective in self.objective_functions]

@cached_property
@property
def constraint_names(self) -> list[str]:
names: list[str] = []

Expand All @@ -675,16 +660,12 @@ def _add_output_constraint(rhs_value: float | None, suffix=None):
names.append(name if suffix is None else f"{name}:{suffix}")

for constr in self.output_constraints or []:
_add_output_constraint(constr.target)
_add_output_constraint(
constr.target,
)
_add_output_constraint(
constr.upper_bound,
None if constr.lower_bound is None else "upper",
constr.upper_bound, None if constr.lower_bound is None else "upper"
)
_add_output_constraint(
constr.lower_bound,
None if constr.upper_bound is None else "lower",
constr.lower_bound, None if constr.upper_bound is None else "lower"
)

return names
Expand Down
117 changes: 117 additions & 0 deletions src/everest/config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections.abc import Iterator

from .control_config import ControlConfig
from .control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from .sampler_config import SamplerConfig


class FlattenedControls:
def __init__(self, controls: list[ControlConfig]) -> None:
self._controls = []
self._samplers: list[SamplerConfig] = []

for control in controls:
control_sampler_idx = -1
variables = []
for variable in control.variables:
if isinstance(variable, ControlVariableConfig):
var_dict = {
key: getattr(variable, key)
for key in [
"control_type",
"enabled",
"auto_scale",
"scaled_range",
"min",
"max",
"perturbation_magnitude",
"initial_guess",
]
}

var_dict["name"] = (
(control.name, variable.name)
if variable.index is None
else (control.name, variable.name, variable.index)
)

if variable.sampler is not None:
self._samplers.append(variable.sampler)
var_dict["sampler_idx"] = len(self._samplers) - 1
else:
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
var_dict["sampler_idx"] = control_sampler_idx

variables.append(var_dict)
elif isinstance(variable, ControlVariableGuessListConfig):
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
variables.extend(
{
"name": (control.name, variable.name, index + 1),
"initial_guess": guess,
"sampler_idx": control_sampler_idx,
}
for index, guess in enumerate(variable.initial_guess)
)

for var_dict in variables:
for key in [
"type",
"initial_guess",
"control_type",
"enabled",
"auto_scale",
"min",
"max",
"perturbation_type",
"perturbation_magnitude",
"scaled_range",
]:
if var_dict.get(key) is None:
var_dict[key] = getattr(control, key)

self._controls.extend(variables)

self.names = [control["name"] for control in self._controls]
self.types = [
None if control["control_type"] is None else control["control_type"]
for control in self._controls
]
self.initial_guesses = [control["initial_guess"] for control in self._controls]
self.lower_bounds = [control["min"] for control in self._controls]
self.upper_bounds = [control["max"] for control in self._controls]
self.auto_scales = [control["auto_scale"] for control in self._controls]
self.scaled_ranges = [
(0.0, 1.0) if control["scaled_range"] is None else control["scaled_range"]
for control in self._controls
]
self.enabled = [control["enabled"] for control in self._controls]
self.perturbation_magnitudes = [
control["perturbation_magnitude"] for control in self._controls
]
self.perturbation_types = [
control["perturbation_type"] for control in self._controls
]
self.sampler_indices = [control["sampler_idx"] for control in self._controls]
self.samplers = self._samplers


def control_tuples(
controls: list[ControlConfig],
) -> Iterator[tuple[str, str, int] | tuple[str, str]]:
for control in controls:
for variable in control.variables:
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
yield (control.name, variable.name, index)
elif variable.index is not None:
yield (control.name, variable.name, variable.index)
else:
yield (control.name, variable.name)
Loading

0 comments on commit 3632cb5

Please sign in to comment.