Skip to content

Commit

Permalink
Convert test class to threading
Browse files Browse the repository at this point in the history
There's no need to create an external process to run the AMQP test class; in
fact, doing so has hidden a couple of subtle pika interaction bugs.
Consequently, though the motivating change in this commit is:

    - multiprocessing
    + threading/queue.SimpleQueue

there are also some fixes in `test/.../result_queue_subscriber` around the
proper shutdown procedure of the `pika.SelectConnection`: don't stop the loop
prematurely, and close all resources.
  • Loading branch information
khk-globus committed Jan 16, 2025
1 parent c29c0c0 commit 3852c03
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def run(self) -> None:
finally:
if self._mq_conn and self._mq_conn.ioloop:
self._mq_conn.ioloop.close()
self._mq_conn = None

self._stop_event.set()

Expand Down Expand Up @@ -378,7 +379,6 @@ def _stop_ioloop(self):
self._mq_conn.close()
elif self._mq_conn.is_closed:
self._mq_conn.ioloop.stop()
self._mq_conn = None

def publish(self, message: bytes) -> Future[None]:
"""
Expand Down
33 changes: 19 additions & 14 deletions compute_endpoint/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import multiprocessing
import os
import queue
import random
import string
import threading

import pika
import pika.exceptions
Expand Down Expand Up @@ -205,29 +205,34 @@ def func(


@pytest.fixture
def start_result_q_subscriber(running_subscribers, pika_conn_params):
def start_result_q_subscriber(pika_conn_params):
qs_list: list[ResultQueueSubscriber] = []

def func(
*,
queue: multiprocessing.Queue | None = None,
kill_event: multiprocessing.Event | None = None,
result_q: queue.SimpleQueue | None = None,
kill_event: threading.Event | None = None,
override_params: pika.connection.Parameters | None = None,
):
if kill_event is None:
kill_event = multiprocessing.Event()
if queue is None:
queue = multiprocessing.Queue()
result_q = ResultQueueSubscriber(
kill_event = threading.Event()
if result_q is None:
result_q = queue.SimpleQueue()
qs = ResultQueueSubscriber(
conn_params=pika_conn_params if not override_params else override_params,
external_queue=queue,
external_queue=result_q,
kill_event=kill_event,
)
result_q.start()
running_subscribers.append(result_q)
if not result_q.test_class_ready.wait(10):
qs.start()
qs_list.append(qs)
if not qs.test_class_ready.wait(10):
raise AssertionError("Result Queue subscriber failed to initialize")
return result_q
return qs

return func
yield func

while qs_list:
qs_list.pop().stop()


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from __future__ import annotations

import logging
import multiprocessing
import queue

# multiprocessing.Event is a method, not a class
# to annotate, we need the "real" class
# see: https://github.com/python/typeshed/issues/4266
from multiprocessing.synchronize import Event as EventType
import threading

import pika
from globus_compute_endpoint.endpoint.rabbit_mq.base import SubscriberProcessStatus

logger = logging.getLogger(__name__)


class ResultQueueSubscriber(multiprocessing.Process):
class ResultQueueSubscriber(threading.Thread):
"""The ResultQueueSubscriber is a direct rabbitMQ pipe subscriber that uses
the SelectConnection adaptor to enable performance consumption of messages
from the service
Expand All @@ -31,21 +26,21 @@ def __init__(
self,
*,
conn_params: pika.connection.Parameters,
external_queue: multiprocessing.Queue,
kill_event: EventType,
external_queue: queue.SimpleQueue,
kill_event: threading.Event,
):
"""
Parameters
----------
conn_params: Connection Params
conn_params
external_queue: multiprocessing.Queue
external_queue
Each incoming message will be pushed to the queue.
Please note that upon pushing a message into this queue, it will be
marked as delivered.
kill_event: multiprocessing.Event
kill_event
An event object used to signal shutdown to the subscriber process.
"""
super().__init__()
Expand All @@ -54,9 +49,7 @@ def __init__(
self.conn_params = conn_params
self.external_queue = external_queue
self.kill_event = kill_event
self.test_class_ready = multiprocessing.Event()
self._channel_closed = multiprocessing.Event()
self._cleanup_complete = multiprocessing.Event()
self.test_class_ready = threading.Event()
self._watcher_poll_period_s = 0.1 # seconds

self._connection = None
Expand Down Expand Up @@ -158,8 +151,6 @@ def _on_channel_closed(self, channel, exception):
f"Channel closed with code:{exception.reply_code}, "
f"error:{exception.reply_text}"
)
logger.debug("marking channel as closed")
self._channel_closed.set()

