Skip to content

Commit

Permalink
Create end point for events
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jan 9, 2025
1 parent 3f8b320 commit 761ae55
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 9 deletions.
9 changes: 6 additions & 3 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from _ert.events import EESnapshot, EESnapshotUpdate, Event
from ert.config import ErtConfig, ExtParamConfig
from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig
from ert.ensemble_evaluator import EndEvent, EnsembleSnapshot, EvaluatorServerConfig
from ert.runpaths import Runpaths
from ert.storage import open_storage
from everest.config import EverestConfig
Expand Down Expand Up @@ -103,10 +103,11 @@ def __init__(
everest_config: EverestConfig,
simulation_callback: SimulationCallback | None,
optimization_callback: OptimizerCallback | None,
status_queue: queue.SimpleQueue[StatusEvents] | None = None,
):
Path(everest_config.log_dir).mkdir(parents=True, exist_ok=True)
Path(everest_config.optimization_output_dir).mkdir(parents=True, exist_ok=True)

status_queue = queue.SimpleQueue() if status_queue is None else status_queue
assert everest_config.environment is not None
logging.getLogger(EVEREST).info(
"Using random seed: %d. To deterministically reproduce this experiment, "
Expand Down Expand Up @@ -136,7 +137,6 @@ def __init__(
self._status: SimulationStatus | None = None

storage = open_storage(config.ens_path, mode="w")
status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue()
super().__init__(
config,
storage,
Expand All @@ -152,12 +152,14 @@ def create(
ever_config: EverestConfig,
simulation_callback: SimulationCallback | None = None,
optimization_callback: OptimizerCallback | None = None,
status_queue: queue.SimpleQueue[StatusEvents] | None = None,
) -> EverestRunModel:
return cls(
config=everest_to_ert_config(ever_config),
everest_config=ever_config,
simulation_callback=simulation_callback,
optimization_callback=optimization_callback,
status_queue=status_queue,
)

@classmethod
Expand Down Expand Up @@ -222,6 +224,7 @@ def run_experiment(
self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS
case _:
self._exit_code = EverestExitCode.COMPLETED
self.send_event(EndEvent(failed=bool(self.exit_code)))

def _create_optimizer(self) -> BasicOptimizer:
RESULT_COLUMNS = {
Expand Down
88 changes: 82 additions & 6 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import argparse
import asyncio
import datetime
import json
import logging
import multiprocessing as mp
import os
import queue
import socket
import ssl
import threading
import traceback
import uuid
from base64 import b64encode
from functools import partial
from pathlib import Path
Expand All @@ -19,7 +23,7 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from dns import resolver, reversename
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
JSONResponse,
Expand All @@ -32,7 +36,8 @@
)

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.ensemble_evaluator import EndEvent, EvaluatorServerConfig
from ert.run_models import StatusEvents
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest import export_to_csv, export_with_progress
from everest.config import EverestConfig, ServerConfig
Expand All @@ -51,6 +56,23 @@
from everest.util import makedirs_if_needed, version_info


class EndTaskEvent:
pass


class Subscriber:
def __init__(self) -> None:
self.index = 0
self._event = asyncio.Event()

def notify(self):
self._event.set()

async def wait_for_event(self):
await self._event.wait()
self._event.clear()


def _get_machine_name() -> str:
"""Returns a name that can be used to identify this machine in a network
Expand Down Expand Up @@ -153,6 +175,29 @@ def get_opt_progress(
progress = get_opt_status(server_config["optimization_output_dir"])
return JSONResponse(jsonable_encoder(progress))

@app.websocket("/events")
async def websocket_endpoint(websocket: WebSocket):
subscriber_id = str(uuid.uuid4())
await websocket.accept()
while True:
event = await get_event(subscriber_id=subscriber_id)
if isinstance(event, EndTaskEvent):
break
await websocket.send_json(event)
await asyncio.sleep(0.1)

async def get_event(subscriber_id: str) -> StatusEvents:
if subscriber_id not in shared_data["subscribers"]:
shared_data["subscribers"][subscriber_id] = Subscriber()
subscriber = shared_data["subscribers"][subscriber_id]

while subscriber.index >= len(shared_data["events"]):
await subscriber.wait_for_event()

event = shared_data["events"][subscriber.index]
shared_data["subscribers"][subscriber_id].index += 1
return event

uvicorn.run(
app,
host="0.0.0.0",
Expand Down Expand Up @@ -235,6 +280,10 @@ def make_handler_config(


def main():
asyncio.run(everserver_main())


async def everserver_main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--config-file", type=str)
arg_parser.add_argument("--debug", action="store_true")
Expand Down Expand Up @@ -272,6 +321,8 @@ def main():
shared_data = {
SIM_PROGRESS_ENDPOINT: {},
STOP_ENDPOINT: False,
"events": [],
"subscribers": [],
}

server_config = {
Expand All @@ -296,23 +347,48 @@ def main():
message=traceback.format_exc(),
)
return

status_queue: mp.Queue[StatusEvents] = mp.Queue()
try:
update_everserver_status(status_path, ServerStatus.running)

run_model = EverestRunModel.create(
config,
simulation_callback=partial(_sim_monitor, shared_data=shared_data),
optimization_callback=partial(_opt_monitor, shared_data=shared_data),
status_queue=status_queue,
)
if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL:
evaluator_server_config = EvaluatorServerConfig()
else:
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=range(49152, 51819), use_ipc_protocol=False
)

run_model.run_experiment(evaluator_server_config)
loop = asyncio.get_running_loop()
simulation_future = loop.run_in_executor(
None,
lambda: run_model.run_experiment(evaluator_server_config),
)
events = []
while True:
try:
item: StatusEvents = status_queue.get(block=False)
except queue.Empty:
await asyncio.sleep(0.01)
continue

event = jsonable_encoder(item)
shared_data["events"].append(event)
for sub in shared_data["subscribers"]:
sub.notify()
await asyncio.sleep(0.1)

if isinstance(item, EndEvent):
events.append(EndTaskEvent())
for sub in shared_data["subscribers"]:
sub.notify()
break

await simulation_future
run_model = None

status, message = _get_optimization_status(run_model.exit_code, shared_data)
if status != ServerStatus.completed:
Expand Down

0 comments on commit 761ae55

Please sign in to comment.