Skip to content

Commit

Permalink
Update Python tag_probe
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Jan 15, 2025
1 parent a772d87 commit 1199f53
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
19 changes: 16 additions & 3 deletions python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ from libcpp.memory cimport (
unique_ptr,
)
from libcpp.optional cimport nullopt
from libcpp.pair cimport pair
from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector
Expand Down Expand Up @@ -684,13 +685,25 @@ cdef class UCXWorker():

return num_canceled

def tag_probe(self, UCXXTag tag) -> bool:
cdef bint tag_matched
def tag_probe(self, UCXXTag tag, UCXXTagMask tag_mask = UCXXTagMaskFull) -> bool:
cdef Tag cpp_tag = <Tag><size_t>tag.value
cdef TagMask cpp_tag_mask = <TagMask><size_t>tag_mask.value
cdef ucp_tag_recv_info_t empty_tag_recv_info
cdef pair[bint, TagRecvInfo]* probed
cdef bint tag_matched = False

with nogil:
tag_matched = self._worker.get().tagProbe(cpp_tag)
# TagRecvInfo is not default-construtible, therefore we need to use a
# pointer, allocating it using a temporary ucp_tag_recv_info_t object
probed = new pair[bint, TagRecvInfo](
False,
TagRecvInfo(empty_tag_recv_info)
)
probed[0] = self._worker.get().tagProbe(cpp_tag, cpp_tag_mask)
tag_matched = probed[0].first
del probed

# TODO: Come up with good interface to expose TagRecvInfo as well
return tag_matched

def set_progress_thread_start_callback(
Expand Down
13 changes: 10 additions & 3 deletions python/ucxx/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from libcpp cimport bool as cpp_bool
from libcpp.functional cimport function
from libcpp.memory cimport shared_ptr, unique_ptr
from libcpp.optional cimport nullopt_t, optional
from libcpp.pair cimport pair
from libcpp.string cimport string
from libcpp.unordered_map cimport unordered_map as cpp_unordered_map
from libcpp.vector cimport vector
Expand Down Expand Up @@ -54,6 +55,9 @@ cdef extern from "ucp/api/ucp.h" nogil:

ctypedef uint64_t ucp_tag_t

ctypedef struct ucp_tag_recv_info_t:
pass

ctypedef enum ucs_status_t:
pass

Expand Down Expand Up @@ -174,10 +178,13 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
pass
cdef enum TagMask:
pass
cdef cppclass TagRecvInfo:
TagRecvInfo(const ucp_tag_recv_info_t&)

Tag senderTag
size_t length
cdef cppclass AmReceiverCallbackInfo:
pass
# ctypedef Tag CppTag
# ctypedef TagMask CppTagMask

# Using function[Buffer] here doesn't seem possible due to Cython bugs/limitations.
# The workaround is to use a raw C function pointer and let it be parsed by the
Expand Down Expand Up @@ -241,7 +248,7 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
size_t cancelInflightRequests(
uint64_t period, uint64_t maxAttempts
) except +raise_py_error
bint tagProbe(const Tag) const
pair[bint, TagRecvInfo] tagProbe(const Tag, const TagMask) const
void setProgressThreadStartCallback(
function[void(void*)] callback, void* callbackArg
)
Expand Down

0 comments on commit 1199f53

Please sign in to comment.