Skip to content

Commit

Permalink
Refactor implementation after review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
DanSava committed Jan 20, 2025
1 parent e5c019c commit 3436a3e
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 117 deletions.
18 changes: 10 additions & 8 deletions src/ert/analysis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import pandas as pd
from pydantic import BaseModel, ConfigDict


@dataclass
class AnalysisEvent:
class AnalysisEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
pass


@dataclass
class AnalysisStatusEvent(AnalysisEvent):
event_type: Literal["AnalysisStatusEvent"] = "AnalysisStatusEvent"
msg: str


@dataclass
class AnalysisTimeEvent(AnalysisEvent):
event_type: Literal["AnalysisTimeEvent"] = "AnalysisTimeEvent"
remaining_time: float
elapsed_time: float


@dataclass
class AnalysisReportEvent(AnalysisEvent):
event_type: Literal["AnalysisReportEvent"] = "AnalysisReportEvent"
report: str


Expand Down Expand Up @@ -56,18 +58,18 @@ def to_csv(self, name: str, output_path: Path) -> None:
df.to_csv(f_path.with_suffix(".csv"))


@dataclass
class AnalysisDataEvent(AnalysisEvent):
event_type: Literal["AnalysisDataEvent"] = "AnalysisDataEvent"
name: str
data: DataSection


@dataclass
class AnalysisErrorEvent(AnalysisEvent):
event_type: Literal["AnalysisErrorEvent"] = "AnalysisErrorEvent"
error_msg: str
data: DataSection | None = None


@dataclass
class AnalysisCompleteEvent(AnalysisEvent):
event_type: Literal["AnalysisCompleteEvent"] = "AnalysisCompleteEvent"
data: DataSection
40 changes: 30 additions & 10 deletions src/ert/ensemble_evaluator/event.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
from dataclasses import dataclass
from collections.abc import Mapping
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, field_serializer, field_validator

from .snapshot import EnsembleSnapshot


@dataclass
class _UpdateEvent:
class _UpdateEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
iteration_label: str
total_iterations: int
progress: float
realization_count: int
status_count: dict[str, int]
iteration: int
snapshot: EnsembleSnapshot | None = None

@field_serializer("snapshot")
def serialize_snapshot(
self, value: EnsembleSnapshot | None
) -> dict[str, Any] | None:
if value is None:
return None
return value.to_dict()

@field_validator("snapshot", mode="before")
@classmethod
def validate_snapshot(
cls, value: EnsembleSnapshot | Mapping[Any, Any]
) -> EnsembleSnapshot:
if isinstance(value, EnsembleSnapshot):
return value
return EnsembleSnapshot.from_nested_dict(value)


@dataclass
class FullSnapshotEvent(_UpdateEvent):
snapshot: EnsembleSnapshot | None = None
event_type: Literal["FullSnapshotEvent"] = "FullSnapshotEvent"


@dataclass
class SnapshotUpdateEvent(_UpdateEvent):
snapshot: EnsembleSnapshot | None = None
event_type: Literal["SnapshotUpdateEvent"] = "SnapshotUpdateEvent"


@dataclass
class EndEvent:
class EndEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
event_type: Literal["EndEvent"] = "EndEvent"
failed: bool
msg: str | None = None
msg: str
25 changes: 10 additions & 15 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ class UnsupportedOperationException(ValueError):


def convert_iso8601_to_datetime(
timestamp: datetime | str,
) -> datetime:
timestamp: datetime | str | None,
) -> datetime | None:
if timestamp is None:
return None
if isinstance(timestamp, datetime):
return timestamp

Expand Down Expand Up @@ -401,27 +403,20 @@ class RealizationSnapshot(TypedDict, total=False):
def _realization_dict_to_realization_snapshot(
source: dict[str, Any],
) -> RealizationSnapshot:
start_time = source.get("start_time")
if start_time and isinstance(start_time, str):
start_time = datetime.fromisoformat(start_time)
end_time = source.get("end_time")
if end_time and isinstance(end_time, str):
end_time = datetime.fromisoformat(end_time)

