Skip to content

Commit

Permalink
Store ert<->everest realization mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Jan 20, 2025
1 parent 1528d2e commit cde99c8
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
35 changes: 35 additions & 0 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from everest.strings import EVEREST

from ..run_arg import RunArg, create_run_arguments
from ..storage.local_ensemble import EverestRealizationInfo
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
Expand Down Expand Up @@ -375,6 +376,40 @@ def _forward_model_evaluator(
name=f"batch_{self._batch_id}",
ensemble_size=len(batch_data),
)

realizations = self._everest_config.model.realizations
num_perturbations = self._everest_config.optimization.perturbation_num
realization_mapping: dict[int, EverestRealizationInfo] = {}

if len(evaluator_context.realizations) == len(realizations):
# Function evaluation
realization_mapping = {
i: EverestRealizationInfo(geo_realization=real, perturbation=None)
for i, real in enumerate(realizations)
}
elif len(evaluator_context.realizations) == num_perturbations:
realization_mapping = {
p: EverestRealizationInfo(geo_realization=real, perturbation=p)
for p, real in enumerate(realizations)
}
else:
# Function and gradient
realization_mapping = {}
for i, real in enumerate(realizations):
realization_mapping[i] = EverestRealizationInfo(
geo_realization=real, perturbation=None
)

i = len(realization_mapping)
for real in realizations:
for p in range(num_perturbations):
realization_mapping[i] = EverestRealizationInfo(
geo_realization=real, perturbation=p
)
i += 1

# Fill in data from ROPT here
ensemble.save_everest_metadata(realization_mapping=realization_mapping)
for sim_id, controls in enumerate(batch_data.values()):
self._setup_sim(sim_id, controls, ensemble)

Expand Down
43 changes: 42 additions & 1 deletion src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime
from functools import cache, lru_cache
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict
from uuid import UUID

import numpy as np
Expand All @@ -32,6 +32,13 @@
import polars


class EverestRealizationInfo(TypedDict):
geo_realization: int
perturbation: int | None # None means it stems from unperturbed controls
# Q: Maybe we also need result ID, or no? Ref if we have multiple evaluations
# for unperturbed values, though the ERT real id will also differentiate them


class _Index(BaseModel):
id: UUID
experiment_id: UUID
Expand All @@ -41,6 +48,8 @@ class _Index(BaseModel):
prior_ensemble_id: UUID | None
started_at: datetime

everest_realizations: dict[int, EverestRealizationInfo] | None = None


class _Failure(BaseModel):
type: RealizationStorageState
Expand Down Expand Up @@ -1039,3 +1048,35 @@ def get_observations_and_responses(
return polars.concat(dfs_per_response_type, how="vertical").with_columns(
polars.col("response_key").cast(polars.String).alias("response_key")
)

def save_everest_metadata(
self, realization_mapping: dict[int, EverestRealizationInfo]
):
self._index.everest_realizations = realization_mapping
self._storage._write_transaction(
self._path / "index.json", self._index.model_dump_json().encode("utf-8")
)
self._index = _Index.model_validate_json(
(self._path / "index.json").read_text(encoding="utf-8")
)

def _everest_to_ert_realizations(
self, geo_realization, perturbation: int | None = None
) -> tuple[int, ...]:
pass

def load_everest_parameters(
self, geo_realization, perturbation: int | None = None
) -> polars.DataFrame:
# ert_realizations = self._everest_to_ert_realizations(
# geo_realization, perturbation
# )
pass

def load_everest_responses(
self, geo_realization: int, perturbation: int | None = None
) -> polars.DataFrame:
ert_realizations = self._everest_to_ert_realizations(
geo_realization, perturbation
)
return self.load_responses("gen_data", ert_realizations)

0 comments on commit cde99c8

Please sign in to comment.