From bc8a6118a15816f0dc7e1817216b0ca1c897aeea Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Mon, 6 Jan 2025 14:05:22 +0100 Subject: [PATCH] Make scheduler cancel() and kill_all_jobs async This commit makes the methods async, so we can await them instead fire-and-forgetting them through asyncio.run_coroutine_threadsafe. --- src/ert/ensemble_evaluator/_ensemble.py | 15 ++++----------- src/ert/ensemble_evaluator/evaluator.py | 12 ++++-------- src/ert/ensemble_evaluator/monitor.py | 1 - src/ert/run_models/base_run_model.py | 8 ++------ src/ert/scheduler/scheduler.py | 10 ++-------- .../ensemble_evaluator/test_ensemble_legacy.py | 1 + 6 files changed, 13 insertions(+), 34 deletions(-) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index b09a0a81536..c4bb692ec9a 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -6,10 +6,7 @@ from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import partialmethod -from typing import ( - Any, - Protocol, -) +from typing import Any from _ert.events import ( Event, @@ -112,7 +109,7 @@ class LegacyEnsemble: id_: str def __post_init__(self) -> None: - self._scheduler: _KillAllJobs | None = None + self._scheduler: Scheduler | None = None self._config: EvaluatorServerConfig | None = None self.snapshot: EnsembleSnapshot = self._create_snapshot() self.status = self.snapshot.status @@ -322,16 +319,12 @@ async def _evaluate_inner( # pylint: disable=too-many-branches def cancellable(self) -> bool: return True - def cancel(self) -> None: + async def cancel(self) -> None: if self._scheduler is not None: - self._scheduler.kill_all_jobs() + await self._scheduler.kill_all_jobs() logger.debug("evaluator cancelled") -class _KillAllJobs(Protocol): - def kill_all_jobs(self) -> None: ... - - @dataclass class Realization: iens: int diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 7078ce412e4..8ec6c477e44 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -62,8 +62,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._config: EvaluatorServerConfig = config self._ensemble: Ensemble = ensemble - self._loop: asyncio.AbstractEventLoop | None = None - self._clients: set[ServerConnection] = set() self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() @@ -198,7 +196,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: if len(events) == 0: events = [EnsembleFailed(ensemble=self.ensemble.id_)] await self._append_message(self.ensemble.update_snapshot(events)) - self._signal_cancel() # let ensemble know it should stop + await self._signal_cancel() # let ensemble know it should stop @property def ensemble(self) -> Ensemble: @@ -223,7 +221,7 @@ async def handle_client(self, websocket: ServerConnection) -> None: logger.debug(f"got message from client: {event}") if type(event) is EEUserCancel: logger.debug(f"Client {websocket.remote_address} asked to cancel.") - self._signal_cancel() + await self._signal_cancel() elif type(event) is EEUserDone: logger.debug(f"Client {websocket.remote_address} signalled done.") @@ -342,7 +340,7 @@ async def _server(self) -> None: def stop(self) -> None: self._server_done.set() - def _signal_cancel(self) -> None: + async def _signal_cancel(self) -> None: """ This is just a wrapper around logic for whether to signal cancel via a cancellable ensemble or to use internal stop-mechanism directly @@ -353,8 +351,7 @@ def _signal_cancel(self) -> None: """ if self._ensemble.cancellable: logger.debug("Cancelling current ensemble") - assert self._loop is not None - self._loop.run_in_executor(None, self._ensemble.cancel) + await self._ensemble.cancel() else: logger.debug("Stopping current ensemble") self.stop() @@ -362,7 +359,6 @@ def _signal_cancel(self) -> None: async def _start_running(self) -> None: if not self._config: raise ValueError("no config for evaluator") - self._loop = asyncio.get_running_loop() self._ee_tasks = [ asyncio.create_task(self._server(), name="server_task"), asyncio.create_task( diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index d3f549377c6..a9d045063a3 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -71,7 +71,6 @@ async def signal_cancel(self) -> None: return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") - cancel_event = EEUserCancel(monitor=self._id) await self._connection.send(event_to_json(cancel_event)) logger.debug(f"monitor-{self._id} asked server to cancel") diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 4fb4894b1b6..1b859277cef 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -508,12 +508,8 @@ async def run_monitor( EESnapshotUpdate, }: event = cast(EESnapshot | EESnapshotUpdate, event) - await asyncio.get_running_loop().run_in_executor( - None, - self.send_snapshot_event, - event, - iteration, - ) + + self.send_snapshot_event(event, iteration) if event.snapshot.get(STATUS) in { ENSEMBLE_STATE_STOPPED, diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index a1610930b26..15432290302 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -15,7 +15,6 @@ import orjson from pydantic.dataclasses import dataclass -from _ert.async_utils import get_running_loop from _ert.events import Event, ForwardModelStepChecksum, Id, event_from_dict from ert.constant_filenames import CERT_FILE @@ -86,7 +85,6 @@ def __init__( real.iens: Job(self, real) for real in (realizations or []) } - self._loop = get_running_loop() self._events: asyncio.Queue[Any] = asyncio.Queue() self._running: asyncio.Event = asyncio.Event() @@ -108,12 +106,8 @@ def __init__( self.checksum: dict[str, dict[str, Any]] = {} - def kill_all_jobs(self) -> None: - assert self._loop - # Checking that the loop is running is required because everest is closing the - # simulation context whenever an optimization simulation batch is done - if self._loop.is_running(): - asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop) + async def kill_all_jobs(self) -> None: + await self.cancel_all_jobs() async def cancel_all_jobs(self) -> None: await self._running.wait() diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index a657a872571..fe412676f0b 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -93,6 +93,7 @@ async def test_run_and_cancel_legacy_ensemble( # and the ensemble is set to STOPPED monitor._receiver_timeout = 10.0 cancel = True + await evaluator._ensemble._scheduler._running.wait() with contextlib.suppress( ConnectionClosed ): # monitor throws some variant of CC if dispatcher dies