realization = RealizationSnapshot(
status=source.get("status"),
active=source.get("active"),
start_time=start_time,
end_time=end_time,
start_time=convert_iso8601_to_datetime(source.get("start_time")),
end_time=convert_iso8601_to_datetime(source.get("end_time")),
exec_hosts=source.get("exec_hosts"),
message=source.get("message"),
fm_steps=source.get("fm_steps", {}),
)
for step in realization["fm_steps"].values():
if "start_time" in step and isinstance(step["start_time"], str):
step["start_time"] = datetime.fromisoformat(step["start_time"])
if "end_time" in step and isinstance(step["end_time"], str):
step["end_time"] = datetime.fromisoformat(step["end_time"])
if step.get("start_time"):
step["start_time"] = convert_iso8601_to_datetime(step["start_time"])
if step.get("end_time"):
step["end_time"] = convert_iso8601_to_datetime(step["end_time"])
return _filter_nones(realization)


Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AnalysisCompleteEvent,
AnalysisDataEvent,
AnalysisErrorEvent,
AnalysisEvent,
)
from ert.config import HookRuntime, QueueSystem
from ert.config.analysis_module import BaseSettings
Expand Down Expand Up @@ -60,7 +61,6 @@
from ..config.analysis_config import UpdateSettings
from ..run_arg import RunArg
from .event import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
EndEvent,
Expand Down
101 changes: 21 additions & 80 deletions src/ert/run_models/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import json
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Annotated, Literal
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, TypeAdapter

from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
)
Expand All @@ -17,33 +16,32 @@
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot


@dataclass
class RunModelEvent:
class RunModelEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
iteration: int
run_id: UUID


@dataclass
class RunModelStatusEvent(RunModelEvent):
event_type: Literal["RunModelStatusEvent"] = "RunModelStatusEvent"
msg: str


@dataclass
class RunModelTimeEvent(RunModelEvent):
event_type: Literal["RunModelTimeEvent"] = "RunModelTimeEvent"
remaining_time: float
elapsed_time: float


@dataclass
class RunModelUpdateBeginEvent(RunModelEvent):
event_type: Literal["RunModelUpdateBeginEvent"] = "RunModelUpdateBeginEvent"
pass


@dataclass
class RunModelDataEvent(RunModelEvent):
event_type: Literal["RunModelDataEvent"] = "RunModelDataEvent"
name: str
data: DataSection

Expand All @@ -52,28 +50,27 @@ def write_as_csv(self, output_path: Path | None) -> None:
self.data.to_csv(self.name, output_path / str(self.run_id))


@dataclass
class RunModelUpdateEndEvent(RunModelEvent):
event_type: Literal["RunModelUpdateEndEvent"] = "RunModelUpdateEndEvent"
data: DataSection

def write_as_csv(self, output_path: Path | None) -> None:
if output_path and self.data:
self.data.to_csv("Report", output_path / str(self.run_id))


@dataclass
class RunModelErrorEvent(RunModelEvent):
event_type: Literal["RunModelErrorEvent"] = "RunModelErrorEvent"
error_msg: str
data: DataSection | None = None
data: DataSection

def write_as_csv(self, output_path: Path | None) -> None:
if output_path and self.data:
self.data.to_csv("Report", output_path / str(self.run_id))


StatusEvents = (
AnalysisEvent
| AnalysisStatusEvent
AnalysisStatusEvent
| AnalysisTimeEvent
| EndEvent
| FullSnapshotEvent
Expand All @@ -87,70 +84,14 @@ def write_as_csv(self, output_path: Path | None) -> None:
)


EVENT_MAPPING = {
"AnalysisEvent": AnalysisEvent,
"AnalysisStatusEvent": AnalysisStatusEvent,
"AnalysisTimeEvent": AnalysisTimeEvent,
"EndEvent": EndEvent,
"FullSnapshotEvent": FullSnapshotEvent,
"SnapshotUpdateEvent": SnapshotUpdateEvent,
"RunModelErrorEvent": RunModelErrorEvent,
"RunModelStatusEvent": RunModelStatusEvent,
"RunModelTimeEvent": RunModelTimeEvent,
"RunModelUpdateBeginEvent": RunModelUpdateBeginEvent,
"RunModelDataEvent": RunModelDataEvent,
"RunModelUpdateEndEvent": RunModelUpdateEndEvent,
}


