-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Support bitset
filter for Brute Force
#560
Changes from 13 commits
1ba31da
4e30bd2
cbc5d38
3a5d4e0
8a45192
e79b1e3
8c0031a
4c53846
4a53e94
85d2dfc
5ef5bc5
f53d1ce
9beb58f
36bae13
1fcc7de
b58f2a5
6c7b583
7c4d50e
3ecccfb
6cc5059
4243fb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,8 +67,8 @@ void _search(cuvsResources_t res, | |
using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>; | ||
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>; | ||
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>; | ||
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to keep the filter immutable, don' we? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is to be compatible with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are using |
||
using prefilter_bmp_type = cuvs::core::bitmap_view<const uint32_t, int64_t>; | ||
using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>; | ||
using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>; | ||
|
||
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor); | ||
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor); | ||
|
@@ -88,7 +88,7 @@ void _search(cuvsResources_t res, | |
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr); | ||
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr); | ||
auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter( | ||
prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(), | ||
prefilter_bmp_type((uint32_t*)prefilter_mds.data_handle(), | ||
queries_mds.extent(0), | ||
index_ptr->dataset().extent(0))); | ||
cuvs::neighbors::brute_force::search( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,9 +56,13 @@ | |
|
||
#include <cstdint> | ||
#include <iostream> | ||
#include <optional> | ||
#include <set> | ||
#include <variant> | ||
|
||
namespace cuvs::neighbors::detail { | ||
|
||
using namespace cuvs::neighbors::filtering; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All done! Thanks! |
||
/** | ||
* Calculates brute force knn, using a fixed memory budget | ||
* by tiling over both the rows and columns of pairwise_distances | ||
|
@@ -82,8 +86,9 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
size_t max_col_tile_size = 0, | ||
const DistanceT* precomputed_index_norms = nullptr, | ||
const DistanceT* precomputed_search_norms = nullptr, | ||
const uint32_t* filter_bitmap = nullptr, | ||
DistanceEpilogue distance_epilogue = raft::identity_op()) | ||
const uint32_t* filter_bits = nullptr, | ||
DistanceEpilogue distance_epilogue = raft::identity_op(), | ||
FilterType filter_type = FilterType::Bitmap) | ||
{ | ||
// Figure out the number of rows/cols to tile for | ||
size_t tile_rows = 0; | ||
|
@@ -245,21 +250,23 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
} | ||
} | ||
|
||
if (filter_bitmap != nullptr) { | ||
auto distances_ptr = temp_distances.data(); | ||
auto count = thrust::make_counting_iterator<IndexType>(0); | ||
DistanceT masked_distance = select_min ? std::numeric_limits<DistanceT>::infinity() | ||
: std::numeric_limits<DistanceT>::lowest(); | ||
auto distances_ptr = temp_distances.data(); | ||
auto count = thrust::make_counting_iterator<IndexType>(0); | ||
DistanceT masked_distance = select_min ? std::numeric_limits<DistanceT>::infinity() | ||
: std::numeric_limits<DistanceT>::lowest(); | ||
|
||
if (filter_bits != nullptr) { | ||
size_t n_cols = filter_type == FilterType::Bitmap ? n : 0; | ||
thrust::for_each(raft::resource::get_thrust_policy(handle), | ||
count, | ||
count + current_query_size * current_centroid_size, | ||
[=] __device__(IndexType idx) { | ||
IndexType row = i + (idx / current_centroid_size); | ||
IndexType col = j + (idx % current_centroid_size); | ||
IndexType g_idx = row * n + col; | ||
IndexType g_idx = row * n_cols + col; | ||
IndexType item_idx = (g_idx) >> 5; | ||
uint32_t bit_idx = (g_idx)&31; | ||
uint32_t filter = filter_bitmap[item_idx]; | ||
uint32_t filter = filter_bits[item_idx]; | ||
if ((filter & (uint32_t(1) << bit_idx)) == 0) { | ||
distances_ptr[idx] = masked_distance; | ||
} | ||
|
@@ -575,12 +582,12 @@ void brute_force_search( | |
query_norms ? query_norms->data_handle() : nullptr); | ||
} | ||
|
||
template <typename T, typename IdxT, typename BitmapT, typename DistanceT = float> | ||
template <typename T, typename IdxT, typename BitsT, typename DistanceT = float> | ||
void brute_force_search_filtered( | ||
raft::resources const& res, | ||
const cuvs::neighbors::brute_force::index<T, DistanceT>& idx, | ||
raft::device_matrix_view<const T, IdxT, raft::row_major> queries, | ||
cuvs::core::bitmap_view<const BitmapT, IdxT> filter, | ||
const base_filter* filter, | ||
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors, | ||
raft::device_matrix_view<DistanceT, IdxT, raft::row_major> distances, | ||
std::optional<raft::device_vector_view<const DistanceT, IdxT>> query_norms = std::nullopt) | ||
|
@@ -601,29 +608,40 @@ void brute_force_search_filtered( | |
metric == cuvs::distance::DistanceType::CosineExpanded), | ||
"Index must has norms when using Euclidean, IP, and Cosine!"); | ||
|
||
IdxT n_queries = queries.extent(0); | ||
IdxT n_dataset = idx.dataset().extent(0); | ||
IdxT dim = idx.dataset().extent(1); | ||
IdxT k = neighbors.extent(1); | ||
IdxT n_queries = queries.extent(0); | ||
IdxT n_dataset = idx.dataset().extent(0); | ||
IdxT dim = idx.dataset().extent(1); | ||
IdxT k = neighbors.extent(1); | ||
FilterType filter_type = filter->get_filter_type(); | ||
|
||
auto stream = raft::resource::get_cuda_stream(res); | ||
|
||
// calc nnz | ||
IdxT nnz_h = 0; | ||
rmm::device_scalar<IdxT> nnz(0, stream); | ||
auto nnz_view = raft::make_device_scalar_view<IdxT>(nnz.data()); | ||
auto filter_view = | ||
raft::make_device_vector_view<const BitmapT, IdxT>(filter.data(), filter.n_elements()); | ||
IdxT size_h = n_queries * n_dataset; | ||
auto size_view = raft::make_host_scalar_view<const IdxT, IdxT>(&size_h); | ||
|
||
raft::popc(res, filter_view, size_view, nnz_view); | ||
raft::copy(&nnz_h, nnz.data(), 1, stream); | ||
std::optional<std::variant<const cuvs::core::bitmap_view<BitsT, IdxT>, | ||
const cuvs::core::bitset_view<BitsT, IdxT>>> | ||
filter_view; | ||
|
||
IdxT nnz_h = 0; | ||
float sparsity = 0.0f; | ||
|
||
const BitsT* filter_data = nullptr; | ||
|
||
if (filter_type == FilterType::Bitmap) { | ||
auto actual_filter = dynamic_cast<const bitmap_filter<BitsT, int64_t>*>(filter); | ||
filter_view.emplace(actual_filter->view()); | ||
nnz_h = actual_filter->view().count(res); | ||
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); | ||
} else if (filter_type == FilterType::Bitset) { | ||
auto actual_filter = dynamic_cast<const bitset_filter<BitsT, int64_t>*>(filter); | ||
filter_view.emplace(actual_filter->view()); | ||
nnz_h = n_queries * actual_filter->view().count(res); | ||
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); | ||
} else { | ||
RAFT_FAIL("Unsupported sample filter type"); | ||
} | ||
|
||
raft::resource::sync_stream(res, stream); | ||
float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset)); | ||
std::visit([&](const auto& actual_view) { filter_data = actual_view.data(); }, *filter_view); | ||
|
||
if (sparsity > 0.01f) { | ||
if (sparsity < 0.9f) { | ||
raft::resources stream_pool_handle(res); | ||
raft::resource::set_cuda_stream(stream_pool_handle, stream); | ||
auto idx_norm = idx.has_norms() ? const_cast<DistanceT*>(idx.norms().data_handle()) : nullptr; | ||
|
@@ -643,12 +661,12 @@ void brute_force_search_filtered( | |
0, | ||
idx_norm, | ||
nullptr, | ||
filter.data()); | ||
filter_data, | ||
raft::identity_op(), | ||
filter_type); | ||
} else { | ||
auto csr = raft::make_device_csr_matrix<DistanceT, IdxT>(res, n_queries, n_dataset, nnz_h); | ||
|
||
// fill csr | ||
raft::sparse::convert::bitmap_to_csr(res, filter, csr); | ||
std::visit([&](const auto& actual_view) { actual_view.to_csr(res, csr); }, *filter_view); | ||
|
||
// create filter csr view | ||
auto compressed_csr_view = csr.structure_view(); | ||
|
@@ -664,7 +682,11 @@ void brute_force_search_filtered( | |
auto csr_view = raft::make_device_csr_matrix_view<DistanceT, IdxT, IdxT, IdxT>( | ||
csr.get_elements().data(), compressed_csr_view); | ||
|
||
raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view); | ||
std::visit( | ||
[&](const auto& actual_view) { | ||
raft::sparse::linalg::masked_matmul(res, queries, dataset_view, actual_view, csr_view); | ||
}, | ||
*filter_view); | ||
|
||
// post process | ||
std::optional<raft::device_vector<DistanceT, IdxT>> query_norms_; | ||
|
@@ -725,29 +747,32 @@ void search(raft::resources const& res, | |
raft::device_matrix_view<const T, int64_t, LayoutT> queries, | ||
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, | ||
raft::device_matrix_view<DistT, int64_t, raft::row_major> distances, | ||
const cuvs::neighbors::filtering::base_filter& sample_filter_ref) | ||
const base_filter& sample_filter_ref) | ||
{ | ||
try { | ||
auto& sample_filter = | ||
dynamic_cast<const cuvs::neighbors::filtering::none_sample_filter&>(sample_filter_ref); | ||
auto& sample_filter = dynamic_cast<const none_sample_filter&>(sample_filter_ref); | ||
return brute_force_search<T, int64_t, DistT>(res, idx, queries, neighbors, distances); | ||
} catch (const std::bad_cast&) { | ||
} | ||
if constexpr (std::is_same_v<LayoutT, raft::col_major>) { | ||
RAFT_FAIL("filtered search isn't available with col_major queries yet"); | ||
} else { | ||
try { | ||
auto& sample_filter = | ||
dynamic_cast<const bitmap_filter<uint32_t, int64_t>&>(sample_filter_ref); | ||
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>( | ||
res, idx, queries, &sample_filter, neighbors, distances); | ||
} catch (const std::bad_cast&) { | ||
} | ||
|
||
try { | ||
auto& sample_filter = | ||
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<const uint32_t, int64_t>&>( | ||
sample_filter_ref); | ||
if constexpr (std::is_same_v<LayoutT, raft::col_major>) { | ||
RAFT_FAIL("filtered search isn't available with col_major queries yet"); | ||
} else { | ||
cuvs::core::bitmap_view<const uint32_t, int64_t> sample_filter_view = | ||
sample_filter.bitmap_view_; | ||
try { | ||
auto& sample_filter = | ||
dynamic_cast<const bitset_filter<uint32_t, int64_t>&>(sample_filter_ref); | ||
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>( | ||
res, idx, queries, sample_filter_view, neighbors, distances); | ||
res, idx, queries, &sample_filter, neighbors, distances); | ||
} catch (const std::bad_cast&) { | ||
RAFT_FAIL("Unsupported sample filter type"); | ||
} | ||
} catch (const std::bad_cast&) { | ||
RAFT_FAIL("Unsupported sample filter type"); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice no changes have been made to
brute_force.hpp
. Ideally, we'll at at least be listing out in the docs which filters are supported, right? Otherwise this is going to be very confusing for users. Also, can we set the default tobitset
filter? I suspect most users will want bitset.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve just added the comments. I believe using bitset as the default setting might not be ideal if we don't have enough input from end-users. Perhaps we should discuss this in the team group, as I noticed that the none filter is also set as the default in CAGRA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you may have misunderstood me. The none filter is fine as the default for the the search functions, but for the code example in the docs, we should make sure we use a bitset and leave bitmap to users who need it. FAISS doesn't even support a bitmap and users aren't asking for it generally. It's good to keep it exposed for users who might need it.