Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do heartbeat on clients #9798

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class ClientConnectionError(Exception):
CONNECT_MSG = b"CONNECT"
DISCONNECT_MSG = b"DISCONNECT"
ACK_MSG = b"ACK"
HEARTBEAT_MSG = b"BEAT"
HEARTBEAT_TIMEOUT = 5.0


class Client:
Expand Down Expand Up @@ -83,7 +85,7 @@ async def connect(self) -> None:
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self.send(CONNECT_MSG, retries=1)
await self.send(CONNECT_MSG)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
Expand All @@ -93,11 +95,23 @@ async def process_message(self, msg: str) -> None:
raise NotImplementedError("Only monitor can receive messages!")

async def _receiver(self) -> None:
last_heartbeat_time: float | None = None
while True:
try:
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == ACK_MSG:
self._ack_event.set()
elif raw_msg == HEARTBEAT_MSG:
if (
last_heartbeat_time
and (asyncio.get_running_loop().time() - last_heartbeat_time)
> 2 * HEARTBEAT_TIMEOUT
):
await self.socket.send_multipart([b"", CONNECT_MSG])
logger.warning(
f"{self.dealer_id} heartbeat failed - reconnecting."
)
last_heartbeat_time = asyncio.get_running_loop().time()
else:
await self.process_message(raw_msg.decode("utf-8"))
except zmq.ZMQError as exc:
Expand Down Expand Up @@ -144,5 +158,5 @@ async def send(self, message: str | bytes, retries: int | None = None) -> None:
self.socket.connect(self.url)
backoff = min(backoff * 2, 10) # Exponential backoff
raise ClientConnectionError(
f"{self.dealer_id} Failed to send {message!r} after retries!"
f"{self.dealer_id} Failed to send {message!r} to {self.url} after retries!"
)
31 changes: 26 additions & 5 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
event_from_json,
event_to_json,
)
from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from _ert.forward_model_runner.client import (
ACK_MSG,
CONNECT_MSG,
DISCONNECT_MSG,
HEARTBEAT_MSG,
HEARTBEAT_TIMEOUT,
)
from ert.ensemble_evaluator import identifiers as ids

from ._ensemble import FMStepSnapshot
Expand All @@ -51,7 +57,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._ensemble: Ensemble = ensemble

self._events: asyncio.Queue[Event] = asyncio.Queue()
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
self._events_to_send: asyncio.Queue[Event | None] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()

self._ee_tasks: list[asyncio.Task[None]] = []
Expand All @@ -72,14 +78,26 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._dispatchers_empty: asyncio.Event = asyncio.Event()
self._dispatchers_empty.set()

async def _do_heartbeat_clients(self) -> None:
await self._server_started
while True:
if self._clients_connected:
await self._events_to_send.put(None)
await asyncio.sleep(HEARTBEAT_TIMEOUT)

async def _publisher(self) -> None:
await self._server_started
while True:
event = await self._events_to_send.get()
for identity in self._clients_connected:
await self._router_socket.send_multipart(
[identity, b"", event_to_json(event).encode("utf-8")]
)
if event is None:
await self._router_socket.send_multipart(
[identity, b"", HEARTBEAT_MSG]
)
else:
await self._router_socket.send_multipart(
[identity, b"", event_to_json(event).encode("utf-8")]
)
self._events_to_send.task_done()

async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None:
Expand Down Expand Up @@ -197,6 +215,8 @@ def ensemble(self) -> Ensemble:

