Skip to content

Commit

Permalink
Merge pull request #133 from rapidsai/branch-0.35
Browse files Browse the repository at this point in the history
Forward-merge branch-0.35 to branch-0.36
  • Loading branch information
GPUtester authored Nov 20, 2023
2 parents a25716e + b4957da commit d0ea493
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 55 deletions.
29 changes: 27 additions & 2 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,21 @@ class Endpoint : public Component {
* This is usually executed by `close()`, when pending requests will no longer be able
* to complete.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
* where the progress thread might be attempting to acquire a resource (e.g., the Python
* GIL) while the current thread owns that resource. In particular for Python, the
* `~Endpoint()` will call this method for which we can't release the GIL when the garbage
* collector runs and destroys the object.
*
* @param[in] period maximum period to wait for a generic pre/post progress thread
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
* @returns Number of requests that were canceled.
*/
size_t cancelInflightRequests();
size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);

/**
* @brief Register a user-defined callback to call when endpoint closes.
Expand Down Expand Up @@ -507,8 +519,21 @@ class Endpoint : public Component {
* If the endpoint was created with error handling support, the error callback will be
* executed, implying the user-defined callback will also be executed if one was
* registered with `setCloseCallback()`.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
* where the progress thread might be attempting to acquire a resource (e.g., the Python
* GIL) while the current thread owns that resource. In particular for Python, the
* `~Endpoint()` will call this method for which we can't release the GIL when the garbage
* collector runs and destroys the object.
*
* @param[in] period maximum period to wait for a generic pre/post progress thread
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
*/
void close();
void close(uint64_t period = 0, uint64_t maxAttempts = 1);
};

} // namespace ucxx
11 changes: 9 additions & 2 deletions cpp/include/ucxx/utils/callback_notifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,18 @@ class CallbackNotifier {
void set();

/**
* @brief Wait until `set()` has been called
* @brief Wait until `set()` has been called or period has elapsed.
*
* Wait until `set()` has been called, or period (in nanoseconds) has elapsed (only
* applicable if using glibc 2.25 and higher).
*
* See also `std::condition_variable::wait`.
*
* @param[in] period maximum period in nanoseconds to wait for or `0` to wait forever.
*
* @return `true` if waiting finished or `false` if a timeout occurred.
*/
void wait();
bool wait(uint64_t period = 0);
};

} // namespace utils
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,21 @@ class Worker : public Component {
* Cancel inflight requests, returning the total number of requests that were canceled.
* This is usually executed during the progress loop.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
* where the progress thread might be attempting to acquire a resource (e.g., the Python
* GIL) while the current thread owns that resource. In particular for Python, the
* `~Worker()` will call this method for which we can't release the GIL when the garbage
* collector runs and destroys the object.
*
* @param[in] period maximum period to wait for a generic pre/post progress thread
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
* @returns Number of requests that were canceled.
*/
size_t cancelInflightRequests();
size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);

/**
* @brief Schedule cancelation of inflight requests.
Expand Down
88 changes: 54 additions & 34 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(std::shared_ptr<Worker

Endpoint::~Endpoint()
{
close();
close(10000000000 /* 10s */);
ucxx_trace("Endpoint destroyed: %p, UCP handle: %p", this, _originalHandle);
}

void Endpoint::close()
void Endpoint::close(uint64_t period, uint64_t maxAttempts)
{
if (_handle == nullptr) return;

size_t canceled = cancelInflightRequests();
size_t canceled = cancelInflightRequests(3000000000 /* 3s */, 3);
ucxx_debug("Endpoint %p canceled %lu requests", _handle, canceled);

// Close the endpoint
Expand All @@ -161,29 +161,39 @@ void Endpoint::close()
ucs_status_ptr_t status;

if (worker->isProgressThreadRunning()) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &status, closeMode]() {
status = ucp_ep_close_nb(_handle, closeMode);
callbackNotifierPre.set();
});
callbackNotifierPre.wait();

while (UCS_PTR_IS_PTR(status)) {
utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([this, &callbackNotifierPost, &status]() {
ucs_status_t s = ucp_request_check_status(status);
if (UCS_PTR_STATUS(s) != UCS_INPROGRESS) {
ucp_request_free(status);
_callbackData->status = UCS_PTR_STATUS(s);
if (UCS_PTR_STATUS(status) != UCS_OK) {
ucxx_error("Error while closing endpoint: %s",
ucs_status_string(UCS_PTR_STATUS(status)));
bool closeSuccess = false;
for (uint64_t i = 0; i < maxAttempts && !closeSuccess; ++i) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &status, closeMode]() {
status = ucp_ep_close_nb(_handle, closeMode);
callbackNotifierPre.set();
});
if (!callbackNotifierPre.wait(period)) continue;

while (UCS_PTR_IS_PTR(status)) {
utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([this, &callbackNotifierPost, &status]() {
ucs_status_t s = ucp_request_check_status(status);
if (UCS_PTR_STATUS(s) != UCS_INPROGRESS) {
ucp_request_free(status);
_callbackData->status = UCS_PTR_STATUS(s);
if (UCS_PTR_STATUS(status) != UCS_OK) {
ucxx_error("Error while closing endpoint: %s",
ucs_status_string(UCS_PTR_STATUS(status)));
}
}
}

callbackNotifierPost.set();
});
callbackNotifierPost.wait();
callbackNotifierPost.set();
});
if (!callbackNotifierPost.wait(period)) continue;
}

closeSuccess = true;
}

