diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 62fdd4056fb..cf9f6ea7edd 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -18,6 +18,8 @@ class ClientConnectionError(Exception): CONNECT_MSG = b"CONNECT" DISCONNECT_MSG = b"DISCONNECT" ACK_MSG = b"ACK" +HEARTBEAT_MSG = b"BEAT" +HEARTBEAT_TIMEOUT = 5.0 class Client: @@ -83,7 +85,7 @@ async def connect(self) -> None: await self._term_receiver_task() self._receiver_task = asyncio.create_task(self._receiver()) try: - await self.send(CONNECT_MSG, retries=1) + await self.send(CONNECT_MSG) except ClientConnectionError: await self._term_receiver_task() self.term() @@ -93,11 +95,23 @@ async def process_message(self, msg: str) -> None: raise NotImplementedError("Only monitor can receive messages!") async def _receiver(self) -> None: + last_heartbeat_time: float | None = None while True: try: _, raw_msg = await self.socket.recv_multipart() if raw_msg == ACK_MSG: self._ack_event.set() + elif raw_msg == HEARTBEAT_MSG: + if ( + last_heartbeat_time + and (asyncio.get_running_loop().time() - last_heartbeat_time) + > 2 * HEARTBEAT_TIMEOUT + ): + await self.socket.send_multipart([b"", CONNECT_MSG]) + logger.warning( + f"{self.dealer_id} heartbeat failed - reconnecting." + ) + last_heartbeat_time = asyncio.get_running_loop().time() else: await self.process_message(raw_msg.decode("utf-8")) except zmq.ZMQError as exc: @@ -144,5 +158,5 @@ async def send(self, message: str | bytes, retries: int | None = None) -> None: self.socket.connect(self.url) backoff = min(backoff * 2, 10) # Exponential backoff raise ClientConnectionError( - f"{self.dealer_id} Failed to send {message!r} after retries!" + f"{self.dealer_id} Failed to send {message!r} to {self.url} after retries!" ) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 9737f26e56c..bf755158552 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -5,6 +5,7 @@ import logging import traceback from collections.abc import Awaitable, Callable, Iterable, Sequence +from enum import Enum from typing import Any, get_args import zmq.asyncio @@ -27,7 +28,12 @@ event_from_json, event_to_json, ) -from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG +from _ert.forward_model_runner.client import ( + ACK_MSG, + CONNECT_MSG, + DISCONNECT_MSG, + HEARTBEAT_TIMEOUT, +) from ert.ensemble_evaluator import identifiers as ids from ._ensemble import FMStepSnapshot @@ -45,13 +51,17 @@ EVENT_HANDLER = Callable[[list[Event]], Awaitable[None]] +class HeartbeatEvent(Enum): + event = b"BEAT" + + class EnsembleEvaluator: def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._config: EvaluatorServerConfig = config self._ensemble: Ensemble = ensemble self._events: asyncio.Queue[Event] = asyncio.Queue() - self._events_to_send: asyncio.Queue[Event] = asyncio.Queue() + self._events_to_send: asyncio.Queue[Event | HeartbeatEvent] = asyncio.Queue() self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() self._ee_tasks: list[asyncio.Task[None]] = [] @@ -72,14 +82,26 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._dispatchers_empty: asyncio.Event = asyncio.Event() self._dispatchers_empty.set() + async def _do_heartbeat_clients(self) -> None: + await self._server_started + while True: + if self._clients_connected: + await self._events_to_send.put(HeartbeatEvent.event) + await asyncio.sleep(HEARTBEAT_TIMEOUT) + async def _publisher(self) -> None: await self._server_started while True: event = await self._events_to_send.get() for identity in self._clients_connected: - await self._router_socket.send_multipart( - [identity, b"", event_to_json(event).encode("utf-8")] - ) + if isinstance(event, HeartbeatEvent): + await self._router_socket.send_multipart( + [identity, b"", event.value] + ) + else: + await self._router_socket.send_multipart( + [identity, b"", event_to_json(event).encode("utf-8")] + ) self._events_to_send.task_done() async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None: @@ -197,6 +219,8 @@ def ensemble(self) -> Ensemble: async def handle_client(self, dealer: bytes, frame: bytes) -> None: if frame == CONNECT_MSG: + if dealer in self._clients_connected: + logger.warning(f"{dealer!r} wants to reconnect.") self._clients_connected.add(dealer) self._clients_empty.clear() current_snapshot_dict = self._ensemble.snapshot.to_dict() @@ -343,6 +367,7 @@ async def _start_running(self) -> None: raise ValueError("no config for evaluator") self._ee_tasks = [ asyncio.create_task(self._server(), name="server_task"), + asyncio.create_task(self._do_heartbeat_clients(), name="heartbeat_task"), asyncio.create_task( self._batch_events_into_buffer(), name="dispatcher_task" ), diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index 1a0ab130f13..d09f89099eb 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -1,5 +1,8 @@ +import asyncio + import pytest +import _ert.forward_model_runner.client from _ert.forward_model_runner.client import Client, ClientConnectionError from tests.ert.utils import MockZMQServer @@ -18,12 +21,12 @@ async def test_invalid_server(): async def test_successful_sending(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" - messages_c1 = ["test_1", "test_2", "test_3"] - async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1: - for message in messages_c1: - await c1.send(message) + messages = ["test_1", "test_2", "test_3"] + async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as client: + for message in messages: + await client.send(message) - for msg in messages_c1: + for msg in messages: assert msg in mock_server.messages @@ -32,14 +35,14 @@ async def test_retry(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" client_connection_error_set = False - messages_c1 = ["test_1", "test_2", "test_3"] + messages = ["test_1", "test_2", "test_3"] async with ( MockZMQServer(unused_tcp_port, signal=2) as mock_server, - Client(url, ack_timeout=0.5) as c1, + Client(url, ack_timeout=0.5) as client, ): - for message in messages_c1: + for message in messages: try: - await c1.send(message, retries=1) + await client.send(message, retries=1) except ClientConnectionError: client_connection_error_set = True mock_server.signal(0) @@ -47,3 +50,26 @@ async def test_retry(unused_tcp_port): assert mock_server.messages.count("test_1") == 2 assert mock_server.messages.count("test_2") == 1 assert mock_server.messages.count("test_3") == 1 + + +async def test_reconnect_when_missing_heartbeat(unused_tcp_port, monkeypatch): + host = "localhost" + url = f"tcp://{host}:{unused_tcp_port}" + monkeypatch.setattr(_ert.forward_model_runner.client, "HEARTBEAT_TIMEOUT", 0.1) + + async with ( + MockZMQServer(unused_tcp_port, signal=3) as mock_server, + Client(url) as client, + ): + await client.send("start", retries=1) + + await mock_server.do_heartbeat() + await asyncio.sleep(1) + await mock_server.do_heartbeat() + await client.send("stop", retries=1) + + # when reconnection happens CONNECT message is sent again + assert mock_server.messages.count("CONNECT") == 2 + assert mock_server.messages.count("DISCONNECT") == 1 + assert "start" in mock_server.messages + assert "stop" in mock_server.messages diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index bbe114d15f3..2adc385f3e1 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -55,8 +55,9 @@ async def mock_event_handler(router_socket): assert msg == DISCONNECT_MSG -async def test_no_connection_established(make_ee_config): +async def test_no_connection_established(monkeypatch, make_ee_config): ee_config = make_ee_config() + monkeypatch.setattr(Monitor, "DEFAULT_MAX_RETRIES", 0) monitor = Monitor(ee_config.get_connection_info()) monitor._ack_timeout = 0.1 with pytest.raises(ClientConnectionError): diff --git a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py index 3d5753f5595..02c49c96dba 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py @@ -213,9 +213,7 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails but reconnects # reporter still manages to send events and completes fine - # see assert reporter._timeout_timestamp is not None - # meaning Finish event initiated _timeout but timeout wasn't reached since - # it finished succesfully + # see reporter._event_publisher for more details. host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 4f66f93e157..4f5bbdbbc55 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -9,7 +9,12 @@ import zmq import zmq.asyncio -from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG +from _ert.forward_model_runner.client import ( + ACK_MSG, + CONNECT_MSG, + DISCONNECT_MSG, + HEARTBEAT_MSG, +) from _ert.threading import ErtThread from ert.scheduler.event import FinishedEvent, StartedEvent @@ -64,9 +69,10 @@ def wait_until(func, interval=0.5, timeout=30): class MockZMQServer: def __init__(self, port, signal=0): """Mock ZMQ server for testing - signal = 0: normal operation + signal = 0: normal operation, receive messages but don't store CONNECT and DISCONNECT messages signal = 1: don't send ACK and don't receive messages signal = 2: don't send ACK, but receive messages + signal = 3: normal operation, and store also CONNECT and DISCONNECT messages """ self.port = port self.messages = [] @@ -74,6 +80,7 @@ def __init__(self, port, signal=0): self.loop = None self.server_task = None self.handler_task = None + self.dealers = set() def start_event_loop(self): asyncio.set_event_loop(self.loop) @@ -116,13 +123,25 @@ async def mock_zmq_server(self): def signal(self, value): self.value = value + async def do_heartbeat(self): + for dealer in self.dealers: + await self.router_socket.send_multipart([dealer, b"", HEARTBEAT_MSG]) + async def _handler(self): while True: try: dealer, __, frame = await self.router_socket.recv_multipart() - if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0: + if frame == CONNECT_MSG: + await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) + self.dealers.add(dealer) + elif frame == DISCONNECT_MSG: + await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) + self.dealers.discard(dealer) + elif self.value in {0, 3}: await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) - if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1: + if ( + self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG} + ) or self.value == 3: self.messages.append(frame.decode("utf-8")) except asyncio.CancelledError: break