Skip to content

Commit

Permalink
[#25041] docdb: Update HNSW lib to support custom vector label
Browse files Browse the repository at this point in the history
Summary:
The change extends HNSW lib to support a custom vector label (via template argument).
Also hnswib_wrapper is updated accordingly to use VectorId as vector label directly.
Jira: DB-14175

Test Plan: Jenkins

Reviewers: sergei, slingam

Reviewed By: sergei

Subscribers: ybase

Differential Revision: https://phorge.dev.yugabyte.com/D41440
  • Loading branch information
arybochkin committed Jan 23, 2025
1 parent ec77d2b commit ea9c77c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 123 deletions.
33 changes: 18 additions & 15 deletions src/inline-thirdparty/hnswlib/hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <assert.h>

namespace hnswlib {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename label_t>
class BruteforceSearch : public AlgorithmInterface<dist_t, label_t> {
public:
char *data_;
size_t maxelements_;
Expand All @@ -19,7 +19,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
void *dist_func_param_;
std::mutex index_lock;

std::unordered_map<labeltype, size_t > dict_external_to_internal;
std::unordered_map<label_t, size_t > dict_external_to_internal;


BruteforceSearch(SpaceInterface <dist_t> *s)
Expand Down Expand Up @@ -48,7 +48,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
size_per_element_ = data_size_ + sizeof(label_t);
data_ = (char *) malloc(maxElements * size_per_element_);
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
Expand All @@ -61,7 +61,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
}


void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
void addPoint(const void *datapoint, label_t label, bool replace_deleted = false) {
int idx;
{
std::unique_lock<std::mutex> lock(index_lock);
Expand All @@ -78,12 +78,12 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
cur_element_count++;
}
}
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(label_t));
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
}


void removePoint(labeltype cur_external) {
void removePoint(label_t cur_external) {
std::unique_lock<std::mutex> lock(index_lock);

auto found = dict_external_to_internal.find(cur_external);
Expand All @@ -94,23 +94,25 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
dict_external_to_internal.erase(found);

size_t cur_c = found->second;
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
label_t label;
memcpy(&label, data_ + size_per_element_ * (cur_element_count-1) + data_size_, sizeof(label_t));
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
data_ + size_per_element_ * (cur_element_count-1),
data_size_+sizeof(labeltype));
data_size_+sizeof(label_t));
cur_element_count--;
}


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, label_t>>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor<label_t>* isIdAllowed = nullptr) const {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
std::priority_queue<std::pair<dist_t, label_t >> topResults;
if (cur_element_count == 0) return topResults;
for (size_t i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
label_t label;
memcpy(&label, data_ + size_per_element_ * i + data_size_, sizeof(label_t));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.emplace(dist, label);
}
Expand All @@ -119,7 +121,8 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
label_t label;
memcpy(&label, data_ + size_per_element_ * i + data_size_, sizeof(label_t));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.emplace(dist, label);
}
Expand Down Expand Up @@ -160,7 +163,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
size_per_element_ = data_size_ + sizeof(label_t);
data_ = (char *) malloc(maxelements_ * size_per_element_);
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
Expand Down
84 changes: 42 additions & 42 deletions src/inline-thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

template<typename dist_t>
template<typename dist_t, typename label_t>
class VectorIterator;

