Skip to content

Commit

Permalink
zmq: replace server_started Event with Future to handle exceptions
Browse files Browse the repository at this point in the history
In a very special case zmq server might fail during initialization and all occurrences of
server_started.wait() will wait indefinitely and therefore replacing it
with asyncio.Future which provides additional exception trigger.
  • Loading branch information
xjules committed Jan 21, 2025
1 parent 0b80ac1 commit 24635d0
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 9 deletions.
11 changes: 6 additions & 5 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._max_batch_size: int = 500
self._batching_interval: float = 2.0
self._complete_batch: asyncio.Event = asyncio.Event()
self._server_started: asyncio.Event = asyncio.Event()
self._server_started: asyncio.Future[None] = asyncio.Future()
self._clients_connected: set[bytes] = set()
self._clients_empty: asyncio.Event = asyncio.Event()
self._clients_empty.set()
Expand All @@ -73,7 +73,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._dispatchers_empty.set()

async def _publisher(self) -> None:
await self._server_started.wait()
await self._server_started
while True:
event = await self._events_to_send.get()
for identity in self._clients_connected:
Expand Down Expand Up @@ -243,7 +243,7 @@ async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None:
await self._events.put(event)

async def listen_for_messages(self) -> None:
await self._server_started.wait()
await self._server_started
while True:
try:
dealer, _, frame = await self._router_socket.recv_multipart()
Expand Down Expand Up @@ -285,9 +285,10 @@ async def _server(self) -> None:
self._router_socket.bind(f"tcp://*:{self._config.router_port}")
else:
self._router_socket.bind(self._config.url)
self._server_started.set()
self._server_started.set_result(None)
except zmq.error.ZMQError as e:
logger.error(f"ZMQ error encountered {e} during evaluator initialization")
self._server_started.set_exception(e)
raise
try:
await self._server_done.wait()
Expand Down Expand Up @@ -350,7 +351,7 @@ async def _start_running(self) -> None:
asyncio.create_task(self.listen_for_messages(), name="listener_task"),
]

await self._server_started.wait()
await self._server_started
self._ee_tasks.append(
asyncio.create_task(
self._ensemble.evaluate(
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ async def run_ensemble_evaluator_async(
evaluator_task = asyncio.create_task(
evaluator.run_and_get_successful_realizations()
)
await evaluator._server_started.wait()
await evaluator._server_started
if not (await self.run_monitor(ee_config, ensemble.iteration)):
return []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def evaluator_to_use_fixture(make_ee_config):
evaluator = EnsembleEvaluator(ensemble, make_ee_config(use_token=False))
evaluator._batching_interval = 0.5 # batching can be faster for tests
run_task = asyncio.create_task(evaluator.run_and_get_successful_realizations())
await evaluator._server_started.wait()
await evaluator._server_started
yield evaluator
evaluator.stop()
await run_task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def evaluator_to_use():
async def run_evaluator(ensemble, ee_config):
evaluator = EnsembleEvaluator(ensemble, ee_config)
run_task = asyncio.create_task(evaluator.run_and_get_successful_realizations())
await evaluator._server_started.wait()
await evaluator._server_started
try:
yield evaluator
finally:
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_manifest_file():
run_task = asyncio.create_task(
evaluator.run_and_get_successful_realizations()
)
await evaluator._server_started.wait()
await evaluator._server_started
await _run_monitor()
await run_task
assert "Waiting for disk synchronization" in caplog.messages
Expand Down

0 comments on commit 24635d0

Please sign in to comment.