From 3852c031981b0e1d49ed156c7487184a1d8882f3 Mon Sep 17 00:00:00 2001 From: Kevin Hunter Kesling Date: Wed, 15 Jan 2025 22:44:47 -0500 Subject: [PATCH] Convert test class to threading There's no need to create an external process to run the AMQP test class; in fact, doing so has hidden a couple of subtle pika interaction bugs. Consequently, though the motivating change in this commit is: - multiprocessing + threading/queue.SimpleQueue there are also some fixes in `test/.../result_queue_subscriber` around the proper shutdown procedure of the `pika.SelectConnection`: don't stop the loop prematurely, and close all resources. --- .../endpoint/rabbit_mq/result_publisher.py | 2 +- .../tests/integration/conftest.py | 33 +++++----- .../test_rabbit_mq/result_queue_subscriber.py | 62 +++++-------------- .../test_rabbit_mq/test_rabbit_e2e.py | 5 +- .../test_rabbit_mq/test_result_q.py | 12 ++-- 5 files changed, 45 insertions(+), 69 deletions(-) diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/result_publisher.py b/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/result_publisher.py index caafad54d..da6a6db4b 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/result_publisher.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/rabbit_mq/result_publisher.py @@ -153,6 +153,7 @@ def run(self) -> None: finally: if self._mq_conn and self._mq_conn.ioloop: self._mq_conn.ioloop.close() + self._mq_conn = None self._stop_event.set() @@ -378,7 +379,6 @@ def _stop_ioloop(self): self._mq_conn.close() elif self._mq_conn.is_closed: self._mq_conn.ioloop.stop() - self._mq_conn = None def publish(self, message: bytes) -> Future[None]: """ diff --git a/compute_endpoint/tests/integration/conftest.py b/compute_endpoint/tests/integration/conftest.py index b9f10b48d..66f8e1abf 100644 --- a/compute_endpoint/tests/integration/conftest.py +++ b/compute_endpoint/tests/integration/conftest.py @@ -1,10 +1,10 @@ from __future__ import annotations -import multiprocessing import os import queue import random import string +import threading import pika import pika.exceptions @@ -205,29 +205,34 @@ def func( @pytest.fixture -def start_result_q_subscriber(running_subscribers, pika_conn_params): +def start_result_q_subscriber(pika_conn_params): + qs_list: list[ResultQueueSubscriber] = [] + def func( *, - queue: multiprocessing.Queue | None = None, - kill_event: multiprocessing.Event | None = None, + result_q: queue.SimpleQueue | None = None, + kill_event: threading.Event | None = None, override_params: pika.connection.Parameters | None = None, ): if kill_event is None: - kill_event = multiprocessing.Event() - if queue is None: - queue = multiprocessing.Queue() - result_q = ResultQueueSubscriber( + kill_event = threading.Event() + if result_q is None: + result_q = queue.SimpleQueue() + qs = ResultQueueSubscriber( conn_params=pika_conn_params if not override_params else override_params, - external_queue=queue, + external_queue=result_q, kill_event=kill_event, ) - result_q.start() - running_subscribers.append(result_q) - if not result_q.test_class_ready.wait(10): + qs.start() + qs_list.append(qs) + if not qs.test_class_ready.wait(10): raise AssertionError("Result Queue subscriber failed to initialize") - return result_q + return qs - return func + yield func + + while qs_list: + qs_list.pop().stop() @pytest.fixture diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py b/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py index 1a361b1ea..edec2f1ef 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/result_queue_subscriber.py @@ -1,13 +1,8 @@ from __future__ import annotations import logging -import multiprocessing import queue - -# multiprocessing.Event is a method, not a class -# to annotate, we need the "real" class -# see: https://github.com/python/typeshed/issues/4266 -from multiprocessing.synchronize import Event as EventType +import threading import pika from globus_compute_endpoint.endpoint.rabbit_mq.base import SubscriberProcessStatus @@ -15,7 +10,7 @@ logger = logging.getLogger(__name__) -class ResultQueueSubscriber(multiprocessing.Process): +class ResultQueueSubscriber(threading.Thread): """The ResultQueueSubscriber is a direct rabbitMQ pipe subscriber that uses the SelectConnection adaptor to enable performance consumption of messages from the service @@ -31,21 +26,21 @@ def __init__( self, *, conn_params: pika.connection.Parameters, - external_queue: multiprocessing.Queue, - kill_event: EventType, + external_queue: queue.SimpleQueue, + kill_event: threading.Event, ): """ Parameters ---------- - conn_params: Connection Params + conn_params - external_queue: multiprocessing.Queue + external_queue Each incoming message will be pushed to the queue. Please note that upon pushing a message into this queue, it will be marked as delivered. - kill_event: multiprocessing.Event + kill_event An event object used to signal shutdown to the subscriber process. """ super().__init__() @@ -54,9 +49,7 @@ def __init__( self.conn_params = conn_params self.external_queue = external_queue self.kill_event = kill_event - self.test_class_ready = multiprocessing.Event() - self._channel_closed = multiprocessing.Event() - self._cleanup_complete = multiprocessing.Event() + self.test_class_ready = threading.Event() self._watcher_poll_period_s = 0.1 # seconds self._connection = None @@ -158,8 +151,6 @@ def _on_channel_closed(self, channel, exception): f"Channel closed with code:{exception.reply_code}, " f"error:{exception.reply_text}" ) - logger.debug("marking channel as closed") - self._channel_closed.set() def _on_exchange_declareok(self, unused_frame): """Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC @@ -280,36 +271,22 @@ def on_cancelok(self, unused_frame): self._channel.close() def _shutdown(self): - logger.debug("set status to 'closing'") self.status = SubscriberProcessStatus.closing logger.debug("closing connection") self._connection.close() - logger.debug("stopping ioloop") - self._connection.ioloop.stop() - logger.debug("waiting until channel is closed (timeout=1 second)") - if not self._channel_closed.wait(1.0): - logger.warning("reached timeout while waiting for channel closed") - logger.debug("closing connection to mp queue") - self.external_queue.close() - logger.debug("joining mp queue background thread") - self.external_queue.join_thread() - logger.info("shutdown done, setting cleanup event") - self._cleanup_complete.set() def event_watcher(self): """Polls the kill_event periodically to trigger a shutdown""" + self._connection.ioloop.call_later( + self._watcher_poll_period_s, self.event_watcher + ) if self.kill_event.is_set(): - logger.info("Kill event is set. Start subscriber shutdown") - try: + if self._connection.is_open: + logger.info("Kill event is set. Start subscriber shutdown") self._shutdown() - except Exception: - logger.exception("error while shutting down") - raise - logger.info("Shutdown complete") - else: - self._connection.ioloop.call_later( - self._watcher_poll_period_s, self.event_watcher - ) + logger.info("Shutdown complete") + elif self._connection.is_closed: + self._connection.ioloop.stop() def run(self): """Run the example consumer by connecting to RabbitMQ and then @@ -331,10 +308,5 @@ def stop(self) -> None: """stop() is called by the parent to shutdown the subscriber""" logger.info("Stopping") self.kill_event.set() - logger.info("Waiting for cleanup_complete") - if not self._cleanup_complete.wait(2 * self._watcher_poll_period_s): - logger.warning("Reached timeout while waiting for cleanup complete") - # join shouldn't block if the above did not raise a timeout - self.join() - self.close() + self.join() # thread stops or test busts/hangs. Very intentional logger.info("Cleanup done") diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py b/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py index b6472231e..11edaa660 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/test_rabbit_e2e.py @@ -1,4 +1,3 @@ -import multiprocessing import queue @@ -10,14 +9,14 @@ def test_simple_roundtrip( start_result_q_publisher, randomstring, ): - task_q, result_q = queue.SimpleQueue(), multiprocessing.Queue() + task_q, result_q = queue.SimpleQueue(), queue.SimpleQueue() # Start the publishers *first* as that route creates the queues task_pub = start_task_q_publisher() result_pub = start_result_q_publisher() task_sub = start_task_q_subscriber(task_queue=task_q) - result_sub = start_result_q_subscriber(queue=result_q) + result_sub = start_result_q_subscriber(result_q=result_q) message = f"Hello test_simple_roundtrip: {randomstring()}".encode() task_pub.publish(message) diff --git a/compute_endpoint/tests/integration/test_rabbit_mq/test_result_q.py b/compute_endpoint/tests/integration/test_rabbit_mq/test_result_q.py index 85949c8db..08a46524b 100644 --- a/compute_endpoint/tests/integration/test_rabbit_mq/test_result_q.py +++ b/compute_endpoint/tests/integration/test_rabbit_mq/test_result_q.py @@ -1,6 +1,6 @@ import json import logging -import multiprocessing +import queue import uuid import pika @@ -38,7 +38,7 @@ def test_result_queue_basic(start_result_q_publisher): def test_message_integrity_across_sizes( size, start_result_q_publisher, start_result_q_subscriber, default_endpoint_id ): - """Publish count messages from endpoint_1 + """Publish messages from endpoint_1 of different sizes; Confirm that the subscriber gets all of them. """ result_pub = start_result_q_publisher() @@ -47,8 +47,8 @@ def test_message_integrity_across_sizes( b_message = json.dumps(message).encode() result_pub.publish(b_message) - results_q = multiprocessing.Queue() - start_result_q_subscriber(queue=results_q) + results_q = queue.SimpleQueue() + start_result_q_subscriber(result_q=results_q) result_message = results_q.get(timeout=2) assert result_message == (result_pub.queue_info["test_routing_key"], b_message) @@ -72,8 +72,8 @@ def test_publish_multiple_then_subscribe( publish_messages(result_pub1, count=10) publish_messages(result_pub2, count=10) - results_q = multiprocessing.Queue() - start_result_q_subscriber(queue=results_q) + results_q = queue.SimpleQueue() + start_result_q_subscriber(result_q=results_q) all_results = {} for _i in range(total_messages):