struct Stats {
Expand Down Expand Up @@ -43,8 +43,8 @@ struct InternalParameters {
size_t bytes_per_vector = 0;
};

template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename label_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t, label_t> {
public:
static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
static const unsigned char DELETE_MARK = 0x01;
Expand Down Expand Up @@ -87,7 +87,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
void *dist_func_param_{nullptr};

mutable std::mutex label_lookup_lock; // lock for label_lookup_
std::unordered_map<labeltype, tableint> label_lookup_;
std::unordered_map<label_t, tableint> label_lookup_;

std::default_random_engine level_generator_;
std::default_random_engine update_probability_generator_;
Expand All @@ -101,16 +101,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements

// Function to return the begin iterator
VectorIterator<dist_t> vectors_begin() {
return VectorIterator<dist_t>(0, this);
auto vectors_begin() {
return VectorIterator<dist_t, label_t>(0, this);
}

// Function to return the end iterator
VectorIterator<dist_t> vectors_end() {
auto vectors_end() {
return VectorIterator(cur_element_count, this);
}


HierarchicalNSW(SpaceInterface<dist_t> *s) {
}

Expand Down Expand Up @@ -159,7 +158,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
update_probability_generator_.seed(random_seed + 1);

size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(label_t);
offsetData_ = size_links_level0_;
label_offset_ = size_links_level0_ + data_size_;
offsetLevel0_ = 0;
Expand Down Expand Up @@ -216,27 +215,28 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


inline std::mutex& getLabelOpMutex(labeltype label) const {
inline std::mutex& getLabelOpMutex(label_t label) const {
// calculate hash
size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
size_t lock_id = std::hash<label_t>{}(label) & (MAX_LABEL_OPERATION_LOCKS - 1);
return label_op_locks_[lock_id];
}


inline labeltype getExternalLabel(tableint internal_id) const {
labeltype return_label;
memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
inline label_t getExternalLabel(tableint internal_id) const {
label_t return_label;
memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(label_t));
return return_label;
}


inline void setExternalLabel(tableint internal_id, labeltype label) const {
memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
inline void setExternalLabel(tableint internal_id, label_t label) const {
memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(label_t));
}


inline labeltype *getExternalLabeLp(tableint internal_id) const {
return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
// NB! Dereferencing the returned pointer is an undefined behavior, mem copy must be used instead.
inline label_t *getExternalLabeLp(tableint internal_id) const {
return (label_t *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
}


Expand Down Expand Up @@ -353,8 +353,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint ep_id,
const void *data_point,
size_t ef,
BaseFilterFunctor* isIdAllowed = nullptr,
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
BaseFilterFunctor<label_t>* isIdAllowed = nullptr,
BaseSearchStopCondition<dist_t, label_t>* stop_condition = nullptr) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand Down Expand Up @@ -726,7 +726,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
void saveIndex(const std::string &location) override {
std::ofstream output(location, std::ios::binary);
std::streampos position;

// TODO(vector-index): we may want to store sizeof(label_t) to make sure an index can be loaded.
writeBinaryPOD(output, offsetLevel0_);
writeBinaryPOD(output, max_elements_);
writeBinaryPOD(output, cur_element_count);
Expand Down Expand Up @@ -864,7 +864,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {


template<typename data_t>
std::vector<data_t> getDataByLabel(labeltype label) const {
std::vector<data_t> getDataByLabel(label_t label) const {
// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));

Expand All @@ -891,7 +891,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
/*
* Marks an element with the given label deleted, does NOT really change the current graph.
*/
void markDelete(labeltype label) {
void markDelete(label_t label) {
// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));

Expand Down Expand Up @@ -933,7 +933,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
* Note: the method is not safe to use when replacement of deleted elements is enabled,
* because elements marked as deleted can be completely removed by addPoint
*/
void unmarkDelete(labeltype label) {
void unmarkDelete(label_t label) {
// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));

Expand Down Expand Up @@ -992,7 +992,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
* Adds point. Updates the point if it is already in the index.
* If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
*/
void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) override {
void addPoint(const void *data_point, label_t label, bool replace_deleted = false) override {
if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
}
Expand All @@ -1019,7 +1019,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
addPoint(data_point, label, -1);
} else {
// we assume that there are no concurrent operations on deleted element
labeltype label_replaced = getExternalLabel(internal_id_replaced);
label_t label_replaced = getExternalLabel(internal_id_replaced);
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
Expand Down Expand Up @@ -1191,7 +1191,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


tableint addPoint(const void *data_point, labeltype label, int level) {
tableint addPoint(const void *data_point, label_t label, int level) {
tableint cur_c = 0;
{
// Checking if the element with the same label already exists
Expand Down Expand Up @@ -1240,7 +1240,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);

// Initialisation of the data and label
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
memcpy(getExternalLabeLp(cur_c), &label, sizeof(label_t));
memcpy(getDataByInternalId(cur_c), data_point, data_size_);

if (curlevel) {
Expand Down Expand Up @@ -1307,9 +1307,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const override {
std::priority_queue<std::pair<dist_t, labeltype >> result;
std::priority_queue<std::pair<dist_t, label_t >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor<label_t>* isIdAllowed = nullptr) const override {
std::priority_queue<std::pair<dist_t, label_t >> result;
if (cur_element_count == 0) return result;

tableint currObj = enterpoint_node_;
Expand Down Expand Up @@ -1357,19 +1357,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
while (top_candidates.size() > 0) {
std::pair<dist_t, tableint> rez = top_candidates.top();
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
result.push(std::pair<dist_t, label_t>(rez.first, getExternalLabel(rez.second)));
top_candidates.pop();
}
return result;
}


std::vector<std::pair<dist_t, labeltype >>
std::vector<std::pair<dist_t, label_t >>
searchStopConditionClosest(
const void *query_data,
BaseSearchStopCondition<dist_t>& stop_condition,
BaseFilterFunctor* isIdAllowed = nullptr) const {
std::vector<std::pair<dist_t, labeltype >> result;
BaseSearchStopCondition<dist_t, label_t>& stop_condition,
BaseFilterFunctor<label_t>* isIdAllowed = nullptr) const {
std::vector<std::pair<dist_t, label_t >> result;
if (cur_element_count == 0) return result;

tableint currObj = enterpoint_node_;
Expand Down Expand Up @@ -1476,7 +1476,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return {};
}
max_level = std::min(maxlevel_, max_level);
size_t node_head_bytes = size_links_level0_ + sizeof(labeltype); // Node header size
size_t node_head_bytes = size_links_level0_ + sizeof(label_t); // Node header size

// Iterate through all the elements
auto num_elements = cur_element_count.load(std::memory_order_seq_cst);
Expand Down Expand Up @@ -1523,15 +1523,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
};

// Define an iterator class for the stored vectors
template<typename dist_t>
template<typename dist_t, typename label_t>
class VectorIterator {
public:
// Constructor for the iterator
VectorIterator(tableint internal_id, HierarchicalNSW<dist_t> * outer)
: curr_internal_id_(internal_id),outer_(outer) {}
VectorIterator(tableint internal_id, HierarchicalNSW<dist_t, label_t>* outer)
: curr_internal_id_(internal_id), outer_(outer) {}

// Dereference operator to access the vector data
std::pair<const void*, labeltype> operator*() const {
std::pair<const void*, label_t> operator*() const {
return std::make_pair(
outer_->getDataByInternalId(curr_internal_id_),
outer_->getExternalLabel(curr_internal_id_));
Expand All @@ -1550,7 +1550,7 @@ class VectorIterator {

private:
tableint curr_internal_id_;
HierarchicalNSW<dist_t> * outer_;
HierarchicalNSW<dist_t, label_t>* outer_;
};

} // namespace hnswlib
Expand Down
Loading

0 comments on commit ea9c77c

Please sign in to comment.