if (!closeSuccess) {
_callbackData->status = UCS_ERR_ENDPOINT_TIMEOUT;
ucxx_error("All attempts to close timed out on endpoint: %p, UCP handle: %p", this, _handle);
}
} else {
status = ucp_ep_close_nb(_handle, closeMode);
Expand Down Expand Up @@ -257,7 +267,7 @@ void Endpoint::removeInflightRequest(const Request* const request)
_inflightRequests->remove(request);
}

size_t Endpoint::cancelInflightRequests()
size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts)
{
auto worker = ::ucxx::getWorker(this->_parent);
size_t canceled = 0;
Expand All @@ -266,15 +276,25 @@ size_t Endpoint::cancelInflightRequests()
canceled = _inflightRequests->cancelAll();
worker->progress();
} else if (worker->isProgressThreadRunning()) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &canceled]() {
canceled = _inflightRequests->cancelAll();
callbackNotifierPre.set();
});
callbackNotifierPre.wait();
utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
callbackNotifierPost.wait();
bool cancelSuccess = false;
for (uint64_t i = 0; i < maxAttempts && !cancelSuccess; ++i) {
utils::CallbackNotifier callbackNotifierPre{};
worker->registerGenericPre([this, &callbackNotifierPre, &canceled]() {
canceled = _inflightRequests->cancelAll();
callbackNotifierPre.set();
});
if (!callbackNotifierPre.wait(period)) continue;

utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
if (!callbackNotifierPost.wait(period)) continue;

cancelSuccess = true;
}
if (!cancelSuccess)
ucxx_error("All attempts to cancel inflight requests failed on endpoint: %p, UCP handle: %p",
this,
_handle);
} else {
canceled = _inflightRequests->cancelAll();
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ Listener::~Listener()
ucp_listener_destroy(_handle);
callbackNotifierPre.set();
});
callbackNotifierPre.wait();
callbackNotifierPre.wait(10000000000 /* 10s */);

utils::CallbackNotifier callbackNotifierPost{};
worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
callbackNotifierPost.wait();
callbackNotifierPost.wait(10000000000 /* 10s */);
} else {
ucp_listener_destroy(_handle);
worker->progress();
Expand Down
13 changes: 11 additions & 2 deletions cpp/src/utils/callback_notifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,24 @@ void CallbackNotifier::set()
_conditionVariable.notify_all();
}
}
void CallbackNotifier::wait()

bool CallbackNotifier::wait(uint64_t period)
{
if (_useSpinlock) {
while (!_flag.load(std::memory_order_acquire)) {}
} else {
std::unique_lock lock(_mutex);
// Likewise here, the mutex provides ordering.
_conditionVariable.wait(lock, [this]() { return _flag.load(std::memory_order_relaxed); });
if (period > 0) {
return _conditionVariable.wait_for(
lock, std::chrono::duration<uint64_t, std::nano>(period), [this]() {
return _flag.load(std::memory_order_relaxed);
});
} else {
_conditionVariable.wait(lock, [this]() { return _flag.load(std::memory_order_relaxed); });
}
}
return true;
}

} // namespace utils
Expand Down
34 changes: 22 additions & 12 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,

Worker::~Worker()
{
size_t canceled = cancelInflightRequests();
size_t canceled = cancelInflightRequests(3000000000 /* 3s */, 3);
ucxx_debug("Worker %p canceled %lu requests", _handle, canceled);

stopProgressThreadNoWarn();
Expand Down Expand Up @@ -266,7 +266,7 @@ bool Worker::progress()
if (progressScheduledCancel) ret |= progressPending();

// Requests that were not completed now must be canceled.
if (cancelInflightRequests() > 0) ret |= progressPending();
if (cancelInflightRequests(3000000000 /* 3s */, 3) > 0) ret |= progressPending();

return ret;
}
Expand Down Expand Up @@ -399,7 +399,7 @@ bool Worker::isProgressThreadRunning() { return _progressThread != nullptr; }

std::thread::id Worker::getProgressThreadId() { return _progressThreadId; }

size_t Worker::cancelInflightRequests()
size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts)
{
size_t canceled = 0;

Expand All @@ -413,16 +413,26 @@ size_t Worker::cancelInflightRequests()
canceled = inflightRequestsToCancel->cancelAll();
progressPending();
} else if (isProgressThreadRunning()) {
utils::CallbackNotifier callbackNotifierPre{};
registerGenericPre([&callbackNotifierPre, &canceled, &inflightRequestsToCancel]() {
canceled = inflightRequestsToCancel->cancelAll();
callbackNotifierPre.set();
});
callbackNotifierPre.wait();
bool cancelSuccess = false;
for (uint64_t i = 0; i < maxAttempts && !cancelSuccess; ++i) {
utils::CallbackNotifier callbackNotifierPre{};
registerGenericPre([&callbackNotifierPre, &canceled, &inflightRequestsToCancel]() {
canceled = inflightRequestsToCancel->cancelAll();
callbackNotifierPre.set();
});
if (!callbackNotifierPre.wait(period)) continue;

utils::CallbackNotifier callbackNotifierPost{};
registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
if (!callbackNotifierPost.wait(period)) continue;

cancelSuccess = true;
}

utils::CallbackNotifier callbackNotifierPost{};
registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); });
callbackNotifierPost.wait();
if (!cancelSuccess)
ucxx_error("All attempts to cancel inflight requests failed on worker: %p, UCP handle: %p",
this,
_handle);
} else {
canceled = inflightRequestsToCancel->cancelAll();
}
Expand Down

0 comments on commit d0ea493

Please sign in to comment.