diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/task_queue_subscriber.py b/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/task_queue_subscriber.py index c71a53216..2a82359a7 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/task_queue_subscriber.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/task_queue_subscriber.py @@ -151,6 +151,7 @@ def run(self): finally: if self._connection and self._connection.ioloop: self._connection.ioloop.close() + self._connection = None self._stop_event.set() logger.debug("%s Shutdown complete", self) @@ -346,7 +347,6 @@ def _stop_ioloop(self): self._connection.close() elif self._connection.is_closed: self._connection.ioloop.stop() - self._connection = None def _event_watcher(self): """Polls the stop_event periodically to trigger a shutdown""" diff --git a/compute_endpoint/tests/integration/conftest.py b/compute_endpoint/tests/integration/conftest.py index 66f8e1abf..f55489044 100644 --- a/compute_endpoint/tests/integration/conftest.py +++ b/compute_endpoint/tests/integration/conftest.py @@ -10,7 +10,6 @@ import pika.exceptions import pytest from globus_compute_endpoint.endpoint.rabbit_mq import ( - RabbitPublisherStatus, ResultPublisher, TaskQueueSubscriber, ) @@ -115,29 +114,6 @@ def task_queue_info(rabbitmq_conn_url, tod_session_num, request) -> dict: } -@pytest.fixture -def running_subscribers(request): - run_list = [] - - def cleanup(): - for x in run_list: - try: # cannot check is_alive on closed proc - is_alive = x.is_alive() - except ValueError: - is_alive = False - if is_alive: - try: - x.stop() - except Exception as e: - x.terminate() - raise Exception( - f"{x.__class__.__name__} did not shutdown correctly" - ) from e - - request.addfinalizer(cleanup) - return run_list - - @pytest.fixture(scope="session") def ensure_result_queue(pika_conn_params): queues_created = [] @@ -180,7 +156,7 @@ def start_task_q_subscriber( task_queue_info, ensure_task_queue, ): - running_subscribers: list[TaskQueueSubscriber] = [] + qs_list: list[TaskQueueSubscriber] = [] def func( *, @@ -192,16 +168,17 @@ def func( q_info = task_queue_info if override_params is None else override_params ensure_task_queue(queue_opts={"queue": q_info["queue"]}) - tqs = TaskQueueSubscriber(queue_info=q_info, pending_task_queue=task_queue) - tqs.start() - running_subscribers.append(tqs) - return tqs + qs = TaskQueueSubscriber(queue_info=q_info, pending_task_queue=task_queue) + qs.start() + qs_list.append(qs) + return qs yield func - for sub in running_subscribers: - sub._stop_event.set() - sub.join() + while qs_list: + qs = qs_list.pop() + qs.stop() + qs.join() @pytest.fixture @@ -235,28 +212,13 @@ def func( qs_list.pop().stop() -@pytest.fixture -def running_publishers(request): - run_list = [] - - def cleanup(): - for x in run_list: - if x.status is RabbitPublisherStatus.connected: - if hasattr(x, "stop"): - x.stop() # ResultPublisher - else: - x.close() # TaskQueuePublisher (from tests) - - request.addfinalizer(cleanup) - return run_list - - @pytest.fixture def start_result_q_publisher( - running_publishers, result_queue_info, ensure_result_queue, ): + qp_list: list[ResultPublisher] = [] + def func( *, override_params: dict | None = None, @@ -273,24 +235,28 @@ def func( queue_opts = {"queue": queue_name, "durable": True} ensure_result_queue(exchange_opts=exchange_opts, queue_opts=queue_opts) - result_pub = ResultPublisher(queue_info=q_info) - result_pub.start() + qp = ResultPublisher(queue_info=q_info) + qp.start() + qp_list.append(qp) if queue_purge: # Make sure queue is empty - try_assert(lambda: result_pub._mq_chan is not None) - result_pub._mq_chan.queue_purge(q_info["queue"]) - running_publishers.append(result_pub) - return result_pub + try_assert(lambda: qp._mq_chan is not None) + qp._mq_chan.queue_purge(q_info["queue"]) + return qp - return func + yield func + + while qp_list: + qp_list.pop().stop(timeout=None) @pytest.fixture def start_task_q_publisher( - running_publishers, task_queue_info, ensure_task_queue, default_endpoint_id, ): + qp_list: list[TaskQueuePublisher] = [] + def func( *, override_params: pika.connection.Parameters | None = None, @@ -306,14 +272,17 @@ def func( queue_opts = {"queue": queue_name, "arguments": {"x-expires": 30 * 1000}} ensure_task_queue(exchange_opts=exchange_opts, queue_opts=queue_opts) - task_pub = TaskQueuePublisher(queue_info=q_info) - task_pub.connect() + qp = TaskQueuePublisher(queue_info=q_info) + qp.connect() + qp_list.append(qp) if queue_purge: # Make sure queue is empty - task_pub._channel.queue_purge(q_info["queue"]) - running_publishers.append(task_pub) - return task_pub + qp._channel.queue_purge(q_info["queue"]) + return qp + + yield func - return func + while qp_list: + qp_list.pop().close() @pytest.fixture(scope="session") diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py b/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py index edec2f1ef..79346e79d 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py @@ -5,6 +5,7 @@ import threading import pika +import pika.exceptions from globus_compute_endpoint.endpoint.rabbit_mq.base import SubscriberProcessStatus logger = logging.getLogger(__name__) @@ -88,24 +89,8 @@ def _on_connection_closed(self, connection, exception): exception, pika.exceptions.ConnectionClosedByClient ): logger.info("Closing connection from client") - elif isinstance(exception, pika.exceptions.ConnectionClosedByBroker): - logger.warning(f"Connection closed, reopening in 5 seconds: {exception}") - self._connection.ioloop.call_later(5, self.reconnect) - - def reconnect(self): - """Will be invoked by the IOLoop timer if the connection is - closed. See the on_connection_closed method. - - """ - # This is the old connection IOLoop instance, stop its ioloop - self._connection.ioloop.stop() - - if self.status is not SubscriberProcessStatus.closing: - # Create a new connection - self._connection = self._connect() - - # There is now a new connection, needs a new ioloop to run - self._connection.ioloop.start() + else: + logger.warning(f"Connection unexpectedly closed: {exception}") def _on_channel_open(self, channel): """This method is invoked by pika when the channel has been opened. @@ -223,7 +208,7 @@ def on_consumer_cancelled(self, method_frame): if self._channel: self._channel.close() - def on_message(self, _unused_channel, basic_deliver, _properties, body): + def on_message(self, _unused_channel, basic_deliver, _properties, body: bytes): """on_message pushes a message upon receipt into the external_queue for async consumption. The message pushed to the external_queue is of the following type : Tuple[str, bytes] @@ -303,6 +288,7 @@ def run(self): finally: if self._connection and self._connection.ioloop: self._connection.ioloop.close() + self._connection = None def stop(self) -> None: """stop() is called by the parent to shutdown the subscriber""" diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/task_queue_publisher.py b/compute_endpoint/tests/integration/test_rabbit_mq/task_queue_publisher.py index 8a17ea473..7ce2efd79 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/task_queue_publisher.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/task_queue_publisher.py @@ -1,13 +1,9 @@ from __future__ import annotations -import logging - import pika import pika.channel from globus_compute_endpoint.endpoint.rabbit_mq.base import RabbitPublisherStatus -logger = logging.getLogger(__name__) - class TaskQueuePublisher: """The TaskQueue is a direct rabbitMQ pipe that runs from the service @@ -37,7 +33,6 @@ def __init__( self.status = RabbitPublisherStatus.closed def connect(self): - logger.debug("Connecting as server") params = pika.URLParameters(self.queue_info["connection_url"]) self._connection = pika.BlockingConnection(params) self._channel = self._connection.channel() @@ -62,5 +57,8 @@ def publish(self, payload: bytes) -> None: def close(self): """Close the connection and channels""" - self._connection.close() + if self._connection and self._connection.is_open: + self._connection.close() + self._channel = None + self._connection = None self.status = RabbitPublisherStatus.closed diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py b/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py index 11edaa660..7ee2daecc 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py @@ -4,9 +4,9 @@ def test_simple_roundtrip( flush_results, start_task_q_publisher, + start_result_q_publisher, start_task_q_subscriber, start_result_q_subscriber, - start_result_q_publisher, randomstring, ): task_q, result_q = queue.SimpleQueue(), queue.SimpleQueue() @@ -15,8 +15,8 @@ def test_simple_roundtrip( task_pub = start_task_q_publisher() result_pub = start_result_q_publisher() - task_sub = start_task_q_subscriber(task_queue=task_q) - result_sub = start_result_q_subscriber(result_q=result_q) + start_task_q_subscriber(task_queue=task_q) + start_result_q_subscriber(result_q=result_q) message = f"Hello test_simple_roundtrip: {randomstring()}".encode() task_pub.publish(message) @@ -26,8 +26,5 @@ def test_simple_roundtrip( result_pub.publish(task_message) _, result_message = result_q.get(timeout=2) - task_sub._stop_event.set() - result_sub.kill_event.set() - _, expected = (result_pub.queue_info["test_routing_key"], message) assert result_message == expected diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/test_task_q.py b/compute_endpoint/tests/integration/test_rabbit_mq/test_task_q.py index 274855458..9c1284cf4 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/test_task_q.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/test_task_q.py @@ -165,7 +165,7 @@ def test_connection_closed_shuts_down(start_task_q_subscriber): try_assert(lambda: tqs._connection, "Ensure we establish a connection") assert not tqs._stop_event.is_set() - tqs._on_connection_closed(tqs._connection, MemoryError()) + tqs._connection.close() try_assert(lambda: tqs._stop_event.is_set()) tqs.join(timeout=3)