From 93edd75122dd21c1963f42ecbc447b1658393940 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 Jan 2025 12:17:19 +0100 Subject: [PATCH] Improve `tagProbe` and AM receiver callback (#348) Improve `tagProbe` by accepting a tag mask for matching and return probed tag information. Expose also the sender endpoint handle to AM receive callback so that the callback is capable of knowing the origin of the message. Additionally, fix C++ request tests that were being unintentionally skipped. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Mads R. B. Kristensen (https://github.com/madsbk) URL: https://github.com/rapidsai/ucxx/pull/348 --- cpp/include/ucxx/typedefs.h | 19 +++++++++++++++++-- cpp/include/ucxx/worker.h | 19 ++++++++++++++----- cpp/src/internal/request_am.cpp | 2 +- cpp/src/worker.cpp | 11 ++++++++--- cpp/tests/request.cpp | 25 +++++++++++++------------ cpp/tests/worker.cpp | 13 +++++++++---- python/ucxx/ucxx/_lib/libucxx.pyx | 19 ++++++++++++++++--- python/ucxx/ucxx/_lib/ucxx_api.pxd | 12 +++++++++--- 8 files changed, 87 insertions(+), 33 deletions(-) diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 2cbf7960..63a0f968 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -72,6 +72,20 @@ enum TagMask : ucp_tag_t {}; */ static constexpr TagMask TagMaskFull{std::numeric_limits>::max()}; +/** + * @brief Information about probed tag message. + * + * Contains information returned when probing by a tag message received by the worker but + * not yet consumed. + */ +class TagRecvInfo { + public: + Tag senderTag; ///< Sender tag + size_t length; ///< The size of the received data + + explicit TagRecvInfo(const ucp_tag_recv_info_t&); +}; + /** * @brief A UCP configuration map. * @@ -124,9 +138,10 @@ typedef std::function(size_t)> AmAllocatorType; * @brief Active Message receiver callback. * * Type for a custom Active Message receiver callback, executed by the remote worker upon - * Active Message request completion. + * Active Message request completion. The first parameter is the request that completed, + * the second is the handle of the UCX endpoint of the sender. */ -typedef std::function)> AmReceiverCallbackType; +typedef std::function, ucp_ep_h)> AmReceiverCallbackType; /** * @brief Active Message receiver callback owner name. diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 5443ff6c..656f58e1 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -684,21 +685,29 @@ class Worker : public Component { * * Checks the worker for any uncaught tag messages. An uncaught tag message is any * tag message that has been fully or partially received by the worker, but not matched - * by a corresponding `ucp_tag_recv_*` call. + * by a corresponding `ucp_tag_recv_*` call. Additionally, returns information about the + * tag message. * * @code{.cpp} * // `worker` is `std::shared_ptr` - * assert(!worker->tagProbe(0)); + * auto probe = worker->tagProbe(0); + * assert(!probe.first) * * // `ep` is a remote `std::shared_ptrtagSend(buffer, length, 0); * - * assert(worker->tagProbe(0)); + * probe = worker->tagProbe(0); + * assert(probe.first); + * assert(probe.second.tag == 0); + * assert(probe.second.length == length); * @endcode * - * @returns `true` if any uncaught messages were received, `false` otherwise. + * @returns pair where first elements is `true` if any uncaught messages were received, + * `false` otherwise, and second element contain the information from the tag + * receive. */ - [[nodiscard]] bool tagProbe(const Tag tag); + [[nodiscard]] std::pair tagProbe(const Tag tag, + const TagMask tagMask = TagMaskFull); /** * @brief Enqueue a tag receive operation. diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index ff03e111..d45dee65 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -29,7 +29,7 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, if (receiverCallback) { _request->_callback = [this, receiverCallback](ucs_status_t, std::shared_ptr) { - receiverCallback(_request); + receiverCallback(_request, _ep); }; } } diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 3e19a7d6..d3657377 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -576,7 +576,12 @@ void Worker::removeInflightRequest(const Request* const request) } } -bool Worker::tagProbe(const Tag tag) +TagRecvInfo::TagRecvInfo(const ucp_tag_recv_info_t& info) + : senderTag(Tag(info.sender_tag)), length(info.length) +{ +} + +std::pair Worker::tagProbe(const Tag tag, const TagMask tagMask) { if (!isProgressThreadRunning()) { progress(); @@ -592,9 +597,9 @@ bool Worker::tagProbe(const Tag tag) } ucp_tag_recv_info_t info; - ucp_tag_message_h tag_message = ucp_tag_probe_nb(_handle, tag, TagMaskFull, 0, &info); + ucp_tag_message_h tag_message = ucp_tag_probe_nb(_handle, tag, tagMask, 0, &info); - return tag_message != NULL; + return {tag_message != NULL, TagRecvInfo(info)}; } std::shared_ptr Worker::tagRecv(void* buffer, diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index f9ac5437..5349bdc8 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -59,17 +59,18 @@ class RequestTest : public ::testing::TestWithParam< void SetUp() { + std::tie(_bufferType, + _registerCustomAmAllocator, + _enableDelayedSubmission, + _progressMode, + _messageLength) = GetParam(); + if (_bufferType == ucxx::BufferType::RMM) { #if !UCXX_ENABLE_RMM GTEST_SKIP() << "UCXX was not built with RMM support"; #endif } - std::tie(_bufferType, - _registerCustomAmAllocator, - _enableDelayedSubmission, - _progressMode, - _messageLength) = GetParam(); _memoryType = (_bufferType == ucxx::BufferType::RMM) ? UCS_MEMORY_TYPE_CUDA : UCS_MEMORY_TYPE_HOST; _messageSize = _messageLength * sizeof(int); @@ -168,13 +169,14 @@ TEST_P(RequestTest, ProgressAm) GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; } + if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { #if !UCXX_ENABLE_RMM - GTEST_SKIP() << "UCXX was not built with RMM support"; + GTEST_SKIP() << "UCXX was not built with RMM support"; #else - if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { return std::make_shared(length); }); +#endif } allocate(1, false); @@ -198,7 +200,6 @@ TEST_P(RequestTest, ProgressAm) // Assert data correctness ASSERT_THAT(_recv[0], ContainerEq(_send[0])); -#endif } TEST_P(RequestTest, ProgressAmReceiverCallback) @@ -207,13 +208,14 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; } + if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { #if !UCXX_ENABLE_RMM - GTEST_SKIP() << "UCXX was not built with RMM support"; + GTEST_SKIP() << "UCXX was not built with RMM support"; #else - if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { return std::make_shared(length); }); +#endif } // Define AM receiver callback's owner and id for callback @@ -226,7 +228,7 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) // Define AM receiver callback and register with worker std::vector> receivedRequests; auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequests, &mutex](std::shared_ptr req) { + [this, &receivedRequests, &mutex](std::shared_ptr req, ucp_ep_h) { { std::lock_guard lock(mutex); receivedRequests.push_back(req); @@ -260,7 +262,6 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) // Assert data correctness ASSERT_THAT(_recv[0], ContainerEq(_send[0])); -#endif } TEST_P(RequestTest, ProgressStream) diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index c941c20a..9bc6bfbb 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -110,7 +110,8 @@ TEST_F(WorkerTest, TagProbe) auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling); auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); - ASSERT_FALSE(_worker->tagProbe(ucxx::Tag{0})); + auto probed = _worker->tagProbe(ucxx::Tag{0}); + ASSERT_FALSE(probed.first); std::vector buf{123}; std::vector> requests; @@ -119,10 +120,14 @@ TEST_F(WorkerTest, TagProbe) loopWithTimeout(std::chrono::milliseconds(5000), [this, progressWorker]() { progressWorker(); - return _worker->tagProbe(ucxx::Tag{0}); + auto probed = _worker->tagProbe(ucxx::Tag{0}); + return probed.first; }); - ASSERT_TRUE(_worker->tagProbe(ucxx::Tag{0})); + probed = _worker->tagProbe(ucxx::Tag{0}); + ASSERT_TRUE(probed.first); + ASSERT_EQ(probed.second.senderTag, ucxx::Tag{0}); + ASSERT_EQ(probed.second.length, buf.size() * sizeof(int)); } TEST_F(WorkerTest, AmProbe) @@ -189,7 +194,7 @@ TEST_P(WorkerProgressTest, ProgressAmReceiverCallback) // Define AM receiver callback and register with worker std::vector> receivedRequests; auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequests, &mutex](std::shared_ptr req) { + [this, &receivedRequests, &mutex](std::shared_ptr req, ucp_ep_h) { { std::lock_guard lock(mutex); receivedRequests.push_back(req); diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 434bef16..df177632 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -25,6 +25,7 @@ from libcpp.memory cimport ( unique_ptr, ) from libcpp.optional cimport nullopt +from libcpp.pair cimport pair from libcpp.string cimport string from libcpp.utility cimport move from libcpp.vector cimport vector @@ -684,13 +685,25 @@ cdef class UCXWorker(): return num_canceled - def tag_probe(self, UCXXTag tag) -> bool: - cdef bint tag_matched + def tag_probe(self, UCXXTag tag, UCXXTagMask tag_mask = UCXXTagMaskFull) -> bool: cdef Tag cpp_tag = tag.value + cdef TagMask cpp_tag_mask = tag_mask.value + cdef ucp_tag_recv_info_t empty_tag_recv_info + cdef pair[bint, TagRecvInfo]* probed + cdef bint tag_matched = False with nogil: - tag_matched = self._worker.get().tagProbe(cpp_tag) + # TagRecvInfo is not default-construtible, therefore we need to use a + # pointer, allocating it using a temporary ucp_tag_recv_info_t object + probed = new pair[bint, TagRecvInfo]( + False, + TagRecvInfo(empty_tag_recv_info) + ) + probed[0] = self._worker.get().tagProbe(cpp_tag, cpp_tag_mask) + tag_matched = probed[0].first + del probed + # TODO: Come up with good interface to expose TagRecvInfo as well return tag_matched def set_progress_thread_start_callback( diff --git a/python/ucxx/ucxx/_lib/ucxx_api.pxd b/python/ucxx/ucxx/_lib/ucxx_api.pxd index 88512a4f..93d2683a 100644 --- a/python/ucxx/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/ucxx/_lib/ucxx_api.pxd @@ -9,6 +9,7 @@ from libcpp cimport bool as cpp_bool from libcpp.functional cimport function from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.optional cimport nullopt_t, optional +from libcpp.pair cimport pair from libcpp.string cimport string from libcpp.unordered_map cimport unordered_map as cpp_unordered_map from libcpp.vector cimport vector @@ -54,6 +55,9 @@ cdef extern from "ucp/api/ucp.h" nogil: ctypedef uint64_t ucp_tag_t + ctypedef struct ucp_tag_recv_info_t: + pass + ctypedef enum ucs_status_t: pass @@ -174,10 +178,12 @@ cdef extern from "" namespace "ucxx" nogil: pass cdef enum TagMask: pass + cdef cppclass TagRecvInfo: + TagRecvInfo(const ucp_tag_recv_info_t&) + Tag senderTag + size_t length cdef cppclass AmReceiverCallbackInfo: pass - # ctypedef Tag CppTag - # ctypedef TagMask CppTagMask # Using function[Buffer] here doesn't seem possible due to Cython bugs/limitations. # The workaround is to use a raw C function pointer and let it be parsed by the @@ -241,7 +247,7 @@ cdef extern from "" namespace "ucxx" nogil: size_t cancelInflightRequests( uint64_t period, uint64_t maxAttempts ) except +raise_py_error - bint tagProbe(const Tag) const + pair[bint, TagRecvInfo] tagProbe(const Tag, const TagMask) const void setProgressThreadStartCallback( function[void(void*)] callback, void* callbackArg )