Skip to content

Commit

Permalink
Add threadsafe_wrapper implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Nov 26, 2024
1 parent c83261b commit 68765e9
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 0 deletions.
99 changes: 99 additions & 0 deletions cpp/include/cuml/experimental/ordered_mutex.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <condition_variable>
#include <mutex>
#include <queue>

namespace ML {
namespace experimental {

/* A mutex which yields to threads in the order in which they attempt to
* acquire a lock.
*
* Note that this order is somewhat ambiguously defined. If one thread has a lock on this mutex and
* several other threads simultaneously attempt to acquire it, they will do so in the order in which
* they are able to acquire the lock on the underlying raw mutex. What
* ordered_mutex ensures is that if it is locked and several threads attempt to
* acquire it in unambiguously serial fashion (i.e. one does not make the
* attempt until a previous one has released the underlying raw mutex), those
* threads will acquire the lock in the same order.
*
* In particular, this mutex is useful to ensure that a thread's acquisition of
* a lock is not indefinitely deferred by other threads' acquisitions. If N
* threads attempt to simultaneously lock the ordered_mutex, and then N-1
* threads successfully acquire it, the remaining thread is guaranteed to get
* the lock next before any of the other N-1 threads get the lock again.
*/
struct ordered_mutex {
void lock()
{
auto scoped_lock = std::unique_lock<std::mutex>{raw_mtx_};
if (locked_) {
// Another thread is using this mutex, so get in line and wait for
// another thread to notify this one to continue.
auto thread_condition = std::condition_variable{};
control_queue_.push(&thread_condition);
thread_condition.wait(scoped_lock);
} else {
// No other threads have acquired the ordered_mutex, so we will not wait
// for another thread to notify this one that it is its turn
locked_ = true;
}
}

void unlock()
{
auto scoped_lock = std::unique_lock<std::mutex>{raw_mtx_};
if (control_queue_.empty()) {
// No waiting threads, so the next thread that attempts to acquire may
// simply proceed.
locked_ = false;
} else {
// We must notify under the scoped_lock to avoid having a new thread
// acquire the raw mutex before a waiting thread gets notified.
control_queue_.front()->notify_one();
control_queue_.pop();
}
}

private:
// Use a pointer here rather than storing the object in the queue to ensure
// that the variable is not deallocated while it is being used.
std::queue<std::condition_variable*> control_queue_{};
std::mutex raw_mtx_{};
bool locked_ = false;
};

/* A scoped lock based on ordered_mutex, which will be acquired in the order in which
* threads attempt to acquire the underlying mutex */
struct ordered_lock {
explicit ordered_lock(ordered_mutex& mtx)
: mtx_{[&mtx]() {
mtx.lock();
return &mtx;
}()}
{
}

~ordered_lock() { mtx_->unlock(); }

private:
ordered_mutex* mtx_;
};

} // namespace experimental
} // namespace ML
168 changes: 168 additions & 0 deletions cpp/include/cuml/experimental/threadsafe_wrapper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuml/experimental/ordered_mutex.hpp>

#include <algorithm>
#include <atomic>
#include <memory>

