From ea9c77c838f07fef84000cb7964ea262c8f074df Mon Sep 17 00:00:00 2001 From: Anton Rybochkin <79331145+arybochkin@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:19:22 +0000 Subject: [PATCH] [#25041] docdb: Update HNSW lib to support custom vector label Summary: The change extends HNSW lib to support a custom vector label (via template argument). Also hnswib_wrapper is updated accordingly to use VectorId as vector label directly. Jira: DB-14175 Test Plan: Jenkins Reviewers: sergei, slingam Reviewed By: sergei Subscribers: ybase Differential Revision: https://phorge.dev.yugabyte.com/D41440 --- .../hnswlib/hnswlib/bruteforce.h | 33 ++++---- .../hnswlib/hnswlib/hnswalg.h | 84 +++++++++---------- .../hnswlib/hnswlib/hnswlib.h | 34 ++++---- .../hnswlib/hnswlib/stop_condition.h | 20 ++--- src/yb/vector_index/hnswlib_wrapper.cc | 48 ++--------- 5 files changed, 96 insertions(+), 123 deletions(-) diff --git a/src/inline-thirdparty/hnswlib/hnswlib/bruteforce.h b/src/inline-thirdparty/hnswlib/hnswlib/bruteforce.h index 76043434278a..d9722433f35c 100644 --- a/src/inline-thirdparty/hnswlib/hnswlib/bruteforce.h +++ b/src/inline-thirdparty/hnswlib/hnswlib/bruteforce.h @@ -6,8 +6,8 @@ #include namespace hnswlib { -template -class BruteforceSearch : public AlgorithmInterface { +template +class BruteforceSearch : public AlgorithmInterface { public: char *data_; size_t maxelements_; @@ -19,7 +19,7 @@ class BruteforceSearch : public AlgorithmInterface { void *dist_func_param_; std::mutex index_lock; - std::unordered_map dict_external_to_internal; + std::unordered_map dict_external_to_internal; BruteforceSearch(SpaceInterface *s) @@ -48,7 +48,7 @@ class BruteforceSearch : public AlgorithmInterface { data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); + size_per_element_ = data_size_ + sizeof(label_t); data_ = (char *) malloc(maxElements * size_per_element_); if (data_ == nullptr) throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); @@ -61,7 +61,7 @@ class BruteforceSearch : public AlgorithmInterface { } - void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { + void addPoint(const void *datapoint, label_t label, bool replace_deleted = false) { int idx; { std::unique_lock lock(index_lock); @@ -78,12 +78,12 @@ class BruteforceSearch : public AlgorithmInterface { cur_element_count++; } } - memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(label_t)); memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); } - void removePoint(labeltype cur_external) { + void removePoint(label_t cur_external) { std::unique_lock lock(index_lock); auto found = dict_external_to_internal.find(cur_external); @@ -94,23 +94,25 @@ class BruteforceSearch : public AlgorithmInterface { dict_external_to_internal.erase(found); size_t cur_c = found->second; - labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + label_t label; + memcpy(&label, data_ + size_per_element_ * (cur_element_count-1) + data_size_, sizeof(label_t)); dict_external_to_internal[label] = cur_c; memcpy(data_ + size_per_element_ * cur_c, data_ + size_per_element_ * (cur_element_count-1), - data_size_+sizeof(labeltype)); + data_size_+sizeof(label_t)); cur_element_count--; } - std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { assert(k <= cur_element_count); - std::priority_queue> topResults; + std::priority_queue> topResults; if (cur_element_count == 0) return topResults; for (size_t i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + label_t label; + memcpy(&label, data_ + size_per_element_ * i + data_size_, sizeof(label_t)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.emplace(dist, label); } @@ -119,7 +121,8 @@ class BruteforceSearch : public AlgorithmInterface { for (int i = k; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { - labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + label_t label; + memcpy(&label, data_ + size_per_element_ * i + data_size_, sizeof(label_t)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.emplace(dist, label); } @@ -160,7 +163,7 @@ class BruteforceSearch : public AlgorithmInterface { data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); + size_per_element_ = data_size_ + sizeof(label_t); data_ = (char *) malloc(maxelements_ * size_per_element_); if (data_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); diff --git a/src/inline-thirdparty/hnswlib/hnswlib/hnswalg.h b/src/inline-thirdparty/hnswlib/hnswlib/hnswalg.h index eb23a143b3b9..0b7e3f7ef6bc 100644 --- a/src/inline-thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/src/inline-thirdparty/hnswlib/hnswlib/hnswalg.h @@ -15,7 +15,7 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; -template +template class VectorIterator; struct Stats { @@ -43,8 +43,8 @@ struct InternalParameters { size_t bytes_per_vector = 0; }; -template -class HierarchicalNSW : public AlgorithmInterface { +template +class HierarchicalNSW : public AlgorithmInterface { public: static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; static const unsigned char DELETE_MARK = 0x01; @@ -87,7 +87,7 @@ class HierarchicalNSW : public AlgorithmInterface { void *dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ - std::unordered_map label_lookup_; + std::unordered_map label_lookup_; std::default_random_engine level_generator_; std::default_random_engine update_probability_generator_; @@ -101,16 +101,15 @@ class HierarchicalNSW : public AlgorithmInterface { std::unordered_set deleted_elements; // contains internal ids of deleted elements // Function to return the begin iterator - VectorIterator vectors_begin() { - return VectorIterator(0, this); + auto vectors_begin() { + return VectorIterator(0, this); } // Function to return the end iterator - VectorIterator vectors_end() { + auto vectors_end() { return VectorIterator(cur_element_count, this); } - HierarchicalNSW(SpaceInterface *s) { } @@ -159,7 +158,7 @@ class HierarchicalNSW : public AlgorithmInterface { update_probability_generator_.seed(random_seed + 1); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(label_t); offsetData_ = size_links_level0_; label_offset_ = size_links_level0_ + data_size_; offsetLevel0_ = 0; @@ -216,27 +215,28 @@ class HierarchicalNSW : public AlgorithmInterface { } - inline std::mutex& getLabelOpMutex(labeltype label) const { + inline std::mutex& getLabelOpMutex(label_t label) const { // calculate hash - size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + size_t lock_id = std::hash{}(label) & (MAX_LABEL_OPERATION_LOCKS - 1); return label_op_locks_[lock_id]; } - inline labeltype getExternalLabel(tableint internal_id) const { - labeltype return_label; - memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + inline label_t getExternalLabel(tableint internal_id) const { + label_t return_label; + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(label_t)); return return_label; } - inline void setExternalLabel(tableint internal_id, labeltype label) const { - memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + inline void setExternalLabel(tableint internal_id, label_t label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(label_t)); } - inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + // NB! Dereferencing the returned pointer is an undefined behavior, mem copy must be used instead. + inline label_t *getExternalLabeLp(tableint internal_id) const { + return (label_t *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); } @@ -353,8 +353,8 @@ class HierarchicalNSW : public AlgorithmInterface { tableint ep_id, const void *data_point, size_t ef, - BaseFilterFunctor* isIdAllowed = nullptr, - BaseSearchStopCondition* stop_condition = nullptr) const { + BaseFilterFunctor* isIdAllowed = nullptr, + BaseSearchStopCondition* stop_condition = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -726,7 +726,7 @@ class HierarchicalNSW : public AlgorithmInterface { void saveIndex(const std::string &location) override { std::ofstream output(location, std::ios::binary); std::streampos position; - + // TODO(vector-index): we may want to store sizeof(label_t) to make sure an index can be loaded. writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); writeBinaryPOD(output, cur_element_count); @@ -864,7 +864,7 @@ class HierarchicalNSW : public AlgorithmInterface { template - std::vector getDataByLabel(labeltype label) const { + std::vector getDataByLabel(label_t label) const { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); @@ -891,7 +891,7 @@ class HierarchicalNSW : public AlgorithmInterface { /* * Marks an element with the given label deleted, does NOT really change the current graph. */ - void markDelete(labeltype label) { + void markDelete(label_t label) { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); @@ -933,7 +933,7 @@ class HierarchicalNSW : public AlgorithmInterface { * Note: the method is not safe to use when replacement of deleted elements is enabled, * because elements marked as deleted can be completely removed by addPoint */ - void unmarkDelete(labeltype label) { + void unmarkDelete(label_t label) { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); @@ -992,7 +992,7 @@ class HierarchicalNSW : public AlgorithmInterface { * Adds point. Updates the point if it is already in the index. * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point */ - void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) override { + void addPoint(const void *data_point, label_t label, bool replace_deleted = false) override { if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); } @@ -1019,7 +1019,7 @@ class HierarchicalNSW : public AlgorithmInterface { addPoint(data_point, label, -1); } else { // we assume that there are no concurrent operations on deleted element - labeltype label_replaced = getExternalLabel(internal_id_replaced); + label_t label_replaced = getExternalLabel(internal_id_replaced); setExternalLabel(internal_id_replaced, label); std::unique_lock lock_table(label_lookup_lock); @@ -1191,7 +1191,7 @@ class HierarchicalNSW : public AlgorithmInterface { } - tableint addPoint(const void *data_point, labeltype label, int level) { + tableint addPoint(const void *data_point, label_t label, int level) { tableint cur_c = 0; { // Checking if the element with the same label already exists @@ -1240,7 +1240,7 @@ class HierarchicalNSW : public AlgorithmInterface { memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); // Initialisation of the data and label - memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getExternalLabeLp(cur_c), &label, sizeof(label_t)); memcpy(getDataByInternalId(cur_c), data_point, data_size_); if (curlevel) { @@ -1307,9 +1307,9 @@ class HierarchicalNSW : public AlgorithmInterface { } - std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const override { - std::priority_queue> result; + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const override { + std::priority_queue> result; if (cur_element_count == 0) return result; tableint currObj = enterpoint_node_; @@ -1357,19 +1357,19 @@ class HierarchicalNSW : public AlgorithmInterface { } while (top_candidates.size() > 0) { std::pair rez = top_candidates.top(); - result.push(std::pair(rez.first, getExternalLabel(rez.second))); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); top_candidates.pop(); } return result; } - std::vector> + std::vector> searchStopConditionClosest( const void *query_data, - BaseSearchStopCondition& stop_condition, - BaseFilterFunctor* isIdAllowed = nullptr) const { - std::vector> result; + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + std::vector> result; if (cur_element_count == 0) return result; tableint currObj = enterpoint_node_; @@ -1476,7 +1476,7 @@ class HierarchicalNSW : public AlgorithmInterface { return {}; } max_level = std::min(maxlevel_, max_level); - size_t node_head_bytes = size_links_level0_ + sizeof(labeltype); // Node header size + size_t node_head_bytes = size_links_level0_ + sizeof(label_t); // Node header size // Iterate through all the elements auto num_elements = cur_element_count.load(std::memory_order_seq_cst); @@ -1523,15 +1523,15 @@ class HierarchicalNSW : public AlgorithmInterface { }; // Define an iterator class for the stored vectors -template +template class VectorIterator { public: // Constructor for the iterator - VectorIterator(tableint internal_id, HierarchicalNSW * outer) - : curr_internal_id_(internal_id),outer_(outer) {} + VectorIterator(tableint internal_id, HierarchicalNSW* outer) + : curr_internal_id_(internal_id), outer_(outer) {} // Dereference operator to access the vector data - std::pair operator*() const { + std::pair operator*() const { return std::make_pair( outer_->getDataByInternalId(curr_internal_id_), outer_->getExternalLabel(curr_internal_id_)); @@ -1550,7 +1550,7 @@ class VectorIterator { private: tableint curr_internal_id_; - HierarchicalNSW * outer_; + HierarchicalNSW* outer_; }; } // namespace hnswlib diff --git a/src/inline-thirdparty/hnswlib/hnswlib/hnswlib.h b/src/inline-thirdparty/hnswlib/hnswlib/hnswlib.h index e118f1990da6..fcd0f0116105 100644 --- a/src/inline-thirdparty/hnswlib/hnswlib/hnswlib.h +++ b/src/inline-thirdparty/hnswlib/hnswlib/hnswlib.h @@ -122,21 +122,21 @@ static bool AVX512Capable() { #include namespace hnswlib { -typedef size_t labeltype; // This can be extended to store state for filtering (e.g. from a std::set) +template class BaseFilterFunctor { public: - virtual bool operator()(hnswlib::labeltype id) { return true; } + virtual bool operator()(label_t id) { return true; } virtual ~BaseFilterFunctor() {}; }; -template +template class BaseSearchStopCondition { public: - virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0; + virtual void add_point_to_result(label_t label, const void *datapoint, dist_t dist) = 0; - virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0; + virtual void remove_point_from_result(label_t label, const void *datapoint, dist_t dist) = 0; virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0; @@ -144,7 +144,7 @@ class BaseSearchStopCondition { virtual bool should_remove_extra() = 0; - virtual void filter_results(std::vector> &candidates) = 0; + virtual void filter_results(std::vector> &candidates) = 0; virtual ~BaseSearchStopCondition() {} }; @@ -183,17 +183,17 @@ class SpaceInterface { virtual ~SpaceInterface() {} }; -template +template class AlgorithmInterface { public: - virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; + virtual void addPoint(const void *datapoint, label_t label, bool replace_deleted = false) = 0; - virtual std::priority_queue> - searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; + virtual std::priority_queue> + searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; // Return k nearest neighbor in the order of closer fist - virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; virtual void saveIndex(const std::string &location) = 0; @@ -203,11 +203,11 @@ class AlgorithmInterface { } }; -template -std::vector> -AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - BaseFilterFunctor* isIdAllowed) const { - std::vector> result; +template +std::vector> +AlgorithmInterface::searchKnnCloserFirst( + const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed) const { + std::vector> result; // here searchKnn returns the result in the order of further first auto ret = searchKnn(query_data, k, isIdAllowed); diff --git a/src/inline-thirdparty/hnswlib/hnswlib/stop_condition.h b/src/inline-thirdparty/hnswlib/hnswlib/stop_condition.h index 2fc6199065a8..4d972afc2713 100644 --- a/src/inline-thirdparty/hnswlib/hnswlib/stop_condition.h +++ b/src/inline-thirdparty/hnswlib/hnswlib/stop_condition.h @@ -143,8 +143,8 @@ class MultiVectorInnerProductSpace : public BaseMultiVectorSpace { }; -template -class MultiVectorSearchStopCondition : public BaseSearchStopCondition { +template +class MultiVectorSearchStopCondition : public BaseSearchStopCondition { size_t curr_num_docs_; size_t num_docs_to_search_; size_t ef_collection_; @@ -163,7 +163,7 @@ class MultiVectorSearchStopCondition : public BaseSearchStopCondition { ef_collection_ = std::max(ef_collection, num_docs_to_search); } - void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + void add_point_to_result(label_t label, const void *datapoint, dist_t dist) override { DOCIDTYPE doc_id = space_.get_doc_id(datapoint); if (doc_counter_[doc_id] == 0) { curr_num_docs_ += 1; @@ -172,7 +172,7 @@ class MultiVectorSearchStopCondition : public BaseSearchStopCondition { doc_counter_[doc_id] += 1; } - void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + void remove_point_from_result(label_t label, const void *datapoint, dist_t dist) override { DOCIDTYPE doc_id = space_.get_doc_id(datapoint); doc_counter_[doc_id] -= 1; if (doc_counter_[doc_id] == 0) { @@ -196,7 +196,7 @@ class MultiVectorSearchStopCondition : public BaseSearchStopCondition { return flag_remove_extra; } - void filter_results(std::vector> &candidates) override { + void filter_results(std::vector> &candidates) override { while (curr_num_docs_ > num_docs_to_search_) { dist_t dist_cand = candidates.back().first; dist_t dist_res = search_results_.top().first; @@ -215,8 +215,8 @@ class MultiVectorSearchStopCondition : public BaseSearchStopCondition { }; -template -class EpsilonSearchStopCondition : public BaseSearchStopCondition { +template +class EpsilonSearchStopCondition : public BaseSearchStopCondition { float epsilon_; size_t min_num_candidates_; size_t max_num_candidates_; @@ -231,11 +231,11 @@ class EpsilonSearchStopCondition : public BaseSearchStopCondition { curr_num_items_ = 0; } - void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + void add_point_to_result(label_t label, const void *datapoint, dist_t dist) override { curr_num_items_ += 1; } - void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + void remove_point_from_result(label_t label, const void *datapoint, dist_t dist) override { curr_num_items_ -= 1; } @@ -262,7 +262,7 @@ class EpsilonSearchStopCondition : public BaseSearchStopCondition { return flag_remove_extra; } - void filter_results(std::vector> &candidates) override { + void filter_results(std::vector> &candidates) override { while (!candidates.empty() && candidates.back().first > epsilon_) { candidates.pop_back(); } diff --git a/src/yb/vector_index/hnswlib_wrapper.cc b/src/yb/vector_index/hnswlib_wrapper.cc index 9aef69604895..55eac5616ca6 100644 --- a/src/yb/vector_index/hnswlib_wrapper.cc +++ b/src/yb/vector_index/hnswlib_wrapper.cc @@ -59,7 +59,7 @@ class HnswlibIndex : public: using Scalar = typename Vector::value_type; - using HNSWImpl = typename hnswlib::HierarchicalNSW; + using HNSWImpl = hnswlib::HierarchicalNSW; explicit HnswlibIndex(const HNSWOptions& options) : options_(options) { @@ -95,15 +95,8 @@ class HnswlibIndex : return Status::OK(); } - Status DoInsert(VectorId vertex_id, const Vector& v) { - // TODO(vector-index) temp solution for hnsw lib which accepts only integers as vector id. - auto it = vector_id_label_map_.find(vertex_id); - if (it == vector_id_label_map_.end()) { - static std::atomic counter {1}; - auto label = counter.fetch_add(1, std::memory_order_relaxed); - std::tie(it, std::ignore) = vector_id_label_map_.insert({vertex_id, label}); - } - hnsw_->addPoint(v.data(), it->second); + Status DoInsert(VectorId vector_id, const Vector& v) { + hnsw_->addPoint(v.data(), vector_id); return Status::OK(); } @@ -149,11 +142,11 @@ class HnswlibIndex : auto tmp_result = hnsw_->searchKnnCloserFirst(query_vector.data(), options.max_num_results); result.reserve(tmp_result.size()); for (const auto& entry : tmp_result) { - // Being careful to avoid switching the order of distance and vertex id.. + // Being careful to avoid switching the order of distance and vertex id. const auto distance = entry.first; static_assert(std::is_same_v, DistanceResult>); - result.push_back({ GetVectorIdByLabel(entry.second), distance }); + result.push_back({ entry.second, distance }); } return result; } @@ -229,29 +222,6 @@ class HnswlibIndex : HNSWOptions options_; std::unique_ptr> space_; std::unique_ptr hnsw_; - - // TODO(vector-index) refer to https://github.com/yugabyte/yugabyte-db/issues/25041. - struct VectorLabelTag; - using VectorLabel = hnswlib::labeltype; - using VectorIdLabelPair = std::pair; - using VectorIdLabelMap = boost::multi_index::multi_index_container< - VectorIdLabelPair, - boost::multi_index::indexed_by< - boost::multi_index::hashed_unique< - BOOST_MULTI_INDEX_MEMBER(VectorIdLabelPair, VectorId, first)>, - boost::multi_index::hashed_unique< - boost::multi_index::tag, - BOOST_MULTI_INDEX_MEMBER(VectorIdLabelPair, VectorLabel, second)>>>; - - const VectorId& GetVectorIdByLabel(VectorLabel label) const { - const auto& label_to_id_map = vector_id_label_map_.template get(); - auto it = label_to_id_map.find(label); - CHECK(it != label_to_id_map.end()); - CHECK_NE(0, it->second); - return it->first; - } - VectorIdLabelMap vector_id_label_map_; - friend class HnswlibVectorIterator; }; @@ -259,9 +229,9 @@ template class HnswlibVectorIterator : public AbstractIterator> { public: using VectorIndex = HnswlibIndex; + using HNSWIterator = hnswlib::VectorIterator; - HnswlibVectorIterator(const VectorIndex& index, - typename hnswlib::VectorIterator position, int dimensions) + HnswlibVectorIterator(const VectorIndex& index, HNSWIterator position, int dimensions) : internal_iterator_(position), dimensions_(dimensions), index_(index) {} protected: @@ -270,7 +240,7 @@ class HnswlibVectorIterator : public AbstractIterator internal_iterator_; + HNSWIterator internal_iterator_; int dimensions_; // TODO(vector-index) refer to https://github.com/yugabyte/yugabyte-db/issues/25041.