Skip to content

Commit

Permalink
Make scheduler cancel() and kill_all_jobs async
Browse files Browse the repository at this point in the history
This commit makes the methods async, so we can await them instead
fire-and-forgetting them through asyncio.run_coroutine_threadsafe.
  • Loading branch information
jonathan-eq authored and berland committed Jan 21, 2025
1 parent 36345a6 commit bc8a611
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 34 deletions.
15 changes: 4 additions & 11 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from functools import partialmethod
from typing import (
Any,
Protocol,
)
from typing import Any

from _ert.events import (
Event,
Expand Down Expand Up @@ -112,7 +109,7 @@ class LegacyEnsemble:
id_: str

def __post_init__(self) -> None:
self._scheduler: _KillAllJobs | None = None
self._scheduler: Scheduler | None = None
self._config: EvaluatorServerConfig | None = None
self.snapshot: EnsembleSnapshot = self._create_snapshot()
self.status = self.snapshot.status
Expand Down Expand Up @@ -322,16 +319,12 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
def cancellable(self) -> bool:
return True

def cancel(self) -> None:
async def cancel(self) -> None:
if self._scheduler is not None:
self._scheduler.kill_all_jobs()
await self._scheduler.kill_all_jobs()
logger.debug("evaluator cancelled")


class _KillAllJobs(Protocol):
def kill_all_jobs(self) -> None: ...


@dataclass
class Realization:
iens: int
Expand Down
12 changes: 4 additions & 8 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._config: EvaluatorServerConfig = config
self._ensemble: Ensemble = ensemble

self._loop: asyncio.AbstractEventLoop | None = None

self._clients: set[ServerConnection] = set()
self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue()

Expand Down Expand Up @@ -198,7 +196,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
if len(events) == 0:
events = [EnsembleFailed(ensemble=self.ensemble.id_)]
await self._append_message(self.ensemble.update_snapshot(events))
self._signal_cancel() # let ensemble know it should stop
await self._signal_cancel() # let ensemble know it should stop

@property
def ensemble(self) -> Ensemble:
Expand All @@ -223,7 +221,7 @@ async def handle_client(self, websocket: ServerConnection) -> None:
logger.debug(f"got message from client: {event}")
if type(event) is EEUserCancel:
logger.debug(f"Client {websocket.remote_address} asked to cancel.")
self._signal_cancel()
await self._signal_cancel()

elif type(event) is EEUserDone:
logger.debug(f"Client {websocket.remote_address} signalled done.")
Expand Down Expand Up @@ -342,7 +340,7 @@ async def _server(self) -> None:
def stop(self) -> None:
self._server_done.set()

def _signal_cancel(self) -> None:
async def _signal_cancel(self) -> None:
"""
This is just a wrapper around logic for whether to signal cancel via
a cancellable ensemble or to use internal stop-mechanism directly
Expand All @@ -353,16 +351,14 @@ def _signal_cancel(self) -> None:
"""
if self._ensemble.cancellable:
logger.debug("Cancelling current ensemble")
assert self._loop is not None
self._loop.run_in_executor(None, self._ensemble.cancel)
await self._ensemble.cancel()
else:
logger.debug("Stopping current ensemble")
self.stop()

async def _start_running(self) -> None:
if not self._config:
raise ValueError("no config for evaluator")
self._loop = asyncio.get_running_loop()
self._ee_tasks = [
asyncio.create_task(self._server(), name="server_task"),
asyncio.create_task(
Expand Down
1 change: 0 additions & 1 deletion src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ async def signal_cancel(self) -> None:
return
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} asking server to cancel...")

cancel_event = EEUserCancel(monitor=self._id)
await self._connection.send(event_to_json(cancel_event))
logger.debug(f"monitor-{self._id} asked server to cancel")
Expand Down
8 changes: 2 additions & 6 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,8 @@ async def run_monitor(
EESnapshotUpdate,
}:
event = cast(EESnapshot | EESnapshotUpdate, event)
await asyncio.get_running_loop().run_in_executor(
None,
self.send_snapshot_event,
event,
iteration,
)

self.send_snapshot_event(event, iteration)

if event.snapshot.get(STATUS) in {
ENSEMBLE_STATE_STOPPED,
Expand Down
10 changes: 2 additions & 8 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import orjson
from pydantic.dataclasses import dataclass

from _ert.async_utils import get_running_loop
from _ert.events import Event, ForwardModelStepChecksum, Id, event_from_dict
from ert.constant_filenames import CERT_FILE

Expand Down Expand Up @@ -86,7 +85,6 @@ def __init__(
real.iens: Job(self, real) for real in (realizations or [])
}

self._loop = get_running_loop()
self._events: asyncio.Queue[Any] = asyncio.Queue()
self._running: asyncio.Event = asyncio.Event()

Expand All @@ -108,12 +106,8 @@ def __init__(

self.checksum: dict[str, dict[str, Any]] = {}

def kill_all_jobs(self) -> None:
assert self._loop
# Checking that the loop is running is required because everest is closing the
# simulation context whenever an optimization simulation batch is done
if self._loop.is_running():
asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop)
async def kill_all_jobs(self) -> None:
await self.cancel_all_jobs()

async def cancel_all_jobs(self) -> None:
await self._running.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def test_run_and_cancel_legacy_ensemble(
# and the ensemble is set to STOPPED
monitor._receiver_timeout = 10.0
cancel = True
await evaluator._ensemble._scheduler._running.wait()
with contextlib.suppress(
ConnectionClosed
): # monitor throws some variant of CC if dispatcher dies
Expand Down

0 comments on commit bc8a611

Please sign in to comment.