def _on_exchange_declareok(self, unused_frame):
"""Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC
Expand Down Expand Up @@ -280,36 +271,22 @@ def on_cancelok(self, unused_frame):
self._channel.close()

def _shutdown(self):
logger.debug("set status to 'closing'")
self.status = SubscriberProcessStatus.closing
logger.debug("closing connection")
self._connection.close()
logger.debug("stopping ioloop")
self._connection.ioloop.stop()
logger.debug("waiting until channel is closed (timeout=1 second)")
if not self._channel_closed.wait(1.0):
logger.warning("reached timeout while waiting for channel closed")
logger.debug("closing connection to mp queue")
self.external_queue.close()
logger.debug("joining mp queue background thread")
self.external_queue.join_thread()
logger.info("shutdown done, setting cleanup event")
self._cleanup_complete.set()

def event_watcher(self):
"""Polls the kill_event periodically to trigger a shutdown"""
self._connection.ioloop.call_later(
self._watcher_poll_period_s, self.event_watcher
)
if self.kill_event.is_set():
logger.info("Kill event is set. Start subscriber shutdown")
try:
if self._connection.is_open:
logger.info("Kill event is set. Start subscriber shutdown")
self._shutdown()
except Exception:
logger.exception("error while shutting down")
raise
logger.info("Shutdown complete")
else:
self._connection.ioloop.call_later(
self._watcher_poll_period_s, self.event_watcher
)
logger.info("Shutdown complete")
elif self._connection.is_closed:
self._connection.ioloop.stop()

def run(self):
"""Run the example consumer by connecting to RabbitMQ and then
Expand All @@ -331,10 +308,5 @@ def stop(self) -> None:
"""stop() is called by the parent to shutdown the subscriber"""
logger.info("Stopping")
self.kill_event.set()
logger.info("Waiting for cleanup_complete")
if not self._cleanup_complete.wait(2 * self._watcher_poll_period_s):
logger.warning("Reached timeout while waiting for cleanup complete")
# join shouldn't block if the above did not raise a timeout
self.join()
self.close()
self.join() # thread stops or test busts/hangs. Very intentional
logger.info("Cleanup done")
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import queue


Expand All @@ -10,14 +9,14 @@ def test_simple_roundtrip(
start_result_q_publisher,
randomstring,
):
task_q, result_q = queue.SimpleQueue(), multiprocessing.Queue()
task_q, result_q = queue.SimpleQueue(), queue.SimpleQueue()

# Start the publishers *first* as that route creates the queues
task_pub = start_task_q_publisher()
result_pub = start_result_q_publisher()

task_sub = start_task_q_subscriber(task_queue=task_q)
result_sub = start_result_q_subscriber(queue=result_q)
result_sub = start_result_q_subscriber(result_q=result_q)

message = f"Hello test_simple_roundtrip: {randomstring()}".encode()
task_pub.publish(message)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
import multiprocessing
import queue
import uuid

import pika
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_result_queue_basic(start_result_q_publisher):
def test_message_integrity_across_sizes(
size, start_result_q_publisher, start_result_q_subscriber, default_endpoint_id
):
"""Publish count messages from endpoint_1
"""Publish messages from endpoint_1 of different sizes;
Confirm that the subscriber gets all of them.
"""
result_pub = start_result_q_publisher()
Expand All @@ -47,8 +47,8 @@ def test_message_integrity_across_sizes(
b_message = json.dumps(message).encode()
result_pub.publish(b_message)

results_q = multiprocessing.Queue()
start_result_q_subscriber(queue=results_q)
results_q = queue.SimpleQueue()
start_result_q_subscriber(result_q=results_q)

result_message = results_q.get(timeout=2)
assert result_message == (result_pub.queue_info["test_routing_key"], b_message)
Expand All @@ -72,8 +72,8 @@ def test_publish_multiple_then_subscribe(
publish_messages(result_pub1, count=10)
publish_messages(result_pub2, count=10)

results_q = multiprocessing.Queue()
start_result_q_subscriber(queue=results_q)
results_q = queue.SimpleQueue()
start_result_q_subscriber(result_q=results_q)

all_results = {}
for _i in range(total_messages):
Expand Down

0 comments on commit 3852c03

Please sign in to comment.