Skip to content

Commit

Permalink
reorder comments and remove filtering namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jan 30, 2025
1 parent 6cc5059 commit 4243fb4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
24 changes: 12 additions & 12 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,11 @@ auto build(raft::resources const& handle,
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
* 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
Expand Down Expand Up @@ -435,11 +435,11 @@ void search(raft::resources const& handle,
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
* 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
Expand Down Expand Up @@ -502,11 +502,11 @@ void search(raft::resources const& handle,
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
* 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
Expand Down Expand Up @@ -569,11 +569,11 @@ void search(raft::resources const& handle,
* be considered for each query.
*
* - Supports two types of filters:
* 1. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* 1. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
* 2. **Bitmap Filter**: A per-query filter with a logical shape of `[n_queries, index->size()]`,
* where each bit indicates whether a specific dataset element should be considered for a
* particular query. (1 for inclusion, 0 for exclusion).
* 2. **Bitset Filter**: A shared filter where each bit corresponds to a dataset element.
* All queries share the same filter, with a logical shape of `[1, index->size()]`.
*
* - The default value is `none_sample_filter`, which applies no filtering.
*/
Expand Down
39 changes: 22 additions & 17 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@

namespace cuvs::neighbors::detail {

using namespace cuvs::neighbors::filtering;
/**
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
Expand All @@ -88,7 +87,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
const DistanceT* precomputed_search_norms = nullptr,
const uint32_t* filter_bits = nullptr,
DistanceEpilogue distance_epilogue = raft::identity_op(),
FilterType filter_type = FilterType::Bitmap)
cuvs::neighbors::filtering::FilterType filter_type =
cuvs::neighbors::filtering::FilterType::Bitmap)
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -256,7 +256,7 @@ void tiled_brute_force_knn(const raft::resources& handle,
: std::numeric_limits<DistanceT>::lowest();

if (filter_bits != nullptr) {
size_t n_cols = filter_type == FilterType::Bitmap ? n : 0;
size_t n_cols = filter_type == cuvs::neighbors::filtering::FilterType::Bitmap ? n : 0;
thrust::for_each(raft::resource::get_thrust_policy(handle),
count,
count + current_query_size * current_centroid_size,
Expand Down Expand Up @@ -587,7 +587,7 @@ void brute_force_search_filtered(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<T, DistanceT>& idx,
raft::device_matrix_view<const T, IdxT, raft::row_major> queries,
const base_filter* filter,
const cuvs::neighbors::filtering::base_filter* filter,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors,
raft::device_matrix_view<DistanceT, IdxT, raft::row_major> distances,
std::optional<raft::device_vector_view<const DistanceT, IdxT>> query_norms = std::nullopt)
Expand All @@ -608,11 +608,11 @@ void brute_force_search_filtered(
metric == cuvs::distance::DistanceType::CosineExpanded),
"Index must has norms when using Euclidean, IP, and Cosine!");

IdxT n_queries = queries.extent(0);
IdxT n_dataset = idx.dataset().extent(0);
IdxT dim = idx.dataset().extent(1);
IdxT k = neighbors.extent(1);
FilterType filter_type = filter->get_filter_type();
IdxT n_queries = queries.extent(0);
IdxT n_dataset = idx.dataset().extent(0);
IdxT dim = idx.dataset().extent(1);
IdxT k = neighbors.extent(1);
cuvs::neighbors::filtering::FilterType filter_type = filter->get_filter_type();

auto stream = raft::resource::get_cuda_stream(res);

Expand All @@ -625,13 +625,15 @@ void brute_force_search_filtered(

const BitsT* filter_data = nullptr;

if (filter_type == FilterType::Bitmap) {
auto actual_filter = dynamic_cast<const bitmap_filter<BitsT, int64_t>*>(filter);
if (filter_type == cuvs::neighbors::filtering::FilterType::Bitmap) {
auto actual_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<BitsT, int64_t>*>(filter);
filter_view.emplace(actual_filter->view());
nnz_h = actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
} else if (filter_type == FilterType::Bitset) {
auto actual_filter = dynamic_cast<const bitset_filter<BitsT, int64_t>*>(filter);
} else if (filter_type == cuvs::neighbors::filtering::FilterType::Bitset) {
auto actual_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitset_filter<BitsT, int64_t>*>(filter);
filter_view.emplace(actual_filter->view());
nnz_h = n_queries * actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
Expand Down Expand Up @@ -747,10 +749,11 @@ void search(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, LayoutT> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<DistT, int64_t, raft::row_major> distances,
const base_filter& sample_filter_ref)
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
{
try {
auto& sample_filter = dynamic_cast<const none_sample_filter&>(sample_filter_ref);
auto& sample_filter =
dynamic_cast<const cuvs::neighbors::filtering::none_sample_filter&>(sample_filter_ref);
return brute_force_search<T, int64_t, DistT>(res, idx, queries, neighbors, distances);
} catch (const std::bad_cast&) {
}
Expand All @@ -759,15 +762,17 @@ void search(raft::resources const& res,
} else {
try {
auto& sample_filter =
dynamic_cast<const bitmap_filter<uint32_t, int64_t>&>(sample_filter_ref);
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<uint32_t, int64_t>&>(
sample_filter_ref);
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>(
res, idx, queries, &sample_filter, neighbors, distances);
} catch (const std::bad_cast&) {
}

try {
auto& sample_filter =
dynamic_cast<const bitset_filter<uint32_t, int64_t>&>(sample_filter_ref);
dynamic_cast<const cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>&>(
sample_filter_ref);
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>(
res, idx, queries, &sample_filter, neighbors, distances);
} catch (const std::bad_cast&) {
Expand Down

0 comments on commit 4243fb4

Please sign in to comment.