Skip to content

Commit

Permalink
fix device memory allocation strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Feb 4, 2025
1 parent 6592d20 commit afb6026
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
13 changes: 3 additions & 10 deletions cpp/src/neighbors/detail/cagra/cagra_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
#include <cuvs/neighbors/ivf_pq.hpp>
#include <cuvs/neighbors/refine.hpp>

#include <cuvs/neighbors/nn_descent.hpp>

#include <rmm/resource_ref.hpp>

#include <chrono>
Expand Down Expand Up @@ -75,10 +73,6 @@ index<T, IdxT> merge(raft::resources const& handle,

IdxT offset = 0;

// Allocate the new dataset on device
bool dataset_on_device = cuvs::neighbors::nn_descent::has_enough_device_memory(
handle, raft::make_extents<std::int64_t>(new_dataset_size, dim), sizeof(IdxT));

auto merge_dataset = [&](T* dst) {
for (auto index : indices) {
using ds_idx_type = decltype(index->data().n_rows());
Expand All @@ -99,9 +93,7 @@ index<T, IdxT> merge(raft::resources const& handle,

cagra::index_params output_index_params = params.output_index_params;

if (dataset_on_device) {
RAFT_LOG_DEBUG("cagra merge: using device memory for merged dataset");

try {
auto updated_dataset = raft::make_device_matrix<T, std::int64_t>(
handle, std::int64_t(new_dataset_size), std::int64_t(dim));

Expand All @@ -118,9 +110,10 @@ index<T, IdxT> merge(raft::resources const& handle,
std::array<int64_t, 2>{stride, 1});
merged_index.update_dataset(handle, owning_t{std::move(updated_dataset), out_layout});
}
RAFT_LOG_DEBUG("cagra merge: using device memory for merged dataset");
return merged_index;

} else {
} catch (std::bad_alloc& e) {
RAFT_LOG_DEBUG("cagra::merge: using host memory for merged dataset");

auto updated_dataset =
Expand Down
13 changes: 12 additions & 1 deletion cpp/tests/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -936,10 +936,21 @@ class AnnCagraIndexMergeTest : public ::testing::TestWithParam<AnnCagraInputs> {
protected:
void testCagra()
{
// TODO (rhdong): remove when NN Descent index building support InnerProduct. Reference
// TODO (tarang-jain): remove when NN Descent index building support InnerProduct. Reference
// issue: https://github.com/rapidsai/raft/issues/2276
if (ps.metric == InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) GTEST_SKIP();
if (ps.compression != std::nullopt) GTEST_SKIP();
// IVF_PQ and NN_DESCENT graph builds do not support BitwiseHamming
if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming &&
((!std::is_same_v<DataT, uint8_t>) ||
(ps.build_algo != graph_build_algo::ITERATIVE_CAGRA_SEARCH)))
GTEST_SKIP();
// If the dataset dimension is small and the dataset size is large, there can be a lot of
// dataset vectors that have the same distance to the query, especially in the binary Hamming
// distance, making it impossible to make a top-k ground truth.
if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming &&
(ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows))
GTEST_SKIP();

size_t queries_size = ps.n_queries * ps.k;
std::vector<IdxT> indices_Cagra(queries_size);
Expand Down

0 comments on commit afb6026

Please sign in to comment.