Skip to content

Commit

Permalink
update reset_index
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Nov 16, 2023
1 parent 91e17c2 commit a2d4575
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions cpp/include/raft/neighbors/ivf_pq_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,9 @@ void set_centers(raft::resources const& res,
device_matrix_view<const float, uint32_t> cluster_centers)
{
RAFT_EXPECTS(cluster_centers.extent(0) == index->n_lists(),
"Number of rows in cluster centers and IVF centers are different");
"Number of rows in the new centers must be equal to the number of IVF lists");
RAFT_EXPECTS(cluster_centers.extent(1) == index->dim(),
"Number of columns in cluster centers and index dim are different");
"Number of columns in the new cluster centers and index dim are different");
RAFT_EXPECTS(index->size() == 0, "Index must be empty");
ivf_pq::detail::set_centers(res, index, cluster_centers.data_handle());
}
Expand Down Expand Up @@ -795,13 +795,21 @@ void recompute_internal_state(const raft::resources& res, index<IdxT>* index)
*
* @param[in] res raft resource
* @param[in] index IVF-PQ index (passed by reference)
* @param[out] cluster_centers the new cluster centers [index.n_lists(), index.dim]
* @param[out] cluster_centers IVF cluster centers [index.n_lists(), index.dim]
*/
template <typename IdxT>
void extract_centers(raft::resources const& res, const index<IdxT>& index, float* cluster_centers)
void extract_centers(raft::resources const& res,
const index<IdxT>& index,
raft::device_matrix_view<float> cluster_centers)
{
RAFT_EXPECTS(cluster_centers.extent(0) == index.n_lists(),
"Number of rows in the output buffer for cluster centers must be equal to the "
"number of IVF lists");
RAFT_EXPECTS(
cluster_centers.extent(1) == index.dim(),
"Number of columns in the output buffer for cluster centers and index dim are different");
auto stream = resource::get_cuda_stream(res);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers,
RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data_handle(),
sizeof(float) * index.dim(),
index.centers().data_handle(),
sizeof(float) * index.dim_ext(),
Expand Down

0 comments on commit a2d4575

Please sign in to comment.