Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Jan 22, 2025
1 parent a2212ed commit 4d44cab
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 69 deletions.
150 changes: 89 additions & 61 deletions src/everest/config/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterator
from collections.abc import Generator, Iterator
from typing import Any

from .control_config import ControlConfig
from .control_variable_config import (
Expand All @@ -17,66 +18,18 @@ def __init__(self, controls: list[ControlConfig]) -> None:
control_sampler_idx = -1
variables = []
for variable in control.variables:
if isinstance(variable, ControlVariableConfig):
var_dict = {
key: getattr(variable, key)
for key in [
"control_type",
"enabled",
"auto_scale",
"scaled_range",
"min",
"max",
"perturbation_magnitude",
"initial_guess",
]
}

var_dict["name"] = (
(control.name, variable.name)
if variable.index is None
else (control.name, variable.name, variable.index)
)

if variable.sampler is not None:
self._samplers.append(variable.sampler)
var_dict["sampler_idx"] = len(self._samplers) - 1
else:
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
var_dict["sampler_idx"] = control_sampler_idx

variables.append(var_dict)
elif isinstance(variable, ControlVariableGuessListConfig):
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
variables.extend(
{
"name": (control.name, variable.name, index + 1),
"initial_guess": guess,
"sampler_idx": control_sampler_idx,
}
for index, guess in enumerate(variable.initial_guess)
)

for var_dict in variables:
for key in [
"type",
"initial_guess",
"control_type",
"enabled",
"auto_scale",
"min",
"max",
"perturbation_type",
"perturbation_magnitude",
"scaled_range",
]:
if var_dict.get(key) is None:
var_dict[key] = getattr(control, key)

match variable:
case ControlVariableConfig():
var_dict, control_sampler_idx = self._add_variable(
control, variable, control_sampler_idx
)
variables.append(var_dict)
case ControlVariableGuessListConfig():
var_dicts, control_sampler_idx = self._add_variable_guess_list(
control, variable, control_sampler_idx
)
variables.extend(var_dicts)
self._inject_defaults(control, variables)
self._controls.extend(variables)

self.names = [control["name"] for control in self._controls]
Expand All @@ -102,6 +55,81 @@ def __init__(self, controls: list[ControlConfig]) -> None:
self.sampler_indices = [control["sampler_idx"] for control in self._controls]
self.samplers = self._samplers

def _add_variable(
self,
control: ControlConfig,
variable: ControlVariableConfig,
control_sampler_idx: int,
) -> tuple[dict[str, Any], int]:
var_dict = {
key: getattr(variable, key)
for key in [
"control_type",
"enabled",
"auto_scale",
"scaled_range",
"min",
"max",
"perturbation_magnitude",
"initial_guess",
]
}
var_dict["name"] = (
(control.name, variable.name)
if variable.index is None
else (control.name, variable.name, variable.index)
)
if variable.sampler is not None:
self._samplers.append(variable.sampler)
var_dict["sampler_idx"] = len(self._samplers) - 1
else:
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
var_dict["sampler_idx"] = control_sampler_idx
return var_dict, control_sampler_idx

def _add_variable_guess_list(
self,
control: ControlConfig,
variable: ControlVariableGuessListConfig,
control_sampler_idx: int,
) -> tuple[Generator[dict[str, Any], None, None], int]:
if control.sampler is not None and control_sampler_idx < 0:
self._samplers.append(control.sampler)
control_sampler_idx = len(self._samplers) - 1
return (
(
{
"name": (control.name, variable.name, index + 1),
"initial_guess": guess,
"sampler_idx": control_sampler_idx,
}
for index, guess in enumerate(variable.initial_guess)
),
control_sampler_idx,
)

@staticmethod
def _inject_defaults(
control: ControlConfig, variables: list[dict[str, Any]]
) -> None:
for var_dict in variables:
for key in [
"type",
"initial_guess",
"control_type",
"enabled",
"auto_scale",
"min",
"max",
"perturbation_type",
"perturbation_magnitude",
"scaled_range",
]:
if var_dict.get(key) is None:
var_dict[key] = getattr(control, key)


def control_tuples(
controls: list[ControlConfig],
Expand Down
23 changes: 15 additions & 8 deletions src/everest/optimizer/everest2ropt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from everest.config import (
ControlConfig,
EverestConfig,
InputConstraintConfig,
ModelConfig,
ObjectiveFunctionConfig,
OptimizationConfig,
Expand Down Expand Up @@ -70,9 +71,7 @@ def _parse_controls(ever_controls: list[ControlConfig], ropt_config):
}
for sampler in controls.samplers
]
ropt_config["gradient"]["samplers"] = [
max(0, idx) for idx in controls.sampler_indices
]
ropt_config["gradient"]["samplers"] = controls.sampler_indices

default_magnitude = (max(controls.upper_bounds) - min(controls.lower_bounds)) / 10.0
ropt_config["gradient"]["perturbation_magnitudes"] = [
Expand Down Expand Up @@ -139,17 +138,23 @@ def _parse_objectives(objective_functions: list[ObjectiveFunctionConfig], ropt_c
ropt_config["function_transforms"] = transforms


def _parse_input_constraints(ever_config: EverestConfig, ropt_config):
if not ever_config.input_constraints:
def _parse_input_constraints(
controls: list[ControlConfig],
input_constraints: list[InputConstraintConfig] | None,
ropt_config,
):
if not input_constraints:
return

# TODO: Issue #9816 is intended to address the need for a more general
# naming scheme. This code should be revisited once that issue is resolved.
formatted_names = [
(
f"{control_name[0]}.{control_name[1]}-{control_name[2]}"
if len(control_name) > 2
else f"{control_name[0]}.{control_name[1]}"
)
for control_name in control_tuples(ever_config.controls)
for control_name in control_tuples(controls)
]

coefficients_matrix = []
Expand All @@ -162,7 +167,7 @@ def _add_input_constraint(rhs_value, coefficients, constraint_type):
rhs_values.append(rhs_value)
types.append(constraint_type)

for constr in ever_config.input_constraints:
for constr in input_constraints:
coefficients = [0.0] * len(formatted_names)
for name, value in constr.weights.items():
coefficients[formatted_names.index(name)] = value
Expand Down Expand Up @@ -379,7 +384,9 @@ def everest2ropt(ever_config: EverestConfig) -> EnOptConfig:

_parse_controls(ever_config.controls, ropt_config)
_parse_objectives(ever_config.objective_functions, ropt_config)
_parse_input_constraints(ever_config, ropt_config)
_parse_input_constraints(
ever_config.controls, ever_config.input_constraints, ropt_config
)
_parse_output_constraints(ever_config.output_constraints, ropt_config)
_parse_optimization(
ever_opt=ever_config.optimization,
Expand Down

0 comments on commit 4d44cab

Please sign in to comment.