namespace ML {
namespace experimental {

/* A proxy to an underlying object that holds a lock for its lifetime. This
* ensures that the underlying object cannot be accessed unless the lock has
* been acquired.
*/
template <typename T, typename L>
struct threadsafe_proxy {
// Acquire ownership of the lock on construction
threadsafe_proxy(T* wrapped, L&& lock) : wrapped_{wrapped}, lock_{std::move(lock)} {}
// Do not allow copy or move. Proxy object is intended to be used
// immediately.
threadsafe_proxy(threadsafe_proxy const&) = delete;
threadsafe_proxy(threadsafe_proxy&&) = delete;
// Access the wrapped object via -> operator
auto* operator->() { return wrapped_; }

private:
T* wrapped_;
L lock_;
};

/* This struct wraps an object which may be modified from some host threads
* but accessed without modification from others. Because multiple users can safely
* access the object simultaneously so long as it is not being modified, any
* const access to a threadsafe_wrapper<T> will acquire a lock solely to
* increment an atomic counter indicating that it is currently accessing the
* underlying object. It will then decrement that counter once the const call
* to the underlying object has been completed. Non-const access will
* acquire a lock on the same underlying mutex but not proceed with the
* non-const call until the counter reaches 0.
*
* A special lock (ordered_lock) ensures that the mutex is acquired in the
* order that threads attempt to acquire it. This ensures that
* modifying threads are not indefinitely delayed.
*
* Example usage:
*
* struct foo() {
* foo(int data) : data_{data} {}
* auto get_data() const { return data_; }
* void set_data(int new_data) { data_ = new_data; }
* private:
* int data_;
* };
*
* auto f = threadsafe_wrapper<foo>{5};
* f->set_data(6);
* f->get_data(); // Safe but inefficient. Returns 6.
* std::as_const(f)->get_data(); // Safe and efficient. Returns 6.
* std::as_const(f)->set_data(7); // Fails to compile.
*/
template <typename T>
struct threadsafe_wrapper {
template <typename... Args>
threadsafe_wrapper(Args&&... args) : wrapped{std::make_unique<T>(std::forward<Args>(args)...)}
{
}
auto operator->()
{
return threadsafe_proxy<T*, modifier_lock>{wrapped.get(), modifier_lock{mtx_}};
}
auto operator->() const
{
return threadsafe_proxy<T const*, accessor_lock>{wrapped.get(), accessor_lock{mtx_}};
}

private:
// A class for coordinating access to a resource that may be modified by some
// threads and accessed without modification by others.
class modification_mutex {
// Wait until all ongoing const access has completed and do not allow
// additional const or non-const access to begin until the modifying lock on this mutex has been
// released.
void acquire_for_modifier()
{
// Prevent any new users from incrementing work counter
lock_ = std::make_unique<ordered_lock>(mtx_);
// Wait until all work in progress is done
while (currently_using_.load() != 0)
;
std::atomic_thread_fence(std::memory_order_acquire);
}
// Allow other threads to initiate const or non-const access again
void release_from_modifier() { lock_.reset(); }
// Wait until ongoing non-const access has completed, then increment a
// counter indicating the number of threads performing const access
void acquire_for_access() const
{
auto tmp_lock = ordered_lock{mtx_};
++currently_using_;
}
// Decrement counter of the number of threads performing const access
void release_from_accessor() const
{
std::atomic_thread_fence(std::memory_order_release);
--currently_using_;
}
mutable ordered_mutex mtx_{};
mutable std::atomic<int> currently_using_{};
mutable std::unique_ptr<ordered_lock> lock_{nullptr};
friend struct modifier_lock;
friend struct accessor_lock;
};

// A lock acquired to modify the wrapped object. While this lock is acquired,
// no other thread can perform const or non-const access to the underlying
// object.
struct modifier_lock {
modifier_lock(modification_mutex& mtx)
: mtx_{[&mtx]() {
mtx.acquire_for_modifier();
return &mtx;
}()}
{
}
~modifier_lock() { mtx_->release_from_modifier(); }

private:
modification_mutex* mtx_;
};

// A lock acquired to access but not modify the wrapped object. We ensure that
// only const methods can be accessed while protected by this lock. While
// this lock is acquired, no other thread can perform non-const access, but
// other threads may perform const access.
struct accessor_lock {
accessor_lock(modification_mutex const& mtx)
: mtx_{[&mtx]() {
mtx.acquire_for_access();
return &mtx;
}()}
{
}
~accessor_lock() { mtx_->release_from_accessor(); }

private:
modification_mutex const* mtx_;
};
modification_mutex mtx_;
std::unique_ptr<T> wrapped;
};

} // namespace experimental
} // namespace ML
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ endfunction()
# - build ml_test executable -------------------------------------------------
if(all_algo)
ConfigureTest(PREFIX SG NAME LOGGER_TEST sg/logger.cpp ML_INCLUDE)
ConfigureTest(PREFIX SG NAME THREADSAFE_WRAPPER_TEST sg/experimental/threadsafe_wrapper.cpp ML_INCLUDE)
endif()

if(all_algo OR dbscan_algo)
Expand Down
71 changes: 71 additions & 0 deletions cpp/test/sg/experimental/threadsafe_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuml/experimental/threadsafe_wrapper.hpp>

#include <gtest/gtest.h>

#include <atomic>
#include <thread>

namespace ML {
namespace experimental {

struct threadsafe_test_struct {
auto access() const
{
access_counter_.fetch_add(1);
return access_counter_.fetch_sub(1) > 0 && !modification_in_progress_.load();
}

auto modify()
{
auto being_modified = modification_in_progress_.exchange(true);
return !being_modified && access_counter_.load() == 0;
}

private:
mutable std::atomic<int> access_counter_ = int{};
std::atomic<bool> modification_in_progress_ = false;
};

TEST(ThreadsafeWrapper, threadsafe_wrapper)
{
auto test_obj = threadsafe_wrapper<threadsafe_test_struct>{};
// Choose a prime number large enough to cause contention. We use a prime
// number to allow us to easily produce different patterns of access in
// each thread.
auto const num_threads = 61;
auto threads = std::vector<std::thread>{};
for (auto thread_id = 0; thread_id < num_threads; ++thread_id) {
threads.emplace_back(
[thread_id](auto& obj) {
for (auto i = 0; i < num_threads; ++i) {
if (i % (thread_id + 1) == 0) {
EXPECT(obj->modify());
} else {
EXPECT(std::as_const(obj)->access());
}
}
},
test_obj);
}
for (auto thread_id = 0; thread_id < num_threads; ++thread_id) {
threads[thread_id].join();
}
}

} // namespace experimental
} // namespace ML

0 comments on commit 68765e9

Please sign in to comment.