Skip to content

Commit

Permalink
[Feat] Support bitset filter for Brute Force (#560)
Browse files Browse the repository at this point in the history
Authors:
  - rhdong (https://github.com/rhdong)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #560
  • Loading branch information
rhdong authored Jan 31, 2025
1 parent 833f28c commit c778c88
Show file tree
Hide file tree
Showing 7 changed files with 750 additions and 99 deletions.
165 changes: 140 additions & 25 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,28 @@ auto build(raft::resources const& handle,
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`:
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* ...
* // Use the same allocator across multiple searches to reduce the number of
* // cuda memory allocations
* brute_force::search(handle, index, queries1, out_inds1, out_dists1);
* brute_force::search(handle, index, queries2, out_inds2, out_dists2);
* brute_force::search(handle, index, queries3, out_inds3, out_dists3);
* ...
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
Expand All @@ -350,9 +363,17 @@ auto build(raft::resources const& handle,
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter An optional device bitmap filter function with a `row-major` layout and
* the shape of [n_queries, index->size()], which means the filter will use the first
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand All @@ -379,15 +400,28 @@ void search(raft::resources const& handle,
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`:
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* ...
* // Use the same allocator across multiple searches to reduce the number of
* // cuda memory allocations
* brute_force::search(handle, index, queries1, out_inds1, out_dists1);
* brute_force::search(handle, index, queries2, out_inds2, out_dists2);
* brute_force::search(handle, index, queries3, out_inds3, out_dists3);
* ...
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<half>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
Expand All @@ -397,8 +431,17 @@ void search(raft::resources const& handle,
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
* given
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand All @@ -421,15 +464,51 @@ void search(raft::resources const& handle,
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand All @@ -452,15 +531,51 @@ void search(raft::resources const& handle,
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
*
* // use default index parameters
* brute_force::index_params index_params;
* // create and fill the index from a [N, D] dataset
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* // use default search parameters
* brute_force::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<int64_t, int64_t>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, int64_t> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<half>(res, n_queries, k);
* auto filter = filtering::bitset_filter(removed_indices_bitset.view());
* brute_force::search(res, search_params, index, queries, neighbors, distances, filter);
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
* @param[in] sample_filter An optional device filter that restricts which dataset elements should
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
Expand Down
34 changes: 29 additions & 5 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cstdint>
#include <cuvs/distance/distance.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -456,8 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value;

namespace filtering {

enum class FilterType { None, Bitmap, Bitset };

struct base_filter {
virtual ~base_filter() = default;
virtual ~base_filter() = default;
virtual FilterType get_filter_type() const = 0;
};

/* A filter that filters nothing. This is the default behavior. */
Expand All @@ -475,6 +479,8 @@ struct none_sample_filter : public base_filter {
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::None; }
};

/**
Expand Down Expand Up @@ -513,15 +519,24 @@ struct ivf_to_sample_filter {
*/
template <typename bitmap_t, typename index_t>
struct bitmap_filter : public base_filter {
using view_t = cuvs::core::bitmap_view<bitmap_t, index_t>;

// View of the bitset to use as a filter
const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_view_;
const view_t bitmap_view_;

bitmap_filter(const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_for_filtering);
bitmap_filter(const view_t bitmap_for_filtering);
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::Bitmap; }

view_t view() const { return bitmap_view_; }

template <typename csr_matrix_t>
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
Expand All @@ -532,15 +547,24 @@ struct bitmap_filter : public base_filter {
*/
template <typename bitset_t, typename index_t>
struct bitset_filter : public base_filter {
using view_t = cuvs::core::bitset_view<bitset_t, index_t>;

// View of the bitset to use as a filter
const cuvs::core::bitset_view<bitset_t, index_t> bitset_view_;
const view_t bitset_view_;

bitset_filter(const cuvs::core::bitset_view<bitset_t, index_t> bitset_for_filtering);
bitset_filter(const view_t bitset_for_filtering);
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::Bitset; }

view_t view() const { return bitset_view_; }

template <typename csr_matrix_t>
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void _search(cuvsResources_t res,
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
using prefilter_bmp_type = cuvs::core::bitmap_view<const uint32_t, int64_t>;
using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>;
using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>;

auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
Expand All @@ -85,14 +85,14 @@ void _search(cuvsResources_t res,
distances_mds,
cuvs::neighbors::filtering::none_sample_filter{});
} else if (prefilter.type == BITMAP) {
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter(
prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(),
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
const auto prefilter = cuvs::neighbors::filtering::bitmap_filter(
prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(),
queries_mds.extent(0),
index_ptr->dataset().extent(0)));
cuvs::neighbors::brute_force::search(
*res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter_view);
*res_ptr, params, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter);
} else {
RAFT_FAIL("Unsupported prefilter type: BITSET");
}
Expand Down
Loading

0 comments on commit c778c88

Please sign in to comment.