Skip to content

Commit

Permalink
optimize by reviewing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jan 10, 2025
1 parent 1fe0a97 commit 7ddd5cc
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 12 deletions.
2 changes: 0 additions & 2 deletions cpp/include/raft/core/bitmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* The bitmap is interpreted as a row-major matrix, with rows and columns defined by
* the dimensions of the bitmap.
*
* @tparam bitmap_t The data type of the elements in the bitmap matrix.
* @tparam index_t The data type used for indexing the elements in the matrices.
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to raft::device_csr_matrix.
*
* @param[in] res RAFT resources for managing CUDA streams and execution policies.
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ struct bitset_view {
* // 1, 1, 1, 1];
* @endcode
*
* @tparam bitset_t The data type of the elements in the bitset matrix.
* @tparam index_t The data type used for indexing the elements in the matrices.
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to raft::device_csr_matrix.
*
* @param[in] res RAFT resources for managing CUDA streams and execution policies.
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ void bitmap_to_csr(raft::resources const& handle,
thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data());

if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
index_t nnz = 0;
nnz_t nnz = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(
&nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
&nnz, sub_nnz.data() + sub_nnz_size, sizeof(nnz_t), cudaMemcpyDeviceToHost, stream));
resource::sync_stream(handle);
csr.initialize_sparsity(nnz);
}
Expand Down
9 changes: 3 additions & 6 deletions cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ RAFT_KERNEL repeat_csr_kernel(const index_t* indptr,

__syncthreads();

int block_offset = blockIdx.x * blockDim.x;

index_t item;
int idx = block_offset + threadIdx.x;
item = (idx < nnz) ? indices[idx] : -1;
item = (global_id < nnz) ? indices[global_id] : -1;

__syncthreads();

Expand Down Expand Up @@ -144,10 +141,10 @@ void bitset_to_csr(raft::resources const& handle,
thrust::exclusive_scan(
thrust_policy, sub_nnz.data(), sub_nnz.data() + sub_nnz_size + 1, sub_nnz.data());

index_t bitset_nnz = 0;
nnz_t bitset_nnz = 0;
if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
RAFT_CUDA_TRY(cudaMemcpyAsync(
&bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
&bitset_nnz, sub_nnz.data() + sub_nnz_size, sizeof(nnz_t), cudaMemcpyDeviceToHost, stream));
resource::sync_stream(handle);
csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows());
} else {
Expand Down

0 comments on commit 7ddd5cc

Please sign in to comment.