Skip to content

Commit

Permalink
Check before test ioloop close
Browse files Browse the repository at this point in the history
- remove an unwarranted test `reconnect()` (makes flow much simpler to verify)
- don't forget to remove reference to connection so it can be GC'd
- remove redundant shutdown requests, given fixtures
- fix test to let ioloop know that connection is closing, allows for a cleaner
  shutdown in the ioloop internals.
  • Loading branch information
khk-globus committed Jan 16, 2025
1 parent 6d384e7 commit 6e15241
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def run(self):
finally:
if self._connection and self._connection.ioloop:
self._connection.ioloop.close()
self._connection = None

self._stop_event.set()
logger.debug("%s Shutdown complete", self)
Expand Down Expand Up @@ -346,7 +347,6 @@ def _stop_ioloop(self):
self._connection.close()
elif self._connection.is_closed:
self._connection.ioloop.stop()
self._connection = None

def _event_watcher(self):
"""Polls the stop_event periodically to trigger a shutdown"""
Expand Down
95 changes: 32 additions & 63 deletions compute_endpoint/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pika.exceptions
import pytest
from globus_compute_endpoint.endpoint.rabbit_mq import (
RabbitPublisherStatus,
ResultPublisher,
TaskQueueSubscriber,
)
Expand Down Expand Up @@ -115,29 +114,6 @@ def task_queue_info(rabbitmq_conn_url, tod_session_num, request) -> dict:
}


@pytest.fixture
def running_subscribers(request):
run_list = []

def cleanup():
for x in run_list:
try: # cannot check is_alive on closed proc
is_alive = x.is_alive()
except ValueError:
is_alive = False
if is_alive:
try:
x.stop()
except Exception as e:
x.terminate()
raise Exception(
f"{x.__class__.__name__} did not shutdown correctly"
) from e

request.addfinalizer(cleanup)
return run_list


@pytest.fixture(scope="session")
def ensure_result_queue(pika_conn_params):
queues_created = []
Expand Down Expand Up @@ -180,7 +156,7 @@ def start_task_q_subscriber(
task_queue_info,
ensure_task_queue,
):
running_subscribers: list[TaskQueueSubscriber] = []
qs_list: list[TaskQueueSubscriber] = []

