Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/branch-25.02' into rhdong/bitset…
Browse files Browse the repository at this point in the history
…-to-csr-dev
  • Loading branch information
rhdong committed Jan 13, 2025
2 parents 7ddd5cc + 5c826d7 commit 9459d78
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 39 deletions.
24 changes: 20 additions & 4 deletions cpp/include/raft/core/bitmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* @param bitmap_ptr Device raw pointer
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
* @param original_nbits Original number of bits used when the bitmap was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitmap was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols), rows_(rows), cols_(cols)
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr,
index_t rows,
index_t cols,
index_t original_nbits = 0)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols, original_nbits),
rows_(rows),
cols_(cols)
{
}

Expand All @@ -65,11 +74,18 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* @param bitmap_span Device vector view of the bitmap
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
* @param original_nbits Original number of bits used when the bitmap was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitmap was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t rows,
index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols), rows_(rows), cols_(cols)
index_t cols,
index_t original_nbits = 0)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols, original_nbits),
rows_(rows),
cols_(cols)
{
}

Expand Down
53 changes: 45 additions & 8 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,41 @@

namespace raft::core {

template <typename index_t>
_RAFT_HOST_DEVICE void inline compute_original_nbits_position(const index_t original_nbits,
const index_t nbits,
const index_t sample_index,
index_t& new_bit_index,
index_t& new_bit_offset)
{
const index_t original_bit_index = sample_index / original_nbits;
const index_t original_bit_offset = sample_index % original_nbits;
new_bit_index = original_bit_index * original_nbits / nbits;
new_bit_offset = 0;
if (original_nbits > nbits) {
new_bit_index += original_bit_offset / nbits;
new_bit_offset = original_bit_offset % nbits;
} else {
index_t ratio = nbits / original_nbits;
new_bit_offset += (original_bit_index % ratio) * original_nbits;
new_bit_offset += original_bit_offset % nbits;
}
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE inline bool bitset_view<bitset_t, index_t>::test(const index_t sample_index) const
{
const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size];
const index_t bit_index = sample_index % bitset_element_size;
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
const index_t nbits = sizeof(bitset_t) * 8;
index_t bit_index = 0;
index_t bit_offset = 0;
if (original_nbits_ == 0 || nbits == original_nbits_) {
bit_index = sample_index / bitset_element_size;
bit_offset = sample_index % bitset_element_size;
} else {
compute_original_nbits_position(original_nbits_, nbits, sample_index, bit_index, bit_offset);
}
const bitset_t bit_element = bitset_ptr_[bit_index];
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_offset)) != 0;
return is_bit_set;
}

Expand All @@ -52,14 +81,22 @@ template <typename bitset_t, typename index_t>
_RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index,
bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
const index_t nbits = sizeof(bitset_t) * 8;
index_t bit_index = 0;
index_t bit_offset = 0;

if (original_nbits_ == 0 || nbits == original_nbits_) {
bit_index = sample_index / bitset_element_size;
bit_offset = sample_index % bitset_element_size;
} else {
compute_original_nbits_position(original_nbits_, nbits, sample_index, bit_index, bit_offset);
}
const bitset_t bitmask = bitset_t{1} << bit_offset;
if (set_value) {
atomicOr(bitset_ptr_ + bit_element, bitmask);
atomicOr(bitset_ptr_ + bit_index, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr_ + bit_element, bitmask2);
atomicAnd(bitset_ptr_ + bit_index, bitmask2);
}
}

