Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Dec 7, 2023
1 parent d241411 commit 9a446a6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 0 additions & 2 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,7 @@ void radix_topk(const T* in,
}
const IdxT buf_len = calc_buf_len<T>(len);

size_t req_aux = max_chunk_size * (sizeof(Counter<T, IdxT>) + num_buckets * sizeof(IdxT));
size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT));
size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment

rmm::device_uvector<Counter<T, IdxT>> counters(max_chunk_size, stream, mr);
rmm::device_uvector<IdxT> histograms(max_chunk_size * num_buckets, stream, mr);
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,7 @@ auto build(raft::resources const& handle,
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});

// Trainset labels are needed for training PQ codebooks
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, big_memory_resource);
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, device_memory);
auto centers_const_view = raft::make_device_matrix_view<const float, IdxT>(
cluster_centers, index.n_lists(), index.dim());
auto labels_view = raft::make_device_vector_view<uint32_t, IdxT>(labels.data(), n_rows_train);
Expand Down

0 comments on commit 9a446a6

Please sign in to comment.