def func(
*,
Expand All @@ -192,16 +168,17 @@ def func(
q_info = task_queue_info if override_params is None else override_params
ensure_task_queue(queue_opts={"queue": q_info["queue"]})

tqs = TaskQueueSubscriber(queue_info=q_info, pending_task_queue=task_queue)
tqs.start()
running_subscribers.append(tqs)
return tqs
qs = TaskQueueSubscriber(queue_info=q_info, pending_task_queue=task_queue)
qs.start()
qs_list.append(qs)
return qs

yield func

for sub in running_subscribers:
sub._stop_event.set()
sub.join()
while qs_list:
qs = qs_list.pop()
qs.stop()
qs.join()


@pytest.fixture
Expand Down Expand Up @@ -235,28 +212,13 @@ def func(
qs_list.pop().stop()


@pytest.fixture
def running_publishers(request):
run_list = []

def cleanup():
for x in run_list:
if x.status is RabbitPublisherStatus.connected:
if hasattr(x, "stop"):
x.stop() # ResultPublisher
else:
x.close() # TaskQueuePublisher (from tests)

request.addfinalizer(cleanup)
return run_list


@pytest.fixture
def start_result_q_publisher(
running_publishers,
result_queue_info,
ensure_result_queue,
):
qp_list: list[ResultPublisher] = []

def func(
*,
override_params: dict | None = None,
Expand All @@ -273,24 +235,28 @@ def func(
queue_opts = {"queue": queue_name, "durable": True}
ensure_result_queue(exchange_opts=exchange_opts, queue_opts=queue_opts)

result_pub = ResultPublisher(queue_info=q_info)
result_pub.start()
qp = ResultPublisher(queue_info=q_info)
qp.start()
qp_list.append(qp)
if queue_purge: # Make sure queue is empty
try_assert(lambda: result_pub._mq_chan is not None)
result_pub._mq_chan.queue_purge(q_info["queue"])
running_publishers.append(result_pub)
return result_pub
try_assert(lambda: qp._mq_chan is not None)
qp._mq_chan.queue_purge(q_info["queue"])
return qp

return func
yield func

while qp_list:
qp_list.pop().stop(timeout=None)


@pytest.fixture
def start_task_q_publisher(
running_publishers,
task_queue_info,
ensure_task_queue,
default_endpoint_id,
):
qp_list: list[TaskQueuePublisher] = []

def func(
*,
override_params: pika.connection.Parameters | None = None,
Expand All @@ -306,14 +272,17 @@ def func(
queue_opts = {"queue": queue_name, "arguments": {"x-expires": 30 * 1000}}
ensure_task_queue(exchange_opts=exchange_opts, queue_opts=queue_opts)

task_pub = TaskQueuePublisher(queue_info=q_info)
task_pub.connect()
qp = TaskQueuePublisher(queue_info=q_info)
qp.connect()
qp_list.append(qp)
if queue_purge: # Make sure queue is empty
task_pub._channel.queue_purge(q_info["queue"])
running_publishers.append(task_pub)
return task_pub
qp._channel.queue_purge(q_info["queue"])
return qp

yield func

return func
while qp_list:
qp_list.pop().close()


@pytest.fixture(scope="session")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading

import pika
import pika.exceptions
from globus_compute_endpoint.endpoint.rabbit_mq.base import SubscriberProcessStatus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,24 +89,8 @@ def _on_connection_closed(self, connection, exception):
exception, pika.exceptions.ConnectionClosedByClient
):
logger.info("Closing connection from client")
elif isinstance(exception, pika.exceptions.ConnectionClosedByBroker):
logger.warning(f"Connection closed, reopening in 5 seconds: {exception}")
self._connection.ioloop.call_later(5, self.reconnect)

def reconnect(self):
"""Will be invoked by the IOLoop timer if the connection is
closed. See the on_connection_closed method.
"""
# This is the old connection IOLoop instance, stop its ioloop
self._connection.ioloop.stop()

if self.status is not SubscriberProcessStatus.closing:
# Create a new connection
self._connection = self._connect()

# There is now a new connection, needs a new ioloop to run
self._connection.ioloop.start()
else:
logger.warning(f"Connection unexpectedly closed: {exception}")

def _on_channel_open(self, channel):
"""This method is invoked by pika when the channel has been opened.
Expand Down Expand Up @@ -223,7 +208,7 @@ def on_consumer_cancelled(self, method_frame):
if self._channel:
self._channel.close()

def on_message(self, _unused_channel, basic_deliver, _properties, body):
def on_message(self, _unused_channel, basic_deliver, _properties, body: bytes):
"""on_message pushes a message upon receipt into the external_queue
for async consumption. The message pushed to the external_queue
is of the following type : Tuple[str, bytes]
Expand Down Expand Up @@ -303,6 +288,7 @@ def run(self):
finally:
if self._connection and self._connection.ioloop:
self._connection.ioloop.close()
self._connection = None

def stop(self) -> None:
"""stop() is called by the parent to shutdown the subscriber"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from __future__ import annotations

import logging

import pika
import pika.channel
from globus_compute_endpoint.endpoint.rabbit_mq.base import RabbitPublisherStatus

logger = logging.getLogger(__name__)


class TaskQueuePublisher:
"""The TaskQueue is a direct rabbitMQ pipe that runs from the service
Expand Down Expand Up @@ -37,7 +33,6 @@ def __init__(
self.status = RabbitPublisherStatus.closed

def connect(self):
logger.debug("Connecting as server")
params = pika.URLParameters(self.queue_info["connection_url"])
self._connection = pika.BlockingConnection(params)
self._channel = self._connection.channel()
Expand All @@ -62,5 +57,8 @@ def publish(self, payload: bytes) -> None:

def close(self):
"""Close the connection and channels"""
self._connection.close()
if self._connection and self._connection.is_open:
self._connection.close()
self._channel = None
self._connection = None
self.status = RabbitPublisherStatus.closed
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
def test_simple_roundtrip(
flush_results,
start_task_q_publisher,
start_result_q_publisher,
start_task_q_subscriber,
start_result_q_subscriber,
start_result_q_publisher,
randomstring,
):
task_q, result_q = queue.SimpleQueue(), queue.SimpleQueue()
Expand All @@ -15,8 +15,8 @@ def test_simple_roundtrip(
task_pub = start_task_q_publisher()
result_pub = start_result_q_publisher()

task_sub = start_task_q_subscriber(task_queue=task_q)
result_sub = start_result_q_subscriber(result_q=result_q)
start_task_q_subscriber(task_queue=task_q)
start_result_q_subscriber(result_q=result_q)

message = f"Hello test_simple_roundtrip: {randomstring()}".encode()
task_pub.publish(message)
Expand All @@ -26,8 +26,5 @@ def test_simple_roundtrip(
result_pub.publish(task_message)
_, result_message = result_q.get(timeout=2)

task_sub._stop_event.set()
result_sub.kill_event.set()

_, expected = (result_pub.queue_info["test_routing_key"], message)
assert result_message == expected
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_connection_closed_shuts_down(start_task_q_subscriber):
try_assert(lambda: tqs._connection, "Ensure we establish a connection")

assert not tqs._stop_event.is_set()
tqs._on_connection_closed(tqs._connection, MemoryError())
tqs._connection.close()
try_assert(lambda: tqs._stop_event.is_set())

tqs.join(timeout=3)
Expand Down

0 comments on commit 6e15241

Please sign in to comment.