Skip to content

Commit

Permalink
WIP: write test for client reconnection
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Jan 17, 2025
1 parent f70f05b commit ae76e35
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
20 changes: 20 additions & 0 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest

from _ert.forward_model_runner.client import Client, ClientConnectionError
Expand Down Expand Up @@ -46,3 +48,21 @@ async def test_retry(unused_tcp_port):
assert mock_server.messages.count("test_1") == 2
assert mock_server.messages.count("test_2") == 1
assert mock_server.messages.count("test_3") == 1


async def test_reconnect_when_missing_ping(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"

async with (
MockZMQServer(unused_tcp_port, signal=3) as mock_server,
Client(url, ack_timeout=0.5) as c1,
):
await c1.send("start", retries=1)

await mock_server.do_ping()
await asyncio.sleep(2)
await mock_server.do_ping()
await c1.send("stop", retries=1)

print(mock_server.messages)
30 changes: 26 additions & 4 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import zmq
import zmq.asyncio

from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from _ert.forward_model_runner.client import (
ACK_MSG,
CONNECT_MSG,
DISCONNECT_MSG,
PING_MSG,
)
from _ert.threading import ErtThread
from ert.scheduler.event import FinishedEvent, StartedEvent

Expand Down Expand Up @@ -67,13 +72,15 @@ def __init__(self, port, signal=0):
signal = 0: normal operation
signal = 1: don't send ACK and don't receive messages
signal = 2: don't send ACK, but receive messages
signal = 3: normal operation and store all messages
"""
self.port = port
self.messages = []
self.value = signal
self.loop = None
self.server_task = None
self.handler_task = None
self.dealers = set()

def start_event_loop(self):
asyncio.set_event_loop(self.loop)
Expand Down Expand Up @@ -116,14 +123,29 @@ async def mock_zmq_server(self):
def signal(self, value):
self.value = value

async def do_ping(self):
for dealer in self.dealers:
await self.router_socket.send_multipart([dealer, b"", PING_MSG])

async def _handler(self):
while True:
try:
dealer, __, frame = await self.router_socket.recv_multipart()
if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0:
if frame == CONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
self.dealers.add(dealer)
if self.value == 3:
self.messages.append(frame.decode("utf-8"))
elif frame == DISCONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1:
self.messages.append(frame.decode("utf-8"))
self.dealers.discard(dealer)
if self.value == 3:
self.messages.append(frame.decode("utf-8"))
elif self.value != 1:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
if self.value != 2:
self.messages.append(frame.decode("utf-8"))

except asyncio.CancelledError:
break

Expand Down

0 comments on commit ae76e35

Please sign in to comment.