From a80e7558140a835dca58f3e33078a2d29ba50b9a Mon Sep 17 00:00:00 2001 From: Sergei Politov Date: Mon, 20 Jan 2025 20:51:24 +0300 Subject: [PATCH] [#25687] DocDB: Remove key value storage callbacks from vector LSM Summary: In initial design vector LSM was responsible for binding vector id with indexed table key. So we added a way to store this binding in regular DB. Currently vector id is generated externally to vector LSM, so it could avoid dealing with indexed table keys, and this logic could be implemented in higher level. Jira: DB-14945 Test Plan: Jenkins Reviewers: arybochkin Reviewed By: arybochkin Subscribers: yql, ybase Tags: #jenkins-ready Differential Revision: https://phorge.dev.yugabyte.com/D41337 --- src/yb/docdb/pgsql_operation.cc | 2 +- src/yb/docdb/rocksdb_writer.cc | 44 ++++++++--- src/yb/docdb/rocksdb_writer.h | 6 +- src/yb/docdb/vector_index.cc | 77 ++++--------------- src/yb/docdb/vector_index.h | 7 +- src/yb/dockv/vector_id.cc | 17 ++++ src/yb/dockv/vector_id.h | 3 + src/yb/vector_index/ann_validation.cc | 10 +-- src/yb/vector_index/ann_validation.h | 2 +- src/yb/vector_index/distance.h | 30 ++++---- src/yb/vector_index/hnswlib_wrapper.cc | 4 +- src/yb/vector_index/index_merge-test.cc | 8 +- src/yb/vector_index/index_wrapper_base.h | 4 +- src/yb/vector_index/sharded_index.h | 2 +- src/yb/vector_index/usearch_wrapper.cc | 6 +- src/yb/vector_index/vector_index_if.h | 2 +- .../vector_index/vector_index_wrapper_util.h | 2 +- src/yb/vector_index/vector_lsm-test.cc | 48 +++++------- src/yb/vector_index/vector_lsm.cc | 35 +++------ src/yb/vector_index/vector_lsm.h | 39 ++-------- src/yb/vector_index/vectorann.cc | 2 +- src/yb/vector_index/vectorann_util.h | 14 ++-- 22 files changed, 157 insertions(+), 207 deletions(-) diff --git a/src/yb/docdb/pgsql_operation.cc b/src/yb/docdb/pgsql_operation.cc index f2a359bd6bb3..f48db5f325fd 100644 --- a/src/yb/docdb/pgsql_operation.cc +++ b/src/yb/docdb/pgsql_operation.cc @@ -919,7 +919,7 @@ class PgsqlVectorFilter { } bool operator()(const vector_index::VectorId& vector_id) { - auto key = VectorIdKey(vector_id); + auto key = dockv::VectorIdKey(vector_id); // TODO(vector_index) handle failure auto ybctid = CHECK_RESULT(iter_.impl().FetchDirect(key.AsSlice())); if (ybctid.empty()) { diff --git a/src/yb/docdb/rocksdb_writer.cc b/src/yb/docdb/rocksdb_writer.cc index f9a072a2b78d..c41d542a938a 100644 --- a/src/yb/docdb/rocksdb_writer.cc +++ b/src/yb/docdb/rocksdb_writer.cc @@ -13,6 +13,8 @@ #include "yb/docdb/rocksdb_writer.h" +#include + #include "yb/common/row_mark.h" #include "yb/docdb/conflict_resolution.h" @@ -30,6 +32,7 @@ #include "yb/dockv/packed_value.h" #include "yb/dockv/schema_packing.h" #include "yb/dockv/value_type.h" +#include "yb/dockv/vector_id.h" #include "yb/gutil/walltime.h" @@ -688,7 +691,7 @@ Result ApplyIntentsContext::Entry( } if (vector_indexes_) { - RETURN_NOT_OK(ProcessVectorIndexes(intent.doc_path, decoded_value.body)); + RETURN_NOT_OK(ProcessVectorIndexes(handler, intent.doc_path, decoded_value.body)); } ++write_id_; @@ -704,7 +707,16 @@ Result ApplyIntentsContext::Entry( return false; } -Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) { +void ApplyIntentsContext::AddVectorIndexReverseEntry( + rocksdb::DirectWriteHandler* handler, Slice ybctid, Slice value) { + DocHybridTimeBuffer ht_buf; + auto encoded_write_time = ht_buf.EncodeWithValueType({ commit_ht_, write_id_ }); + + handler->Put(dockv::VectorIndexReverseEntryKeyParts(value, encoded_write_time), {&ybctid, 1}); +} + +Status ApplyIntentsContext::ProcessVectorIndexes( + rocksdb::DirectWriteHandler* handler, Slice key, Slice value) { auto sizes = VERIFY_RESULT(dockv::DocKey::EncodedPrefixAndDocKeySizes(key)); if (sizes.doc_key_size < key.size()) { auto entry_type = static_cast(key[sizes.doc_key_size]); @@ -712,6 +724,7 @@ Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) { auto column_id = VERIFY_RESULT(ColumnId::FullyDecode( key.WithoutPrefix(sizes.doc_key_size + 1))); // We expect small amount of vector indexes, usually 1. So it is faster to iterate over them. + bool added_to_vector_index = false; for (size_t i = 0; i != vector_indexes_->size(); ++i) { if (!ApplyToVectorIndex(i)) { continue; @@ -719,10 +732,15 @@ Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) { const auto& vector_index = *(*vector_indexes_)[i]; auto table_key_prefix = vector_index.indexed_table_key_prefix(); if (key.starts_with(table_key_prefix) && vector_index.column_id() == column_id) { + auto ybctid = key.Prefix(sizes.doc_key_size).WithoutPrefix(table_key_prefix.size()); vector_index_batches_[i].push_back(VectorIndexInsertEntry { - .key = KeyBuffer(key.Prefix(sizes.doc_key_size).WithoutPrefix(table_key_prefix.size())), + .key = KeyBuffer(ybctid), .value = ValueBuffer(value), }); + if (!added_to_vector_index) { + AddVectorIndexReverseEntry(handler, ybctid, value); + added_to_vector_index = true; + } } } } else { @@ -740,10 +758,10 @@ Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) { switch (*packed_row_version) { case dockv::PackedRowVersion::kV1: return ProcessVectorIndexesForPackedRow( - sizes.prefix_size, key, value); + handler, sizes.prefix_size, key, value); case dockv::PackedRowVersion::kV2: return ProcessVectorIndexesForPackedRow( - sizes.prefix_size, key, value); + handler, sizes.prefix_size, key, value); } FATAL_INVALID_ENUM_VALUE(dockv::PackedRowVersion, *packed_row_version); } @@ -752,7 +770,7 @@ Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) { template Status ApplyIntentsContext::ProcessVectorIndexesForPackedRow( - size_t prefix_size, Slice key, Slice value) { + rocksdb::DirectWriteHandler* handler, size_t prefix_size, Slice key, Slice value) { value.consume_byte(); auto schema_version = narrow_cast(VERIFY_RESULT(FastDecodeUnsignedVarInt(&value))); @@ -771,6 +789,7 @@ Status ApplyIntentsContext::ProcessVectorIndexesForPackedRow( } Decoder decoder(*schema_packing_, value.data()); + boost::dynamic_bitset<> columns_added_to_vector_index; for (size_t i = 0; i != vector_indexes_->size(); ++i) { if (!ApplyToVectorIndex(i)) { continue; @@ -786,10 +805,18 @@ Status ApplyIntentsContext::ProcessVectorIndexesForPackedRow( continue; } + auto ybctid = key.WithoutPrefix(table_key_prefix.size()); vector_index_batches_[i].push_back(VectorIndexInsertEntry { - .key = KeyBuffer(key.WithoutPrefix(table_key_prefix.size())), + .key = KeyBuffer(ybctid), .value = ValueBuffer(*column_value), }); + + size_t column_index = schema_packing_->GetIndex(vector_index.column_id()); + columns_added_to_vector_index.resize( + std::max(columns_added_to_vector_index.size(), column_index + 1)); + if (!columns_added_to_vector_index.test_set(column_index)) { + AddVectorIndexReverseEntry(handler, ybctid, *column_value); + } } return Status::OK(); } @@ -804,8 +831,7 @@ Status ApplyIntentsContext::Complete(rocksdb::DirectWriteHandler* handler) { DocHybridTime write_time { commit_ht_, write_id_ }; for (size_t i = 0; i != vector_index_batches_.size(); ++i) { if (!vector_index_batches_[i].empty()) { - RETURN_NOT_OK((*vector_indexes_)[i]->Insert( - vector_index_batches_[i], frontiers(), handler, write_time)); + RETURN_NOT_OK((*vector_indexes_)[i]->Insert(vector_index_batches_[i], frontiers())); } } } diff --git a/src/yb/docdb/rocksdb_writer.h b/src/yb/docdb/rocksdb_writer.h index f75f936e0dbd..f1a361e66fda 100644 --- a/src/yb/docdb/rocksdb_writer.h +++ b/src/yb/docdb/rocksdb_writer.h @@ -259,9 +259,11 @@ class ApplyIntentsContext : public IntentsWriterContext, public FrontierSchemaVe private: Result StoreApplyState(const Slice& key, rocksdb::DirectWriteHandler* handler); - Status ProcessVectorIndexes(Slice key, Slice value); + Status ProcessVectorIndexes(rocksdb::DirectWriteHandler* handler, Slice key, Slice value); template - Status ProcessVectorIndexesForPackedRow(size_t prefix_size, Slice key, Slice value); + Status ProcessVectorIndexesForPackedRow( + rocksdb::DirectWriteHandler* handler, size_t prefix_size, Slice key, Slice value); + void AddVectorIndexReverseEntry(rocksdb::DirectWriteHandler* handler, Slice ybctid, Slice value); bool ApplyToRegularDB() const { return apply_to_storages_.TestRegularDB(); diff --git a/src/yb/docdb/vector_index.cc b/src/yb/docdb/vector_index.cc index cb5d039c71f1..b669915a134c 100644 --- a/src/yb/docdb/vector_index.cc +++ b/src/yb/docdb/vector_index.cc @@ -23,8 +23,6 @@ #include "yb/qlexpr/index.h" -#include "yb/rocksdb/write_batch.h" - #include "yb/util/decimal.h" #include "yb/util/endian_util.h" #include "yb/util/path_util.h" @@ -119,8 +117,7 @@ Result> ConvertEntry( auto encoded = dockv::EncodedDocVectorValue::FromSlice(entry.value.AsSlice()); return vector_index::VectorLSMInsertEntry { - .vertex_id = VERIFY_RESULT(encoded.DecodeId()), - .base_table_key = entry.key, + .vector_id = VERIFY_RESULT(encoded.DecodeId()), .vector = VERIFY_RESULT(VectorFromBinary(encoded.data)), }; } @@ -129,14 +126,9 @@ size_t EncodeDistance(float distance) { return bit_cast(util::CanonicalizeFloat(distance)); } -struct VectorIndexInsertContext : public vector_index::VectorLSMInsertContext { - rocksdb::DirectWriteHandler* handler; - DocHybridTime write_time; -}; - template -class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyValueStorage { +class VectorIndexImpl : public VectorIndex { public: VectorIndexImpl( const TableId& table_id, Slice indexed_table_key_prefix, ColumnId column_id, @@ -167,7 +159,6 @@ class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyVal .vector_index_factory = VERIFY_RESULT((GetVectorLSMFactory( idx_options.idx_type(), idx_options.dimensions()))), .points_per_chunk = FLAGS_vector_index_initial_chunk_size, - .key_value_storage = this, .thread_pool = &thread_pool, .frontiers_factory = [] { return std::make_unique(); }, }; @@ -175,19 +166,15 @@ class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyVal } Status Insert( - const VectorIndexInsertEntries& entries, - const rocksdb::UserFrontiers* frontiers, - rocksdb::DirectWriteHandler* handler, - DocHybridTime write_time) override { + const VectorIndexInsertEntries& entries, const rocksdb::UserFrontiers* frontiers) override { typename LSM::InsertEntries lsm_entries; lsm_entries.reserve(entries.size()); for (const auto& entry : entries) { lsm_entries.push_back(VERIFY_RESULT(ConvertEntry(entry))); } - VectorIndexInsertContext context; - context.frontiers = frontiers; - context.handler = handler; - context.write_time = write_time; + vector_index::VectorLSMInsertContext context { + .frontiers = frontiers, + }; return lsm_.Insert(lsm_entries, context); } @@ -195,12 +182,22 @@ class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyVal Slice vector, const vector_index::SearchOptions& options) override { auto entries = VERIFY_RESULT(lsm_.Search( VERIFY_RESULT(VectorFromYSQL(vector)), options)); + + // TODO(vector-index): check if ReadOptions are required. + docdb::BoundedRocksDbIterator iter(doc_db_.regular, {}, doc_db_.key_bounds); + VectorIndexSearchResult result; result.reserve(entries.size()); for (auto& entry : entries) { + auto key = dockv::VectorIdKey(entry.vector_id); + const auto& db_entry = iter.Seek(key.AsSlice()); + if (!db_entry.Valid() || !db_entry.key.starts_with(key.AsSlice())) { + return STATUS_FORMAT(NotFound, "Vector not found: $0", entry.vector_id); + } + result.push_back(VectorIndexSearchResultEntry { .encoded_distance = EncodeDistance(entry.distance), - .key = entry.base_table_key, + .key = KeyBuffer(db_entry.value), }); } return result; @@ -237,38 +234,6 @@ class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyVal } private: - Status StoreBaseTableKeys( - const vector_index::BaseTableKeysBatch& batch, - const vector_index::VectorLSMInsertContext& insert_context) override { - const auto& context = static_cast(insert_context); - for (const auto& [vector_id, base_table_key] : batch) { - DocHybridTimeBuffer ht_buf; - auto kb = VectorIdKey(vector_id); - kb.Append(ht_buf.EncodeWithValueType(context.write_time)); - auto kbs = kb.AsSlice(); - - ValueBuffer vb; - vb.Append(base_table_key); - auto vbs = vb.AsSlice(); - context.handler->Put({&kbs, 1}, {&vbs, 1}); - } - - return Status::OK(); - } - - Result ReadBaseTableKey(vector_index::VectorId vector_id) override { - // TODO(vector-index) check if ReadOptions are required. - docdb::BoundedRocksDbIterator iter(doc_db_.regular, {}, doc_db_.key_bounds); - - auto key = VectorIdKey(vector_id); - const auto& entry = iter.Seek(key.AsSlice()); - if (!entry.Valid()) { - return STATUS_FORMAT(NotFound, "Vector not found: $0", vector_id); - } - - return KeyBuffer { entry.value }; - } - std::string DirName() const { return kVectorIndexDirPrefix + table_id_; } @@ -299,12 +264,4 @@ Result CreateVectorIndex( return result; } -KeyBuffer VectorIdKey(vector_index::VectorId vector_id) { - KeyBuffer key; - key.PushBack(dockv::KeyEntryTypeAsChar::kVectorIndexMetadata); - key.PushBack(dockv::KeyEntryTypeAsChar::kVectorId); - key.Append(vector_id.AsSlice()); - return key; -} - } // namespace yb::docdb diff --git a/src/yb/docdb/vector_index.h b/src/yb/docdb/vector_index.h index 76625bd92d75..63bfdbcbaea9 100644 --- a/src/yb/docdb/vector_index.h +++ b/src/yb/docdb/vector_index.h @@ -54,10 +54,7 @@ class VectorIndex { virtual const std::string& path() const = 0; virtual Status Insert( - const VectorIndexInsertEntries& entries, - const rocksdb::UserFrontiers* frontiers, - rocksdb::DirectWriteHandler* handler, - DocHybridTime write_time) = 0; + const VectorIndexInsertEntries& entries, const rocksdb::UserFrontiers* frontiers) = 0; virtual Result Search( Slice vector, const vector_index::SearchOptions& options) = 0; virtual Result Distance(Slice lhs, Slice rhs) = 0; @@ -77,8 +74,6 @@ Result CreateVectorIndex( const qlexpr::IndexInfo& index_info, const DocDB& doc_db); -KeyBuffer VectorIdKey(vector_index::VectorId vector_id); - extern const std::string kVectorIndexDirPrefix; } // namespace yb::docdb diff --git a/src/yb/dockv/vector_id.cc b/src/yb/dockv/vector_id.cc index 7af6122f6eb4..f57e239fe190 100644 --- a/src/yb/dockv/vector_id.cc +++ b/src/yb/dockv/vector_id.cc @@ -34,6 +34,8 @@ namespace { constexpr const size_t kEncodedVectorIdValueSize = 1 + kUuidSize; constexpr const size_t kEncodedVectorIdSize = kEncodedVectorIdValueSize + 1; +constexpr std::array kVectorIdKeyPrefix = + { dockv::KeyEntryTypeAsChar::kVectorIndexMetadata, dockv::KeyEntryTypeAsChar::kVectorId }; char* GrowAtLeast(std::string* buffer, size_t size) { const auto current_size = buffer->size(); @@ -120,4 +122,19 @@ bool IsNull(const dockv::DocVectorValue& v) { return IsNull(v.value()); } +KeyBuffer VectorIdKey(vector_index::VectorId vector_id) { + KeyBuffer key; + key.Append(Slice(kVectorIdKeyPrefix)); + key.Append(vector_id.AsSlice()); + return key; +} + +std::array VectorIndexReverseEntryKeyParts(Slice value, Slice encoded_write_time) { + return std::array{ + Slice(kVectorIdKeyPrefix), + EncodedDocVectorValue::FromSlice(value).id, + encoded_write_time, + }; +} + } // namespace yb::dockv diff --git a/src/yb/dockv/vector_id.h b/src/yb/dockv/vector_id.h index 906cfb1a628a..39d179ee9474 100644 --- a/src/yb/dockv/vector_id.h +++ b/src/yb/dockv/vector_id.h @@ -61,4 +61,7 @@ class DocVectorValue final { bool IsNull(const DocVectorValue& v); +KeyBuffer VectorIdKey(vector_index::VectorId vector_id); +std::array VectorIndexReverseEntryKeyParts(Slice value, Slice encoded_write_time); + } // namespace yb::dockv diff --git a/src/yb/vector_index/ann_validation.cc b/src/yb/vector_index/ann_validation.cc index 5ca696daa5cf..35678d40e672 100644 --- a/src/yb/vector_index/ann_validation.cc +++ b/src/yb/vector_index/ann_validation.cc @@ -28,7 +28,7 @@ namespace yb::vector_index { namespace { template -using VerticesWithDistances = std::vector>; +using VerticesWithDistances = std::vector>; template std::vector VertexIdsOnly( @@ -36,7 +36,7 @@ std::vector VertexIdsOnly( std::vector result; result.reserve(vertices_with_distances.size()); for (const auto& v_dist : vertices_with_distances) { - result.push_back(v_dist.vertex_id); + result.push_back(v_dist.vector_id); } return result; } @@ -228,14 +228,14 @@ Status GroundTruth::ProcessQuery( } template -std::vector> +std::vector> GroundTruth::AugmentWithDistancesAndTrimToK( const std::vector& precomputed_correct_results, const Vector& query) { VerticesWithDistances result; result.reserve(k_); for (auto vertex_id : precomputed_correct_results) { - result.push_back(VertexWithDistance(vertex_id, distance_fn_(vertex_id, query))); + result.push_back(VectorWithDistance(vertex_id, distance_fn_(vertex_id, query))); if (result.size() == k_) { break; } @@ -255,7 +255,7 @@ void GroundTruth::DoApproxSearchAndUpdateStats( vector_cast(query), SearchOptions{.max_num_results = k_})); std::unordered_set approx_set; for (const auto& approx_entry : approx_result) { - approx_set.insert(approx_entry.vertex_id); + approx_set.insert(approx_entry.vector_id); } size_t overlap = 0; diff --git a/src/yb/vector_index/ann_validation.h b/src/yb/vector_index/ann_validation.h index 2ebb06062936..e35671ce835d 100644 --- a/src/yb/vector_index/ann_validation.h +++ b/src/yb/vector_index/ann_validation.h @@ -76,7 +76,7 @@ class GroundTruth { // This works on queries convertered from input vector io indexed vector format. Only uses up to // k_ first elements of precomputed_correct_results. - std::vector> AugmentWithDistancesAndTrimToK( + std::vector> AugmentWithDistancesAndTrimToK( const std::vector& precomputed_correct_results, const Vector& converted_query); VertexIdToVectorDistanceFunction distance_fn_; diff --git a/src/yb/vector_index/distance.h b/src/yb/vector_index/distance.h index da3b91e93505..162e791b41b0 100644 --- a/src/yb/vector_index/distance.h +++ b/src/yb/vector_index/distance.h @@ -150,48 +150,48 @@ using VertexIdToVectorDistanceFunction = std::function; template -struct VertexWithDistance { - VectorId vertex_id = VectorId::Nil(); +struct VectorWithDistance { + VectorId vector_id = VectorId::Nil(); DistanceResult distance{}; // Constructor with the wrong order. Only delete it if DistanceResult is not uint64_t. template ::value, int>::type = 0> - VertexWithDistance(DistanceResult, VectorId) = delete; + VectorWithDistance(DistanceResult, VectorId) = delete; - VertexWithDistance() = default; + VectorWithDistance() = default; // Constructor with the correct order - VertexWithDistance(VectorId vertex_id_, DistanceResult distance_) - : vertex_id(vertex_id_), distance(distance_) {} + VectorWithDistance(VectorId vector_id_, DistanceResult distance_) + : vector_id(vector_id_), distance(distance_) {} std::string ToString() const { - return YB_STRUCT_TO_STRING(vertex_id, distance); + return YB_STRUCT_TO_STRING(vector_id, distance); } // Sort in lexicographical order of (distance, vertex_id). - bool operator <(const VertexWithDistance& other) const { + bool operator <(const VectorWithDistance& other) const { return distance < other.distance || - (distance == other.distance && vertex_id < other.vertex_id); + (distance == other.distance && vector_id < other.vector_id); } - bool operator>(const VertexWithDistance& other) const { + bool operator>(const VectorWithDistance& other) const { return other < *this; } - bool operator<=(const VertexWithDistance& other) const { + bool operator<=(const VectorWithDistance& other) const { return !(other < *this); } - bool operator>=(const VertexWithDistance& other) const { + bool operator>=(const VectorWithDistance& other) const { return !(*this < other); } }; template -bool operator==(const VertexWithDistance& lhs, - const VertexWithDistance& rhs) { - return YB_STRUCT_EQUALS(vertex_id, distance); +bool operator==(const VectorWithDistance& lhs, + const VectorWithDistance& rhs) { + return YB_STRUCT_EQUALS(vector_id, distance); } template diff --git a/src/yb/vector_index/hnswlib_wrapper.cc b/src/yb/vector_index/hnswlib_wrapper.cc index 15b5a6eedf48..9aef69604895 100644 --- a/src/yb/vector_index/hnswlib_wrapper.cc +++ b/src/yb/vector_index/hnswlib_wrapper.cc @@ -143,9 +143,9 @@ class HnswlibIndex : return space_->get_dist_func()(lhs.data(), rhs.data(), space_->get_dist_func_param()); } - std::vector> DoSearch( + std::vector> DoSearch( const Vector& query_vector, const SearchOptions& options) const { - std::vector> result; + std::vector> result; auto tmp_result = hnsw_->searchKnnCloserFirst(query_vector.data(), options.max_num_results); result.reserve(tmp_result.size()); for (const auto& entry : tmp_result) { diff --git a/src/yb/vector_index/index_merge-test.cc b/src/yb/vector_index/index_merge-test.cc index e2e6dd6cf352..97bdc313298f 100644 --- a/src/yb/vector_index/index_merge-test.cc +++ b/src/yb/vector_index/index_merge-test.cc @@ -71,8 +71,8 @@ class IndexMergeTest : public YBTest { void VerifyExpectedVertexIds(const VectorIndexReaderIf::SearchResult& results, std::set&& expected_ids) { for (const auto& result : results) { - ASSERT_TRUE(expected_ids.find(result.vertex_id) != expected_ids.end()); - expected_ids.erase(result.vertex_id); // Remove found ID from the set. + ASSERT_TRUE(expected_ids.find(result.vector_id) != expected_ids.end()); + expected_ids.erase(result.vector_id); // Remove found ID from the set. } ASSERT_TRUE(expected_ids.empty()); // Verify all expected IDs were found. } @@ -90,12 +90,12 @@ class IndexMergeTest : public YBTest { auto result_a = ASSERT_RESULT(merged_index->Search( input_vectors_[0], {.max_num_results = 1})); ASSERT_EQ(result_a.size(), 1); - ASSERT_EQ(result_a[0].vertex_id, data_a.vector_ids[0]); + ASSERT_EQ(result_a[0].vector_id, data_a.vector_ids[0]); auto result_b = ASSERT_RESULT(merged_index->Search( input_vectors_[half_size], {.max_num_results = 1})); ASSERT_EQ(result_b.size(), 1); - ASSERT_EQ(result_b[0].vertex_id, data_b.vector_ids[0]); + ASSERT_EQ(result_b[0].vector_id, data_b.vector_ids[0]); // Verify the size of the merged index. auto all_results = ASSERT_RESULT(merged_index->Search( diff --git a/src/yb/vector_index/index_wrapper_base.h b/src/yb/vector_index/index_wrapper_base.h index 32d5db111531..e08a65f80358 100644 --- a/src/yb/vector_index/index_wrapper_base.h +++ b/src/yb/vector_index/index_wrapper_base.h @@ -52,11 +52,11 @@ class IndexWrapperBase : public VectorIndexIf { return Status::OK(); } - Result>> Search( + Result>> Search( const Vector& query_vector, const SearchOptions& options) const override { if (!has_entries_) { - return std::vector>(); + return std::vector>(); } return impl().DoSearch(query_vector, options); } diff --git a/src/yb/vector_index/sharded_index.h b/src/yb/vector_index/sharded_index.h index fb0f31171acc..464553f6848f 100644 --- a/src/yb/vector_index/sharded_index.h +++ b/src/yb/vector_index/sharded_index.h @@ -83,7 +83,7 @@ class ShardedVectorIndex : public VectorIndexIf { // Search for the closest vectors across all shards. Result Search( const Vector& query_vector, const SearchOptions& options) const override { - std::vector> all_results; + std::vector> all_results; for (const auto& index : indexes_) { auto results = VERIFY_RESULT(index->Search(query_vector, options)); all_results.insert(all_results.end(), results.begin(), results.end()); diff --git a/src/yb/vector_index/usearch_wrapper.cc b/src/yb/vector_index/usearch_wrapper.cc index b9e067094538..f5386423ea12 100644 --- a/src/yb/vector_index/usearch_wrapper.cc +++ b/src/yb/vector_index/usearch_wrapper.cc @@ -197,7 +197,7 @@ class UsearchIndex : pointer_cast(lhs.data()), pointer_cast(rhs.data())); } - Result>> DoSearch( + Result>> DoSearch( const Vector& query_vector, const SearchOptions& options) const { SemaphoreLock lock(*search_semaphore_); auto usearch_results = index_.filtered_search( @@ -205,11 +205,11 @@ class UsearchIndex : RSTATUS_DCHECK( usearch_results, RuntimeError, "Failed to search a vector: $0", usearch_results.error.release()); - std::vector> result_vec; + std::vector> result_vec; result_vec.reserve(usearch_results.size()); for (size_t i = 0; i < usearch_results.size(); ++i) { auto match = usearch_results[i]; - result_vec.push_back(VertexWithDistance(match.member.key, match.distance)); + result_vec.push_back(VectorWithDistance(match.member.key, match.distance)); } return result_vec; } diff --git a/src/yb/vector_index/vector_index_if.h b/src/yb/vector_index/vector_index_if.h index 3c02d8603283..0fa2e81bffdb 100644 --- a/src/yb/vector_index/vector_index_if.h +++ b/src/yb/vector_index/vector_index_if.h @@ -35,7 +35,7 @@ class VectorIndexReaderIf; template class VectorIndexReaderIf { public: - using SearchResult = std::vector>; + using SearchResult = std::vector>; using IteratorValue = std::pair; using Iterator = PolymorphicIterator; diff --git a/src/yb/vector_index/vector_index_wrapper_util.h b/src/yb/vector_index/vector_index_wrapper_util.h index 1e5d34841b2c..8faf52158d38 100644 --- a/src/yb/vector_index/vector_index_wrapper_util.h +++ b/src/yb/vector_index/vector_index_wrapper_util.h @@ -62,7 +62,7 @@ class VectorIndexReaderAdapter for (const auto& source_result : source_results) { auto cast_distance = static_cast(source_result.distance); - destination_results.emplace_back(source_result.vertex_id, cast_distance); + destination_results.emplace_back(source_result.vector_id, cast_distance); } return destination_results; diff --git a/src/yb/vector_index/vector_lsm-test.cc b/src/yb/vector_index/vector_lsm-test.cc index 61ff801a4d89..d042a602e11a 100644 --- a/src/yb/vector_index/vector_lsm-test.cc +++ b/src/yb/vector_index/vector_lsm-test.cc @@ -39,27 +39,22 @@ using FloatVectorLSM = VectorLSM, float>; using TestUsearchIndexFactory = MakeVectorIndexFactory; using TestHnswlibIndexFactory = MakeVectorIndexFactory; -class SimpleVectorLSMKeyValueStorage : public VectorLSMKeyValueStorage { +class SimpleVectorLSMKeyValueStorage { public: SimpleVectorLSMKeyValueStorage() = default; - Status StoreBaseTableKeys(const BaseTableKeysBatch& batch, const VectorLSMInsertContext&) { - for (const auto& [vertex_id, base_table_key] : batch) { - storage_.emplace(vertex_id, KeyBuffer(base_table_key)); - } - return Status::OK(); + void StoreVector(const vector_index::VectorId& vector_id, size_t index) { + storage_.emplace(vector_id, index); } - Result ReadBaseTableKey(VectorId vertex_id) { - auto it = storage_.find(vertex_id); - if (it == storage_.end()) { - return STATUS_FORMAT(NotFound, "Vertex id not found: $0", vertex_id); - } + size_t GetVectorIndex(VectorId vector_id) { + auto it = storage_.find(vector_id); + CHECK(it != storage_.end()); return it->second; } private: - std::unordered_map storage_; + std::unordered_map storage_; }; class TestFrontier : public rocksdb::UserFrontier { @@ -170,10 +165,6 @@ class VectorLSMTest : public YBTest, public testing::WithParamInterface> d) & 1); } result.emplace_back(FloatVectorLSM::InsertEntry { - .vertex_id = VectorId::GenerateRandom(), - .base_table_key = KeyBuffer(Slice(VertexKey(i))), + .vector_id = VectorId::GenerateRandom(), .vector = std::move(vector), }); } @@ -226,8 +216,12 @@ Status VectorLSMTest::InsertCube( } FloatVectorLSM::InsertEntries block_entries(begin, end); TestFrontiers frontiers; - frontiers.Smallest().SetVertexId(block_entries.front().vertex_id); - frontiers.Largest().SetVertexId(block_entries.front().vertex_id); + frontiers.Smallest().SetVertexId(block_entries.front().vector_id); + frontiers.Largest().SetVertexId(block_entries.front().vector_id); + for (; begin != end; ++begin) { + key_value_storage_.StoreVector( + begin->vector_id, begin - inserted_entries_.begin() + 1); + } RETURN_NOT_OK(lsm.Insert(block_entries, { .frontiers = &frontiers })); } return Status::OK(); @@ -250,7 +244,6 @@ Status VectorLSMTest::OpenVectorLSM( return factory(hnsw_options); }, .points_per_chunk = points_per_chunk, - .key_value_storage = &key_value_storage_, .thread_pool = &thread_pool_, .frontiers_factory = [] { return std::make_unique(); }, }; @@ -275,13 +268,10 @@ void VectorLSMTest::CheckQueryVector( FloatVectorLSM::SearchResults expected_results; for (const auto& entry : inserted_entries_) { - expected_results.push_back({ - .distance = lsm.Distance(query_vector, entry.vector), - .base_table_key = entry.base_table_key, - }); + expected_results.emplace_back(entry.vector_id, lsm.Distance(query_vector, entry.vector)); } auto less_condition = [](const auto& lhs, const auto& rhs) { - return lhs.distance == rhs.distance ? lhs.base_table_key < rhs.base_table_key + return lhs.distance == rhs.distance ? lhs.vector_id < rhs.vector_id : lhs.distance < rhs.distance; }; std::sort(expected_results.begin(), expected_results.end(), less_condition); @@ -303,7 +293,7 @@ void VectorLSMTest::CheckQueryVector( for (size_t i = 0; i != expected_results.size(); ++i) { ASSERT_EQ(search_result[i].distance, expected_results[i].distance); - ASSERT_EQ(search_result[i].base_table_key, expected_results[i].base_table_key); + ASSERT_EQ(search_result[i].vector_id, expected_results[i].vector_id); } } } @@ -349,7 +339,7 @@ void VectorLSMTest::TestBootstrap(bool flush) { if (frontier_ptr) { const auto frontier_vertex_id = down_cast(frontier_ptr.get())->vertex_id(); for (; frontier_entry_idx < inserted_entries_.size(); ++frontier_entry_idx) { - if (inserted_entries_[frontier_entry_idx].vertex_id == frontier_vertex_id) { + if (inserted_entries_[frontier_entry_idx].vector_id == frontier_vertex_id) { break; } } @@ -378,7 +368,7 @@ TEST_P(VectorLSMTest, NotSavedChunk) { TEST_F(VectorLSMTest, MergeChunkResults) { const auto kIds = GenerateVectorIds(7); - using ChunkResults = std::vector>; + using ChunkResults = std::vector>; ChunkResults a_src = {{kIds[4], 1}, {kIds[2], 3}, {kIds[0], 5}, {kIds[5], 7}}; ChunkResults b_src = {{kIds[1], 2}, {kIds[2], 3}, {kIds[3], 4}, {kIds[6], 7}, {kIds[5], 7}}; for (size_t i = 1; i != a_src.size() + b_src.size(); ++i) { diff --git a/src/yb/vector_index/vector_lsm.cc b/src/yb/vector_index/vector_lsm.cc index 06b2241662e3..839d8877fe95 100644 --- a/src/yb/vector_index/vector_lsm.cc +++ b/src/yb/vector_index/vector_lsm.cc @@ -90,8 +90,8 @@ class VectorLSMInsertTask : public: using Types = VectorLSMTypes; using InsertRegistry = typename Types::InsertRegistry; - using VertexWithDistance = typename Types::VertexWithDistance; - using SearchHeap = std::priority_queue; + using VectorWithDistance = typename Types::VectorWithDistance; + using SearchHeap = std::priority_queue; using LSM = VectorLSM; using MutableChunk = typename LSM::MutableChunk; @@ -125,7 +125,7 @@ class VectorLSMInsertTask : continue; } auto distance = chunk_->index->Distance(query_vector, vector); - VertexWithDistance vertex(id, distance); + VectorWithDistance vertex(id, distance); if (heap.size() < options.max_num_results) { heap.push(vertex); } else if (heap.top() > vertex) { @@ -151,7 +151,7 @@ class VectorLSMInsertRegistry { using Types = VectorLSMTypes; using LSM = VectorLSM; using VectorIndex = typename Types::VectorIndex; - using VertexWithDistance = typename Types::VertexWithDistance; + using VectorWithDistance = typename Types::VectorWithDistance; using InsertTask = VectorLSMInsertTask; using InsertTaskList = boost::intrusive::list; using InsertTaskPtr = std::unique_ptr; @@ -490,16 +490,13 @@ Status VectorLSM::Insert( auto tasks = insert_registry_->AllocateTasks(*this, chunk, num_tasks); auto tasks_it = tasks.begin(); size_t index_in_task = 0; - BaseTableKeysBatch keys_batch; - for (auto& [vertex_id, base_table_key, v] : entries) { + for (auto& [vertex_id, v] : entries) { if (index_in_task++ >= entries_per_task) { ++tasks_it; index_in_task = 0; } tasks_it->Add(vertex_id, std::move(v)); - keys_batch.emplace_back(vertex_id, base_table_key.AsSlice()); } - RETURN_NOT_OK(options_.key_value_storage->StoreBaseTableKeys(keys_batch, context)); insert_registry_->ExecuteTasks(tasks); return Status::OK(); @@ -512,8 +509,8 @@ Status VectorLSM::Insert( // Expects that results_with_chunk and chunk_results already ordered by distance. template void MergeChunkResults( - std::vector>& combined_results, - std::vector>& chunk_results, + std::vector>& combined_results, + std::vector>& chunk_results, size_t max_num_results) { // Store the current size of the existing results. auto old_size = std::min(combined_results.size(), max_num_results); @@ -529,7 +526,7 @@ void MergeChunkResults( while (it != end) { if (entry > *it) { ++it; - } else if (entry.vertex_id == it->vertex_id) { + } else if (entry.vector_id == it->vector_id) { return true; } else { break; @@ -618,17 +615,7 @@ auto VectorLSM::Search( MergeChunkResults(intermediate_results, chunk_results, options.max_num_results); } - SearchResults final_results; - final_results.reserve(intermediate_results.size()); - for (const auto& [vertex_id, distance] : intermediate_results) { - auto base_table_key = VERIFY_RESULT(options_.key_value_storage->ReadBaseTableKey(vertex_id)); - final_results.push_back({ - .distance = distance, - .base_table_key = std::move(base_table_key) - }); - } - - return final_results; + return intermediate_results; } template @@ -878,8 +865,8 @@ Status VectorLSM::RemoveUpdateQueueEntry(size_t order_no YB_INSTANTIATE_TEMPLATE_FOR_ALL_VECTOR_AND_DISTANCE_RESULT_TYPES(VectorLSM); template void MergeChunkResults( - std::vector>& combined_results, - std::vector>& chunk_results, + std::vector>& combined_results, + std::vector>& chunk_results, size_t max_num_results); } // namespace yb::vector_index diff --git a/src/yb/vector_index/vector_lsm.h b/src/yb/vector_index/vector_lsm.h index b47df57589f5..da0b7eedf176 100644 --- a/src/yb/vector_index/vector_lsm.h +++ b/src/yb/vector_index/vector_lsm.h @@ -29,24 +29,10 @@ namespace yb::vector_index { -template -struct VectorLSMSearchEntry { - DistanceResult distance; - // base_table_key could be the encoded DocKey of the corresponding row in the base - // (indexed) table, and the hybrid time of the vector insertion. - KeyBuffer base_table_key; - - std::string ToString() const { - return YB_STRUCT_TO_STRING( - distance, (base_table_key, base_table_key.AsSlice().ToDebugHexString())); - } -}; - template struct VectorLSMInsertEntry { - VectorId vertex_id; - KeyBuffer base_table_key; - Vector vector; + VectorId vector_id; + Vector vector; }; template; using VectorIndexPtr = VectorIndexIfPtr; using VectorIndexFactory = vector_index::VectorIndexFactory; - using SearchResults = std::vector>; + using SearchResults = typename VectorIndex::SearchResult; using InsertEntry = VectorLSMInsertEntry; using InsertEntries = std::vector; using Options = VectorLSMOptions; using InsertRegistry = VectorLSMInsertRegistry; - using VertexWithDistance = vector_index::VertexWithDistance; + using VectorWithDistance = vector_index::VectorWithDistance; }; -using BaseTableKeysBatch = std::vector>; - struct VectorLSMInsertContext { const rocksdb::UserFrontiers* frontiers = nullptr; }; -class VectorLSMKeyValueStorage { - public: - virtual Status StoreBaseTableKeys( - const BaseTableKeysBatch& batch, const VectorLSMInsertContext& context) = 0; - - virtual Result ReadBaseTableKey(VectorId vertex_id) = 0; - - virtual ~VectorLSMKeyValueStorage() = default; -}; - template struct VectorLSMOptions { @@ -96,7 +70,6 @@ struct VectorLSMOptions { std::string storage_dir; typename Types::VectorIndexFactory vector_index_factory; size_t points_per_chunk; - VectorLSMKeyValueStorage* key_value_storage; rpc::ThreadPool* thread_pool; std::function frontiers_factory; }; @@ -204,8 +177,8 @@ using MakeVectorIndexFactory = template void MergeChunkResults( - std::vector>& combined_results, - std::vector>& chunk_results, + std::vector>& combined_results, + std::vector>& chunk_results, size_t max_num_results); } // namespace yb::vector_index diff --git a/src/yb/vector_index/vectorann.cc b/src/yb/vector_index/vectorann.cc index ccefdd2054e6..d65fdc4d2b17 100644 --- a/src/yb/vector_index/vectorann.cc +++ b/src/yb/vector_index/vectorann.cc @@ -119,7 +119,7 @@ class DummyANN final : public VectorANN { continue; } - auto it = vectors_.find(vd.vertex_id); + auto it = vectors_.find(vd.vector_id); CHECK(it != vectors_.end()); // Sanity check, it is expected the vector exists. out.push_back(DocKeyWithDistance(std::get(it->second), vd.distance)); diff --git a/src/yb/vector_index/vectorann_util.h b/src/yb/vector_index/vectorann_util.h index fc6fd1a3077d..7fbf6a753f23 100644 --- a/src/yb/vector_index/vectorann_util.h +++ b/src/yb/vector_index/vectorann_util.h @@ -56,18 +56,18 @@ class DocKeyWithDistance { bool operator>(const DocKeyWithDistance& other) const { return Compare(other) > 0; } }; -// Our default comparator for VertexWithDistance already orders the pairs by increasing distance. +// Our default comparator for VectorWithDistance already orders the pairs by increasing distance. template using MaxDistanceQueue = - std::priority_queue, - std::vector>>; + std::priority_queue, + std::vector>>; -// Drain a max-queue of (vertex, distance) pairs and return a list of VertexWithDistance instances +// Drain a max-queue of (vertex, distance) pairs and return a list of VectorWithDistance instances // ordered by increasing distance. template auto DrainMaxQueueToIncreasingDistanceList(MaxDistanceQueue& queue) { - std::vector> result_list; + std::vector> result_list; while (!queue.empty()) { result_list.push_back(queue.top()); queue.pop(); @@ -82,7 +82,7 @@ auto DrainMaxQueueToIncreasingDistanceList(MaxDistanceQueue& que // multiple results having the same distance from the query, results with lower vertex ids are // preferred. template -std::vector> BruteForcePreciseNearestNeighbors( +std::vector> BruteForcePreciseNearestNeighbors( const Vector& query, const std::vector& vertex_ids, const VertexIdToVectorDistanceFunction& distance_fn, @@ -93,7 +93,7 @@ std::vector> BruteForcePreciseNearestNeighbor MaxDistanceQueue queue; for (const auto& vertex_id : vertex_ids) { auto distance = distance_fn(vertex_id, query); - auto new_element = VertexWithDistance(vertex_id, distance); + auto new_element = VectorWithDistance(vertex_id, distance); if (queue.size() < num_results || new_element < queue.top()) { // Add a new element if there is a room in the result set, or if the new element is better // than the worst element of the result set. The comparsion is done using the (distance,