async def handle_client(self, dealer: bytes, frame: bytes) -> None:
if frame == CONNECT_MSG:
if dealer in self._clients_connected:
logger.warning(f"{dealer!r} wants to reconnect.")
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
Expand Down Expand Up @@ -343,6 +363,7 @@ async def _start_running(self) -> None:
raise ValueError("no config for evaluator")
self._ee_tasks = [
asyncio.create_task(self._server(), name="server_task"),
asyncio.create_task(self._do_heartbeat_clients(), name="heartbeat_task"),
asyncio.create_task(
self._batch_events_into_buffer(), name="dispatcher_task"
),
Expand Down
44 changes: 35 additions & 9 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio

import pytest

import _ert.forward_model_runner.client
from _ert.forward_model_runner.client import Client, ClientConnectionError
from tests.ert.utils import MockZMQServer

Expand All @@ -18,12 +21,12 @@ async def test_invalid_server():
async def test_successful_sending(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
messages_c1 = ["test_1", "test_2", "test_3"]
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1:
for message in messages_c1:
await c1.send(message)
messages = ["test_1", "test_2", "test_3"]
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as client:
for message in messages:
await client.send(message)

for msg in messages_c1:
for msg in messages:
assert msg in mock_server.messages


Expand All @@ -32,18 +35,41 @@ async def test_retry(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
client_connection_error_set = False
messages_c1 = ["test_1", "test_2", "test_3"]
messages = ["test_1", "test_2", "test_3"]
async with (
MockZMQServer(unused_tcp_port, signal=2) as mock_server,
Client(url, ack_timeout=0.5) as c1,
Client(url, ack_timeout=0.5) as client,
):
for message in messages_c1:
for message in messages:
try:
await c1.send(message, retries=1)
await client.send(message, retries=1)
except ClientConnectionError:
client_connection_error_set = True
mock_server.signal(0)
assert client_connection_error_set
assert mock_server.messages.count("test_1") == 2
assert mock_server.messages.count("test_2") == 1
assert mock_server.messages.count("test_3") == 1


async def test_reconnect_when_missing_heartbeat(unused_tcp_port, monkeypatch):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
monkeypatch.setattr(_ert.forward_model_runner.client, "HEARTBEAT_TIMEOUT", 0.1)

async with (
MockZMQServer(unused_tcp_port, signal=3) as mock_server,
Client(url) as client,
):
await client.send("start", retries=1)

await mock_server.do_heartbeat()
await asyncio.sleep(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we sleep here to force the monitor to reconnect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is correct. PING_TIMEOUT is 0.1 so sleeping 1 seconds guarantees re-connection.

await mock_server.do_heartbeat()
await client.send("stop", retries=1)

# when reconnection happens CONNECT message is sent again
assert mock_server.messages.count("CONNECT") == 2
assert mock_server.messages.count("DISCONNECT") == 1
assert "start" in mock_server.messages
assert "stop" in mock_server.messages
3 changes: 2 additions & 1 deletion tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ async def mock_event_handler(router_socket):
assert msg == DISCONNECT_MSG


async def test_no_connection_established(make_ee_config):
async def test_no_connection_established(monkeypatch, make_ee_config):
ee_config = make_ee_config()
monkeypatch.setattr(Monitor, "DEFAULT_MAX_RETRIES", 0)
monitor = Monitor(ee_config.get_connection_info())
monitor._ack_timeout = 0.1
with pytest.raises(ClientConnectionError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port):
def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port):
# this is to show when the reporter fails but reconnects
# reporter still manages to send events and completes fine
# see assert reporter._timeout_timestamp is not None
# meaning Finish event initiated _timeout but timeout wasn't reached since
# it finished succesfully
# see reporter._event_publisher for more details.

host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
Expand Down
27 changes: 23 additions & 4 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import zmq
import zmq.asyncio

from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from _ert.forward_model_runner.client import (
ACK_MSG,
CONNECT_MSG,
DISCONNECT_MSG,
HEARTBEAT_MSG,
)
from _ert.threading import ErtThread
from ert.scheduler.event import FinishedEvent, StartedEvent

Expand Down Expand Up @@ -64,16 +69,18 @@ def wait_until(func, interval=0.5, timeout=30):
class MockZMQServer:
def __init__(self, port, signal=0):
"""Mock ZMQ server for testing
signal = 0: normal operation
signal = 0: normal operation, receive messages but don't store CONNECT and DISCONNECT messages
signal = 1: don't send ACK and don't receive messages
signal = 2: don't send ACK, but receive messages
signal = 3: normal operation, and store also CONNECT and DISCONNECT messages
"""
self.port = port
self.messages = []
self.value = signal
self.loop = None
self.server_task = None
self.handler_task = None
self.dealers = set()

def start_event_loop(self):
asyncio.set_event_loop(self.loop)
Expand Down Expand Up @@ -116,13 +123,25 @@ async def mock_zmq_server(self):
def signal(self, value):
self.value = value

async def do_heartbeat(self):
for dealer in self.dealers:
await self.router_socket.send_multipart([dealer, b"", HEARTBEAT_MSG])

async def _handler(self):
while True:
try:
dealer, __, frame = await self.router_socket.recv_multipart()
if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0:
if frame == CONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
self.dealers.add(dealer)
elif frame == DISCONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
self.dealers.discard(dealer)
elif self.value in {0, 3}:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1:
if (
self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG}
) or self.value == 3:
self.messages.append(frame.decode("utf-8"))
except asyncio.CancelledError:
break
Expand Down
Loading