def status_event_from_json(json_str: str) -> StatusEvents:
json_dict = json.loads(json_str)
event_type = json_dict.pop("event_type", None)

match event_type:
case FullSnapshotEvent.__name__:
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
return FullSnapshotEvent(**json_dict)
case SnapshotUpdateEvent.__name__:
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
return SnapshotUpdateEvent(**json_dict)
case RunModelDataEvent.__name__ | RunModelUpdateEndEvent.__name__:
if "run_id" in json_dict and isinstance(json_dict["run_id"], str):
json_dict["run_id"] = UUID(json_dict["run_id"])
if json_dict.get("data"):
json_dict["data"] = DataSection(**json_dict["data"])
return EVENT_MAPPING[event_type](**json_dict)
case _:
if event_type in EVENT_MAPPING:
if "run_id" in json_dict and isinstance(json_dict["run_id"], str):
json_dict["run_id"] = UUID(json_dict["run_id"])
return EVENT_MAPPING[event_type](**json_dict)
else:
raise TypeError(f"Unknown status event type {event_type}")
STATUS_EVENTS_ANNOTATION = Annotated[StatusEvents, Field(discriminator="event_type")]

StatusEventAdapter: TypeAdapter[StatusEvents] = TypeAdapter(STATUS_EVENTS_ANNOTATION)


def status_event_from_json(raw_msg: str | bytes) -> StatusEvents:
return StatusEventAdapter.validate_json(raw_msg)


def status_event_to_json(event: StatusEvents) -> str:
match event:
case FullSnapshotEvent() | SnapshotUpdateEvent():
assert event.snapshot is not None
event_dict = asdict(event)
event_dict.update(
{
"snapshot": event.snapshot.to_dict(),
"event_type": event.__class__.__name__,
}
)
return json.dumps(
event_dict,
default=lambda o: o.strftime("%Y-%m-%dT%H:%M:%S")
if isinstance(o, datetime)
else None,
)
case StatusEvents:
event_dict = asdict(event)
event_dict["event_type"] = StatusEvents.__class__.__name__
return json.dumps(
event_dict, default=lambda o: str(o) if isinstance(o, UUID) else None
)
return event.model_dump_json()
11 changes: 10 additions & 1 deletion tests/ert/unit_tests/cli/test_model_hook_order.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from unittest.mock import ANY, MagicMock, PropertyMock, call, patch

import pytest
Expand Down Expand Up @@ -48,6 +49,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch):

ens_mock = MagicMock()
ens_mock.iteration = 0
ens_mock.id = uuid.uuid1()
storage_mock = MagicMock()
storage_mock.create_ensemble.return_value = ens_mock

Expand Down Expand Up @@ -86,6 +88,7 @@ def test_hook_call_order_es_mda(monkeypatch):

ens_mock = MagicMock()
ens_mock.iteration = 0
ens_mock.id = uuid.uuid1()
storage_mock = MagicMock()
storage_mock.create_ensemble.return_value = ens_mock
test_class = MultipleDataAssimilation(
Expand Down Expand Up @@ -115,9 +118,15 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch):
monkeypatch.setattr(base_run_model, "_seed_sequence", MagicMock(return_value=0))
monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock)

ens_mock = MagicMock()
ens_mock.iteration = 2
ens_mock.id = uuid.uuid1()
storage_mock = MagicMock()
storage_mock.create_ensemble.return_value = ens_mock

test_class = IteratedEnsembleSmoother(*[MagicMock()] * 13)
test_class.run_ensemble_evaluator = MagicMock(return_value=[0])

test_class._storage = storage_mock
# Mock the return values of iterative_smoother_update
# Mock the iteration property of IteratedEnsembleSmoother
with (
Expand Down
Loading

0 comments on commit 3436a3e

Please sign in to comment.