From 9c853663a90f53c20003a6e2e13551ed25dab619 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 14 Jan 2025 16:20:51 +0100 Subject: [PATCH] Fix sparse utilities --- .../raft/cluster/detail/connectivities.cuh | 2 +- cpp/include/raft/cluster/detail/mst.cuh | 4 +- cpp/include/raft/sparse/convert/csr.cuh | 10 ++--- .../raft/sparse/convert/detail/csr.cuh | 8 ++-- cpp/include/raft/sparse/detail/coo.cuh | 15 +++---- cpp/include/raft/sparse/detail/utils.h | 6 +-- cpp/include/raft/sparse/linalg/degree.cuh | 6 +-- .../raft/sparse/linalg/detail/degree.cuh | 20 ++++----- .../raft/sparse/linalg/detail/norm.cuh | 26 +++++------ .../raft/sparse/linalg/detail/symmetrize.cuh | 22 +++++----- cpp/include/raft/sparse/linalg/norm.cuh | 10 ++--- .../neighbors/detail/cross_component_nn.cuh | 2 +- cpp/include/raft/sparse/op/detail/filter.cuh | 43 ++++++++++--------- cpp/include/raft/sparse/op/detail/sort.h | 2 +- cpp/include/raft/sparse/op/sort.cuh | 2 +- .../raft/spatial/knn/detail/ball_cover.cuh | 2 +- cpp/test/sparse/solver/lanczos.cu | 4 +- cpp/test/sparse/symmetrize.cu | 4 +- 18 files changed, 94 insertions(+), 94 deletions(-) diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh index c527b754c3..fdb8af9171 100644 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ b/cpp/include/raft/cluster/detail/connectivities.cuh @@ -95,7 +95,7 @@ struct distance_graph_impl indptr2(m + 1, stream); raft::sparse::convert::sorted_coo_to_csr( - connected_edges.rows(), connected_edges.nnz, indptr2.data(), m + 1, stream); + connected_edges.rows(), (value_idx)connected_edges.nnz, indptr2.data(), m + 1, stream); // On the second call, we hand the MST the original colors // and the new set of edges and let it restart the optimization process @@ -204,4 +204,4 @@ void build_sorted_mst( raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream); } -}; // namespace raft::cluster::detail \ No newline at end of file +}; // namespace raft::cluster::detail diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 081192ed44..cbe20d8d3a 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -52,8 +52,8 @@ void coo_to_csr(raft::resources const& handle, * @param m: number of rows in dense matrix * @param stream: cuda stream to use */ -template -void sorted_coo_to_csr(const T* rows, int nnz, T* row_ind, int m, cudaStream_t stream) +template +void sorted_coo_to_csr(const T* rows, outT nnz, outT* row_ind, int m, cudaStream_t stream) { detail::sorted_coo_to_csr(rows, nnz, row_ind, m, stream); } @@ -65,10 +65,10 @@ void sorted_coo_to_csr(const T* rows, int nnz, T* row_ind, int m, cudaStream_t s * @param row_ind: output row indices array * @param stream: cuda stream to use */ -template -void sorted_coo_to_csr(COO* coo, int* row_ind, cudaStream_t stream) +template +void sorted_coo_to_csr(COO* coo, outT* row_ind, cudaStream_t stream) { - detail::sorted_coo_to_csr(coo->rows(), coo->nnz, row_ind, coo->n_rows, stream); + detail::sorted_coo_to_csr(coo->rows(), (outT)coo->nnz, row_ind, coo->n_rows, stream); } /** diff --git a/cpp/include/raft/sparse/convert/detail/csr.cuh b/cpp/include/raft/sparse/convert/detail/csr.cuh index a5d7de9a07..3cd01898bb 100644 --- a/cpp/include/raft/sparse/convert/detail/csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/csr.cuh @@ -84,18 +84,18 @@ void coo_to_csr(raft::resources const& handle, * @param m: number of rows in dense matrix * @param stream: cuda stream to use */ -template -void sorted_coo_to_csr(const T* rows, int nnz, T* row_ind, int m, cudaStream_t stream) +template +void sorted_coo_to_csr(const T* rows, outT nnz, outT* row_ind, int m, cudaStream_t stream) { rmm::device_uvector row_counts(m, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(row_counts.data(), 0, m * sizeof(T), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(row_counts.data(), 0, (uint64_t)m * sizeof(T), stream)); linalg::coo_degree(rows, nnz, row_counts.data(), stream); // create csr compressed row index from row counts thrust::device_ptr row_counts_d = thrust::device_pointer_cast(row_counts.data()); - thrust::device_ptr c_ind_d = thrust::device_pointer_cast(row_ind); + thrust::device_ptr c_ind_d = thrust::device_pointer_cast(row_ind); exclusive_scan(rmm::exec_policy(stream), row_counts_d, row_counts_d + m, c_ind_d); } diff --git a/cpp/include/raft/sparse/detail/coo.cuh b/cpp/include/raft/sparse/detail/coo.cuh index 9a38c11a07..c41af76243 100644 --- a/cpp/include/raft/sparse/detail/coo.cuh +++ b/cpp/include/raft/sparse/detail/coo.cuh @@ -52,7 +52,7 @@ class COO { rmm::device_uvector vals_arr; public: - Index_Type nnz; + uint64_t nnz; Index_Type n_rows; Index_Type n_cols; @@ -75,7 +75,7 @@ class COO { COO(rmm::device_uvector& rows, rmm::device_uvector& cols, rmm::device_uvector& vals, - Index_Type nnz, + uint64_t nnz, Index_Type n_rows = 0, Index_Type n_cols = 0) : rows_arr(rows), cols_arr(cols), vals_arr(vals), nnz(nnz), n_rows(n_rows), n_cols(n_cols) @@ -90,7 +90,7 @@ class COO { * @param init: initialize arrays with zeros */ COO(cudaStream_t stream, - Index_Type nnz, + uint64_t nnz, Index_Type n_rows = 0, Index_Type n_cols = 0, bool init = true) @@ -121,7 +121,7 @@ class COO { */ bool validate_size() const { - if (this->nnz < 0 || n_rows < 0 || n_cols < 0) return false; + if (this->nnz <= 0 || n_rows <= 0 || n_cols <= 0) return false; return true; } @@ -204,7 +204,7 @@ class COO { * @param init: should values be initialized to 0? * @param stream: CUDA stream to use */ - void allocate(Index_Type nnz, bool init, cudaStream_t stream) + void allocate(uint64_t nnz, bool init, cudaStream_t stream) { this->allocate(nnz, 0, init, stream); } @@ -216,7 +216,7 @@ class COO { * @param init: should values be initialized to 0? * @param stream: CUDA stream to use */ - void allocate(Index_Type nnz, Index_Type size, bool init, cudaStream_t stream) + void allocate(uint64_t nnz, Index_Type size, bool init, cudaStream_t stream) { this->allocate(nnz, size, size, init, stream); } @@ -229,8 +229,7 @@ class COO { * @param init: should values be initialized to 0? * @param stream: stream to use for init */ - void allocate( - Index_Type nnz, Index_Type n_rows, Index_Type n_cols, bool init, cudaStream_t stream) + void allocate(uint64_t nnz, Index_Type n_rows, Index_Type n_cols, bool init, cudaStream_t stream) { this->n_rows = n_rows; this->n_cols = n_cols; diff --git a/cpp/include/raft/sparse/detail/utils.h b/cpp/include/raft/sparse/detail/utils.h index 3eed74f3b4..16db863a2d 100644 --- a/cpp/include/raft/sparse/detail/utils.h +++ b/cpp/include/raft/sparse/detail/utils.h @@ -103,10 +103,10 @@ void iota_fill(value_idx* indices, value_idx nrows, value_idx ncols, cudaStream_ iota_fill_block_kernel<<>>(indices, ncols); } -template -__device__ int get_stop_idx(T row, T m, T nnz, const T* ind) +template +__device__ indT get_stop_idx(T row, T m, indT nnz, const indT* ind) { - int stop_idx = 0; + indT stop_idx = 0; if (row < (m - 1)) stop_idx = ind[row + 1]; else diff --git a/cpp/include/raft/sparse/linalg/degree.cuh b/cpp/include/raft/sparse/linalg/degree.cuh index 57c9b986b4..5da4c9a30d 100644 --- a/cpp/include/raft/sparse/linalg/degree.cuh +++ b/cpp/include/raft/sparse/linalg/degree.cuh @@ -34,7 +34,7 @@ namespace linalg { * @param stream: cuda stream to use */ template -void coo_degree(const T* rows, int nnz, T* results, cudaStream_t stream) +void coo_degree(const T* rows, uint64_t nnz, T* results, cudaStream_t stream) { detail::coo_degree<64, T>(rows, nnz, results, stream); } @@ -66,7 +66,7 @@ void coo_degree(COO* in, int* results, cudaStream_t stream) */ template void coo_degree_scalar( - const int* rows, const T* vals, int nnz, T scalar, int* results, cudaStream_t stream = 0) + const int* rows, const T* vals, uint64_t nnz, T scalar, int* results, cudaStream_t stream = 0) { detail::coo_degree_scalar<64>(rows, vals, nnz, scalar, results, stream); } @@ -120,4 +120,4 @@ void coo_degree_nz(COO* in, int* results, cudaStream_t stream) }; // end NAMESPACE sparse }; // end NAMESPACE raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/sparse/linalg/detail/degree.cuh b/cpp/include/raft/sparse/linalg/detail/degree.cuh index df31192cf7..6338b8eb00 100644 --- a/cpp/include/raft/sparse/linalg/detail/degree.cuh +++ b/cpp/include/raft/sparse/linalg/detail/degree.cuh @@ -39,10 +39,10 @@ namespace detail { * @param nnz the size of the rows array * @param results array to place results */ -template -RAFT_KERNEL coo_degree_kernel(const T* rows, int nnz, T* results) +template +RAFT_KERNEL coo_degree_kernel(const T* rows, uint64_t nnz, T* results) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t row = (blockIdx.x * TPB_X) + threadIdx.x; if (row < nnz) { atomicAdd(results + rows[row], (T)1); } } @@ -54,8 +54,8 @@ RAFT_KERNEL coo_degree_kernel(const T* rows, int nnz, T* results) * @param results: output result array * @param stream: cuda stream to use */ -template -void coo_degree(const T* rows, int nnz, T* results, cudaStream_t stream) +template +void coo_degree(const T* rows, uint64_t nnz, T* results, cudaStream_t stream) { dim3 grid_rc(raft::ceildiv(nnz, TPB_X), 1, 1); dim3 blk_rc(TPB_X, 1, 1); @@ -71,11 +71,11 @@ RAFT_KERNEL coo_degree_nz_kernel(const int* rows, const T* vals, int nnz, int* r if (row < nnz && vals[row] != 0.0) { raft::myAtomicAdd(results + rows[row], 1); } } -template +template RAFT_KERNEL coo_degree_scalar_kernel( - const int* rows, const T* vals, int nnz, T scalar, int* results) + const int* rows, const T* vals, uint64_t nnz, T scalar, int* results) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; + uint64_t row = (blockIdx.x * TPB_X) + threadIdx.x; if (row < nnz && vals[row] != scalar) { raft::myAtomicAdd(results + rows[row], 1); } } @@ -90,9 +90,9 @@ RAFT_KERNEL coo_degree_scalar_kernel( * @param results: output row counts * @param stream: cuda stream to use */ -template +template void coo_degree_scalar( - const int* rows, const T* vals, int nnz, T scalar, int* results, cudaStream_t stream = 0) + const int* rows, const T* vals, uint64_t nnz, T scalar, int* results, cudaStream_t stream = 0) { dim3 grid_rc(raft::ceildiv(nnz, TPB_X), 1, 1); dim3 blk_rc(TPB_X, 1, 1); diff --git a/cpp/include/raft/sparse/linalg/detail/norm.cuh b/cpp/include/raft/sparse/linalg/detail/norm.cuh index 3702111f83..0390fb5f69 100644 --- a/cpp/include/raft/sparse/linalg/detail/norm.cuh +++ b/cpp/include/raft/sparse/linalg/detail/norm.cuh @@ -40,15 +40,15 @@ namespace sparse { namespace linalg { namespace detail { -template +template RAFT_KERNEL csr_row_normalize_l1_kernel( // @TODO: This can be done much more parallel by // having threads in a warp compute the sum in parallel // over each row and then divide the values in parallel. - const int* ia, // csr row ex_scan (sorted by row) + const indT* ia, // csr row ex_scan (sorted by row) const T* vals, - int nnz, // array of values and number of non-zeros - int m, // num rows in csr + indT nnz, // array of values and number of non-zeros + int m, // num rows in csr T* result) { // output array @@ -57,19 +57,19 @@ RAFT_KERNEL csr_row_normalize_l1_kernel( // sum all vals_arr for row and divide each val by sum if (row < m) { - int start_idx = ia[row]; - int stop_idx = 0; + indT start_idx = ia[row]; + indT stop_idx = 0; if (row < m - 1) { stop_idx = ia[row + 1]; } else stop_idx = nnz; T sum = T(0.0); - for (int j = start_idx; j < stop_idx; j++) { + for (indT j = start_idx; j < stop_idx; j++) { sum = sum + fabs(vals[j]); } - for (int j = start_idx; j < stop_idx; j++) { + for (indT j = start_idx; j < stop_idx; j++) { if (sum != 0.0) { T val = vals[j]; result[j] = val / sum; @@ -90,11 +90,11 @@ RAFT_KERNEL csr_row_normalize_l1_kernel( * @param result: l1 normalized data array * @param stream: cuda stream to use */ -template -void csr_row_normalize_l1(const int* ia, // csr row ex_scan (sorted by row) +template +void csr_row_normalize_l1(const indT* ia, // csr row ex_scan (sorted by row) const T* vals, - int nnz, // array of values and number of non-zeros - int m, // num rows in csr + indT nnz, // array of values and number of non-zeros + int m, // num rows in csr T* result, cudaStream_t stream) { // output array @@ -232,4 +232,4 @@ void rowNormCsrCaller(const IdxType* ia, }; // end NAMESPACE detail }; // end NAMESPACE linalg }; // end NAMESPACE sparse -}; // end NAMESPACE raft \ No newline at end of file +}; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/linalg/detail/symmetrize.cuh b/cpp/include/raft/sparse/linalg/detail/symmetrize.cuh index d343bcbf66..0bacbd01e4 100644 --- a/cpp/include/raft/sparse/linalg/detail/symmetrize.cuh +++ b/cpp/include/raft/sparse/linalg/detail/symmetrize.cuh @@ -48,7 +48,7 @@ namespace detail { // TODO: value_idx param needs to be used for this once FAISS is updated to use float32 // for indices so that the index types can be uniform template -RAFT_KERNEL coo_symmetrize_kernel(int* row_ind, +RAFT_KERNEL coo_symmetrize_kernel(uint64_t* row_ind, int* rows, int* cols, T* vals, @@ -56,31 +56,31 @@ RAFT_KERNEL coo_symmetrize_kernel(int* row_ind, int* ocols, T* ovals, int n, - int cnnz, + uint64_t cnnz, Lambda reduction_op) { int row = (blockIdx.x * TPB_X) + threadIdx.x; if (row < n) { - int start_idx = row_ind[row]; // each thread processes one row - int stop_idx = get_stop_idx(row, n, cnnz, row_ind); + uint64_t start_idx = row_ind[row]; // each thread processes one row + uint64_t stop_idx = get_stop_idx(row, n, cnnz, row_ind); - int row_nnz = 0; - int out_start_idx = start_idx * 2; + int row_nnz = 0; + uint64_t out_start_idx = start_idx * 2; for (int idx = 0; idx < stop_idx - start_idx; idx++) { int cur_row = rows[idx + start_idx]; int cur_col = cols[idx + start_idx]; T cur_val = vals[idx + start_idx]; - int lookup_row = cur_col; - int t_start = row_ind[lookup_row]; // Start at - int t_stop = get_stop_idx(lookup_row, n, cnnz, row_ind); + int lookup_row = cur_col; + uint64_t t_start = row_ind[lookup_row]; // Start at + uint64_t t_stop = get_stop_idx(lookup_row, n, cnnz, row_ind); T transpose = 0.0; bool found_match = false; - for (int t_idx = t_start; t_idx < t_stop; t_idx++) { + for (uint64_t t_idx = t_start; t_idx < t_stop; t_idx++) { // If we find a match, let's get out of the loop. We won't // need to modify the transposed value, since that will be // done in a different thread. @@ -142,7 +142,7 @@ void coo_symmetrize(COO* in, ASSERT(!out->validate_mem(), "Expecting unallocated COO for output"); - rmm::device_uvector in_row_ind(in->n_rows, stream); + rmm::device_uvector in_row_ind(in->n_rows, stream); convert::sorted_coo_to_csr(in, in_row_ind.data(), stream); diff --git a/cpp/include/raft/sparse/linalg/norm.cuh b/cpp/include/raft/sparse/linalg/norm.cuh index 43dd182fe5..f90d088ee6 100644 --- a/cpp/include/raft/sparse/linalg/norm.cuh +++ b/cpp/include/raft/sparse/linalg/norm.cuh @@ -36,11 +36,11 @@ namespace linalg { * @param result: l1 normalized data array * @param stream: cuda stream to use */ -template -void csr_row_normalize_l1(const int* ia, // csr row ex_scan (sorted by row) +template +void csr_row_normalize_l1(const indT* ia, // csr row ex_scan (sorted by row) const T* vals, - int nnz, // array of values and number of non-zeros - int m, // num rows in csr + indT nnz, // array of values and number of non-zeros + int m, // num rows in csr T* result, cudaStream_t stream) { // output array @@ -104,4 +104,4 @@ void rowNormCsr(raft::resources const& handle, }; // end NAMESPACE sparse }; // end NAMESPACE raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh b/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh index a47d5a6f34..1247b91d2e 100644 --- a/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh @@ -242,7 +242,7 @@ void perform_1nn(raft::resources const& handle, // the color components. auto colors_group_idxs = raft::make_device_vector(handle, n_components + 1); raft::sparse::convert::sorted_coo_to_csr( - colors, n_rows, colors_group_idxs.data_handle(), n_components + 1, stream); + colors, (value_idx)n_rows, colors_group_idxs.data_handle(), n_components + 1, stream); auto group_idxs_view = raft::make_device_vector_view( colors_group_idxs.data_handle() + 1, n_components); diff --git a/cpp/include/raft/sparse/op/detail/filter.cuh b/cpp/include/raft/sparse/op/detail/filter.cuh index 3df85e6871..604e5bbf3f 100644 --- a/cpp/include/raft/sparse/op/detail/filter.cuh +++ b/cpp/include/raft/sparse/op/detail/filter.cuh @@ -42,27 +42,27 @@ namespace sparse { namespace op { namespace detail { -template +template RAFT_KERNEL coo_remove_scalar_kernel(const int* rows, const int* cols, const T* vals, - int nnz, + uint64_t nnz, int* crows, int* ccols, T* cvals, - int* ex_scan, - int* cur_ex_scan, + uint64_t* ex_scan, + uint64_t* cur_ex_scan, int m, T scalar) { int row = (blockIdx.x * TPB_X) + threadIdx.x; if (row < m) { - int start = cur_ex_scan[row]; - int stop = get_stop_idx(row, m, nnz, cur_ex_scan); - int cur_out_idx = ex_scan[row]; + uint64_t start = cur_ex_scan[row]; + uint64_t stop = get_stop_idx(row, m, nnz, cur_ex_scan); + uint64_t cur_out_idx = ex_scan[row]; - for (int idx = start; idx < stop; idx++) { + for (uint64_t idx = start; idx < stop; idx++) { if (vals[idx] != scalar) { crows[cur_out_idx] = rows[idx]; ccols[cur_out_idx] = cols[idx]; @@ -94,7 +94,7 @@ template void coo_remove_scalar(const int* rows, const int* cols, const T* vals, - int nnz, + uint64_t nnz, int* crows, int* ccols, T* cvals, @@ -104,19 +104,19 @@ void coo_remove_scalar(const int* rows, int n, cudaStream_t stream) { - rmm::device_uvector ex_scan(n, stream); - rmm::device_uvector cur_ex_scan(n, stream); + rmm::device_uvector ex_scan(n, stream); + rmm::device_uvector cur_ex_scan(n, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(ex_scan.data(), 0, n * sizeof(int), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(cur_ex_scan.data(), 0, n * sizeof(int), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(ex_scan.data(), 0, (uint64_t)n * sizeof(int), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(cur_ex_scan.data(), 0, (uint64_t)n * sizeof(int), stream)); - thrust::device_ptr dev_cnnz = thrust::device_pointer_cast(cnnz); - thrust::device_ptr dev_ex_scan = thrust::device_pointer_cast(ex_scan.data()); + thrust::device_ptr dev_cnnz = thrust::device_pointer_cast(cnnz); + thrust::device_ptr dev_ex_scan = thrust::device_pointer_cast(ex_scan.data()); thrust::exclusive_scan(rmm::exec_policy(stream), dev_cnnz, dev_cnnz + n, dev_ex_scan); RAFT_CUDA_TRY(cudaPeekAtLastError()); - thrust::device_ptr dev_cur_cnnz = thrust::device_pointer_cast(cur_cnnz); - thrust::device_ptr dev_cur_ex_scan = thrust::device_pointer_cast(cur_ex_scan.data()); + thrust::device_ptr dev_cur_cnnz = thrust::device_pointer_cast(cur_cnnz); + thrust::device_ptr dev_cur_ex_scan = thrust::device_pointer_cast(cur_ex_scan.data()); thrust::exclusive_scan(rmm::exec_policy(stream), dev_cur_cnnz, dev_cur_cnnz + n, dev_cur_ex_scan); RAFT_CUDA_TRY(cudaPeekAtLastError()); @@ -151,8 +151,9 @@ void coo_remove_scalar(COO* in, COO* out, T scalar, cudaStream_t stream) rmm::device_uvector row_count_nz(in->n_rows, stream); rmm::device_uvector row_count(in->n_rows, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(row_count_nz.data(), 0, in->n_rows * sizeof(int), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(row_count.data(), 0, in->n_rows * sizeof(int), stream)); + RAFT_CUDA_TRY( + cudaMemsetAsync(row_count_nz.data(), 0, (uint64_t)in->n_rows * sizeof(int), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(row_count.data(), 0, (uint64_t)in->n_rows * sizeof(int), stream)); linalg::coo_degree(in->rows(), in->nnz, row_count.data(), stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); @@ -161,8 +162,8 @@ void coo_remove_scalar(COO* in, COO* out, T scalar, cudaStream_t stream) RAFT_CUDA_TRY(cudaPeekAtLastError()); thrust::device_ptr d_row_count_nz = thrust::device_pointer_cast(row_count_nz.data()); - int out_nnz = - thrust::reduce(rmm::exec_policy(stream), d_row_count_nz, d_row_count_nz + in->n_rows); + uint64_t out_nnz = thrust::reduce( + rmm::exec_policy(stream), d_row_count_nz, d_row_count_nz + in->n_rows, (uint64_t)0); out->allocate(out_nnz, in->n_rows, in->n_cols, false, stream); diff --git a/cpp/include/raft/sparse/op/detail/sort.h b/cpp/include/raft/sparse/op/detail/sort.h index 02287c2367..7d09ebeddc 100644 --- a/cpp/include/raft/sparse/op/detail/sort.h +++ b/cpp/include/raft/sparse/op/detail/sort.h @@ -69,7 +69,7 @@ struct TupleComp { * @param stream: cuda stream to use */ template -void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream) +void coo_sort(IdxT m, IdxT n, uint64_t nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream) { auto coo_indices = thrust::make_zip_iterator(thrust::make_tuple(rows, cols)); diff --git a/cpp/include/raft/sparse/op/sort.cuh b/cpp/include/raft/sparse/op/sort.cuh index 5b8a792429..35b5fd9f31 100644 --- a/cpp/include/raft/sparse/op/sort.cuh +++ b/cpp/include/raft/sparse/op/sort.cuh @@ -38,7 +38,7 @@ namespace op { * @param stream: cuda stream to use */ template -void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream) +void coo_sort(IdxT m, IdxT n, uint64_t nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream) { detail::coo_sort(m, n, nnz, rows, cols, vals, stream); } diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index c4ca2ffa61..f436d5c740 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -161,7 +161,7 @@ void construct_landmark_1nn(raft::resources const& handle, // convert to CSR for fast lookup raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data(), - index.m, + (value_idx)index.m, index.get_R_indptr().data_handle(), index.n_landmarks + 1, resource::get_cuda_stream(handle)); diff --git a/cpp/test/sparse/solver/lanczos.cu b/cpp/test/sparse/solver/lanczos.cu index 128ab73747..23b3a7ff99 100644 --- a/cpp/test/sparse/solver/lanczos.cu +++ b/cpp/test/sparse/solver/lanczos.cu @@ -173,7 +173,7 @@ class rmat_lanczos_tests raft::make_device_vector(handle, symmetric_coo.n_rows + 1); raft::sparse::convert::sorted_coo_to_csr(symmetric_coo.rows(), - symmetric_coo.nnz, + (IndexType)symmetric_coo.nnz, row_indices.data_handle(), symmetric_coo.n_rows + 1, stream); @@ -198,7 +198,7 @@ class rmat_lanczos_tests symmetric_coo.cols(), symmetric_coo.vals(), symmetric_coo.n_rows, - symmetric_coo.nnz}; + (IndexType)symmetric_coo.nnz}; raft::sparse::solver::lanczos_solver_config config{ n_components, params.maxiter, params.restartiter, params.tol, rng.seed}; diff --git a/cpp/test/sparse/symmetrize.cu b/cpp/test/sparse/symmetrize.cu index e1a74dc40b..c3a03a942c 100644 --- a/cpp/test/sparse/symmetrize.cu +++ b/cpp/test/sparse/symmetrize.cu @@ -109,8 +109,8 @@ class SparseSymmetrizeTest rmm::device_scalar sum(stream); sum.set_value_to_zero_async(stream); - assert_symmetry<<>>( - out.rows(), out.cols(), out.vals(), out.nnz, sum.data()); + assert_symmetry<<>>( + out.rows(), out.cols(), out.vals(), (value_idx)out.nnz, sum.data()); sum_h = sum.value(stream); resource::sync_stream(handle, stream);