diff --git a/cpp/include/raft/core/pinned_container_policy.hpp b/cpp/include/raft/core/pinned_container_policy.hpp index 4870e2c5dc..b661fa8860 100644 --- a/cpp/include/raft/core/pinned_container_policy.hpp +++ b/cpp/include/raft/core/pinned_container_policy.hpp @@ -19,9 +19,9 @@ #include #ifndef RAFT_DISABLE_CUDA -#include -#include -#include +#include + +#include #else #include #endif @@ -30,20 +30,16 @@ namespace raft { #ifndef RAFT_DISABLE_CUDA /** - * @brief A thin wrapper over thrust::host_vector for implementing the pinned mdarray container - * policy. + * @brief A thin wrapper over cudaMallocHost/cudaFreeHost for implementing the pinned mdarray + * container policy. * */ template struct pinned_container { - using value_type = T; - using allocator_type = - thrust::mr::stateless_resource_allocator; + using value_type = std::remove_cv_t; private: - using underlying_container_type = thrust::host_vector; - underlying_container_type data_; + value_type* data_ = nullptr; public: using size_type = std::size_t; @@ -57,21 +53,24 @@ struct pinned_container { using iterator = pointer; using const_iterator = const_pointer; - ~pinned_container() = default; - pinned_container(pinned_container&&) noexcept = default; - pinned_container(pinned_container const& that) : data_{that.data_} {} + explicit pinned_container(std::size_t size) + { + RAFT_CUDA_TRY(cudaMallocHost(&data_, size * sizeof(value_type))); + } + ~pinned_container() noexcept + { + if (data_ != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaFreeHost(data_)); } + } - auto operator=(pinned_container const& that) -> pinned_container& + pinned_container(pinned_container&& other) { std::swap(this->data_, other.data_); } + pinned_container& operator=(pinned_container&& other) { - data_ = underlying_container_type{that.data_}; + std::swap(this->data_, other.data_); return *this; } - auto operator=(pinned_container&& that) noexcept -> pinned_container& = default; + pinned_container(pinned_container const&) = delete; // Copying disallowed: one array one owner + pinned_container& operator=(pinned_container const&) = delete; - /** - * @brief Ctor that accepts a size. - */ - explicit pinned_container(std::size_t size, allocator_type const& alloc) : data_{size, alloc} {} /** * @brief Index operator that returns a reference to the actual data. */ @@ -84,15 +83,13 @@ struct pinned_container { * @brief Index operator that returns a reference to the actual data. */ template - auto operator[](Index i) const noexcept + auto operator[](Index i) const noexcept -> const_reference { return data_[i]; } - void resize(size_type size) { data_.resize(size, data_.stream()); } - - [[nodiscard]] auto data() noexcept -> pointer { return data_.data().get(); } - [[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data().get(); } + [[nodiscard]] auto data() noexcept -> pointer { return data_; } + [[nodiscard]] auto data() const noexcept -> const_pointer { return data_; } }; /** @@ -102,7 +99,6 @@ template struct pinned_vector_policy { using element_type = ElementType; using container_type = pinned_container; - using allocator_type = typename container_type::allocator_type; using pointer = typename container_type::pointer; using const_pointer = typename container_type::const_pointer; using reference = typename container_type::reference; @@ -110,15 +106,7 @@ struct pinned_vector_policy { using accessor_policy = std::experimental::default_accessor; using const_accessor_policy = std::experimental::default_accessor; - auto create(raft::resources const&, size_t n) -> container_type - { - return container_type(n, allocator_); - } - - constexpr pinned_vector_policy() noexcept(std::is_nothrow_default_constructible_v) - : allocator_{} - { - } + auto create(raft::resources const&, size_t n) -> container_type { return container_type(n); } [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference { @@ -132,9 +120,6 @@ struct pinned_vector_policy { [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } - - private: - allocator_type allocator_; }; #else template