diff --git a/launch_testing_ros/launch_testing_ros/wait_for_topics.py b/launch_testing_ros/launch_testing_ros/wait_for_topics.py index 9736e5d8..5cf2c9cf 100644 --- a/launch_testing_ros/launch_testing_ros/wait_for_topics.py +++ b/launch_testing_ros/launch_testing_ros/wait_for_topics.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque import random import string from threading import Event @@ -47,12 +48,14 @@ def method_2(): print('Given topics are receiving messages !') print(wait_for_topics.topics_not_received()) # Should be an empty set print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'} + print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...] wait_for_topics.shutdown() """ - def __init__(self, topic_tuples, timeout=5.0): + def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10): self.topic_tuples = topic_tuples self.timeout = timeout + self.messages_received_buffer_length = messages_received_buffer_length self.__ros_context = rclpy.Context() rclpy.init(context=self.__ros_context) self.__ros_executor = SingleThreadedExecutor(context=self.__ros_context) @@ -64,9 +67,14 @@ def __init__(self, topic_tuples, timeout=5.0): self.__ros_spin_thread.start() def _prepare_ros_node(self): - node_name = '_test_node_' +\ - ''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) - self.__ros_node = _WaitForTopicsNode(name=node_name, node_context=self.__ros_context) + node_name = '_test_node_' + ''.join( + random.choices(string.ascii_uppercase + string.digits, k=10) + ) + self.__ros_node = _WaitForTopicsNode( + name=node_name, + node_context=self.__ros_context, + messages_received_buffer_length=self.messages_received_buffer_length, + ) self.__ros_executor.add_node(self.__ros_node) def wait(self): @@ -87,6 +95,12 @@ def topics_not_received(self): """Topics that did not receive any messages.""" return self.__ros_node.expected_topics - self.__ros_node.received_topics + def received_messages(self, topic_name): + """List of received messages of a specific topic.""" + if topic_name not in self.__ros_node.received_messages_buffer: + raise KeyError('No messages received for topic: ' + topic_name) + return list(self.__ros_node.received_messages_buffer[topic_name]) + def __enter__(self): if not self.wait(): raise RuntimeError('Did not receive messages on these topics: ', @@ -100,31 +114,49 @@ def __exit__(self, exep_type, exep_value, trace): class _WaitForTopicsNode(Node): """Internal node used for subscribing to a set of topics.""" - def __init__(self, name='test_node', node_context=None): - super().__init__(node_name=name, context=node_context) + def __init__( + self, name='test_node', node_context=None, messages_received_buffer_length=None + ): + super().__init__(node_name=name, context=node_context) # type: ignore self.msg_event_object = Event() - - def start_subscribers(self, topic_tuples): + self.messages_received_buffer_length = messages_received_buffer_length self.subscriber_list = [] - self.expected_topics = {name for name, _ in topic_tuples} + self.topic_tuples = [] + self.expected_topics = set() + self.received_topics = set() + self.received_messages_buffer = {} + + def _reset(self): + self.msg_event_object.clear() self.received_topics = set() + for buffer in self.received_messages_buffer.values(): + buffer.clear() + def start_subscribers(self, topic_tuples): + self._reset() for topic_name, topic_type in topic_tuples: - # Create a subscriber - self.subscriber_list.append( - self.create_subscription( - topic_type, - topic_name, - self.callback_template(topic_name), - 10 + if (topic_name, topic_type) not in self.topic_tuples: + self.topic_tuples.append((topic_name, topic_type)) + self.expected_topics.add(topic_name) + # Initialize ring buffer of received messages + self.received_messages_buffer[topic_name] = deque( + maxlen=self.messages_received_buffer_length + ) + # Create a subscriber + self.subscriber_list.append( + self.create_subscription( + topic_type, + topic_name, + self.callback_template(topic_name), + 10 + ) ) - ) def callback_template(self, topic_name): - def topic_callback(data): + self.get_logger().debug('Message received for ' + topic_name) + self.received_messages_buffer[topic_name].append(data) if topic_name not in self.received_topics: - self.get_logger().debug('Message received for ' + topic_name) self.received_topics.add(topic_name) if self.received_topics == self.expected_topics: self.msg_event_object.set() diff --git a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py index 3485843f..ad7288f0 100644 --- a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py +++ b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import re import sys import unittest @@ -22,7 +23,9 @@ import launch_testing.actions import launch_testing.markers from launch_testing_ros import WaitForTopics + import pytest + from std_msgs.msg import String @@ -57,11 +60,18 @@ def test_topics_successful(self, count): """All the supplied topics should be read successfully.""" topic_list = [('chatter_' + str(i), String) for i in range(count)] expected_topics = {'chatter_' + str(i) for i in range(count)} + message_pattern = re.compile(r'Hello World: \d+') # Method 1 : Using the magic methods and 'with' keyword - with WaitForTopics(topic_list, timeout=10.0) as wait_for_node_object_1: + with WaitForTopics( + topic_list, timeout=2.0, messages_received_buffer_length=10 + ) as wait_for_node_object_1: assert wait_for_node_object_1.topics_received() == expected_topics assert wait_for_node_object_1.topics_not_received() == set() + for topic_name, _ in topic_list: + assert len(wait_for_node_object_1.received_messages(topic_name)) >= 1 + message = wait_for_node_object_1.received_messages(topic_name).pop().data + assert message_pattern.match(message) # Multiple instances of WaitForNode() can be created safely as # their internal nodes spin in separate contexts @@ -70,6 +80,10 @@ def test_topics_successful(self, count): assert wait_for_node_object_2.wait() assert wait_for_node_object_2.topics_received() == expected_topics assert wait_for_node_object_2.topics_not_received() == set() + for topic_name, _ in topic_list: + assert len(wait_for_node_object_1.received_messages(topic_name)) >= 1 + message = wait_for_node_object_2.received_messages(topic_name).pop().data + assert message_pattern.match(message) wait_for_node_object_2.shutdown() def test_topics_unsuccessful(self, count):