Expand Down
34 changes: 30 additions & 4 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,38 @@ template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset_view {
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
/**
* @brief Create a bitset view from a device pointer to the bitset.
*
* @param bitset_ptr Device pointer to the bitset
* @param bitset_len Number of bits in the bitset
* @param original_nbits Original number of bits used when the bitset was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitset was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr,
index_t bitset_len,
index_t original_nbits = 0)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}, original_nbits_{original_nbits}
{
}
/**
* @brief Create a bitset view from a device vector view of the bitset.
*
* @param bitset_span Device vector view of the bitset
* @param bitset_len Number of bits in the bitset
* @param original_nbits Original number of bits used when the bitset was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitset was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitset_view(raft::device_vector_view<bitset_t, index_t> bitset_span,
index_t bitset_len)
: bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len}
index_t bitset_len,
index_t original_nbits = 0)
: bitset_ptr_{bitset_span.data_handle()},
bitset_len_{bitset_len},
original_nbits_{original_nbits}
{
}
/**
Expand Down Expand Up @@ -180,6 +199,12 @@ struct bitset_view {
return (bitset_len + bits_per_element - 1) / bits_per_element;
}

/**
* @brief Get the original number of bits of the bitset.
*/
auto get_original_nbits() const -> index_t { return original_nbits_; }
void set_original_nbits(index_t original_nbits) { original_nbits_ = original_nbits; }

/**
* @brief Converts to a Compressed Sparse Row (CSR) format matrix.
*
Expand Down Expand Up @@ -246,6 +271,7 @@ struct bitset_view {
private:
bitset_t* bitset_ptr_;
index_t bitset_len_;
index_t original_nbits_;
};

/**
Expand Down
14 changes: 9 additions & 5 deletions cpp/include/raft/sparse/detail/coo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class COO {
* @param n_rows: number of rows in the dense matrix
* @param n_cols: number of columns in the dense matrix
*/
void setSize(int n_rows, int n_cols)
void setSize(Index_Type n_rows, Index_Type n_cols)
{
this->n_rows = n_rows;
this->n_cols = n_cols;
Expand All @@ -192,7 +192,7 @@ class COO {
* @brief Set the number of rows and cols for a square dense matrix
* @param n: number of rows and cols
*/
void setSize(int n)
void setSize(Index_Type n)
{
this->n_rows = n;
this->n_cols = n;
Expand All @@ -204,7 +204,10 @@ class COO {
* @param init: should values be initialized to 0?
* @param stream: CUDA stream to use
*/
void allocate(int nnz, bool init, cudaStream_t stream) { this->allocate(nnz, 0, init, stream); }
void allocate(Index_Type nnz, bool init, cudaStream_t stream)
{
this->allocate(nnz, 0, init, stream);
}

/**
* @brief Allocate the underlying arrays
Expand All @@ -213,7 +216,7 @@ class COO {
* @param init: should values be initialized to 0?
* @param stream: CUDA stream to use
*/
void allocate(int nnz, int size, bool init, cudaStream_t stream)
void allocate(Index_Type nnz, Index_Type size, bool init, cudaStream_t stream)
{
this->allocate(nnz, size, size, init, stream);
}
Expand All @@ -226,7 +229,8 @@ class COO {
* @param init: should values be initialized to 0?
* @param stream: stream to use for init
*/
void allocate(int nnz, int n_rows, int n_cols, bool init, cudaStream_t stream)
void allocate(
Index_Type nnz, Index_Type n_rows, Index_Type n_cols, bool init, cudaStream_t stream)
{
this->n_rows = n_rows;
this->n_cols = n_cols;
Expand Down
21 changes: 12 additions & 9 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ static int lanczosRestart(raft::resources const& handle,
value_type_t* shifts_host;

// Orthonormal matrix for similarity transform
value_type_t* V_dev = work_dev + n * iter;
value_type_t* V_dev = work_dev + (size_t)n * (size_t)iter;

// -------------------------------------------------------
// Implementation
Expand All @@ -641,7 +641,7 @@ static int lanczosRestart(raft::resources const& handle,
// std::cout <<std::endl;

// Initialize similarity transform with identity matrix
memset(V_host, 0, iter * iter * sizeof(value_type_t));
memset(V_host, 0, (size_t)iter * (size_t)iter * (size_t)sizeof(value_type_t));
for (i = 0; i < iter; ++i)
V_host[IDX(i, i, iter)] = 1;

Expand Down Expand Up @@ -679,8 +679,11 @@ static int lanczosRestart(raft::resources const& handle,
WARNING("error in implicitly shifted QR algorithm");

// Obtain new residual
RAFT_CUDA_TRY(cudaMemcpyAsync(
V_dev, V_host, iter * iter * sizeof(value_type_t), cudaMemcpyHostToDevice, stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(V_dev,
V_host,
(size_t)iter * (size_t)iter * (size_t)sizeof(value_type_t),
cudaMemcpyHostToDevice,
stream));

beta_host[iter - 1] = beta_host[iter - 1] * V_host[IDX(iter - 1, iter_new - 1, iter)];
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(cublas_h,
Expand Down Expand Up @@ -716,7 +719,7 @@ static int lanczosRestart(raft::resources const& handle,

RAFT_CUDA_TRY(cudaMemcpyAsync(lanczosVecs_dev,
work_dev,
n * iter_new * sizeof(value_type_t),
(size_t)n * (size_t)iter_new * (size_t)sizeof(value_type_t),
cudaMemcpyDeviceToDevice,
stream));

Expand Down Expand Up @@ -1045,10 +1048,10 @@ int computeSmallestEigenvectors(
unsigned long long seed = 1234567)
{
// Matrix dimension
index_type_t n = A.nrows_;
size_t n = A.nrows_;

// Check that parameters are valid
RAFT_EXPECTS(nEigVecs > 0 && nEigVecs <= n, "Invalid number of eigenvectors.");
RAFT_EXPECTS(nEigVecs > 0 && (size_t)nEigVecs <= n, "Invalid number of eigenvectors.");
RAFT_EXPECTS(restartIter > 0, "Invalid restartIter.");
RAFT_EXPECTS(tol > 0, "Invalid tolerance.");
RAFT_EXPECTS(maxIter >= nEigVecs, "Invalid maxIter.");
Expand Down Expand Up @@ -1395,10 +1398,10 @@ int computeLargestEigenvectors(
unsigned long long seed = 123456)
{
// Matrix dimension
index_type_t n = A.nrows_;
size_t n = A.nrows_;

// Check that parameters are valid
RAFT_EXPECTS(nEigVecs > 0 && nEigVecs <= n, "Invalid number of eigenvectors.");
RAFT_EXPECTS(nEigVecs > 0 && (size_t)nEigVecs <= n, "Invalid number of eigenvectors.");
RAFT_EXPECTS(restartIter > 0, "Invalid restartIter.");
RAFT_EXPECTS(tol > 0, "Invalid tolerance.");
RAFT_EXPECTS(maxIter >= nEigVecs, "Invalid maxIter.");
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/spectral/detail/matrix_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
// =========================================================

// Get index of matrix entry
#define IDX(i, j, lda) ((i) + (j) * (lda))
#define IDX(i, j, lda) ((size_t)(i) + (j) * (lda))

namespace raft {
namespace spectral {
namespace matrix {
namespace detail {

using size_type = int; // for now; TODO: move it in appropriate header
using size_type = size_t; // for now; TODO: move it in appropriate header

// Apply diagonal matrix to vector:
//
Expand Down Expand Up @@ -326,7 +326,7 @@ struct laplacian_matrix_t : sparse_matrix_t<index_type, value_type> {
raft_handle, row_offsets, col_indices, values, nrows, nnz),
diagonal_(raft_handle, nrows)
{
vector_t<value_type> ones{raft_handle, nrows};
vector_t<value_type> ones{raft_handle, (size_t)nrows};
ones.fill(1.0);
sparse_matrix_t<index_type, value_type>::mv(1, ones.raw(), 0, diagonal_.raw());
}
Expand All @@ -341,7 +341,7 @@ struct laplacian_matrix_t : sparse_matrix_t<index_type, value_type> {
csr_m.nnz_),
diagonal_(raft_handle, csr_m.nrows_)
{
vector_t<value_type> ones{raft_handle, csr_m.nrows_};
vector_t<value_type> ones{raft_handle, (size_t)csr_m.nrows_};
ones.fill(1.0);
sparse_matrix_t<index_type, value_type>::mv(1, ones.raw(), 0, diagonal_.raw());
}
Expand Down
Loading

0 comments on commit 9459d78

Please sign in to comment.