Skip to content

Commit

Permalink
[#25687] DocDB: Remove key value storage callbacks from vector LSM
Browse files Browse the repository at this point in the history
Summary:
In initial design vector LSM was responsible for binding vector id with indexed table key.
So we added a way to store this binding in regular DB.
Currently vector id is generated externally to vector LSM, so it could avoid dealing with indexed table keys, and this logic could be implemented in higher level.
Jira: DB-14945

Test Plan: Jenkins

Reviewers: arybochkin

Reviewed By: arybochkin

Subscribers: yql, ybase

Tags: #jenkins-ready

Differential Revision: https://phorge.dev.yugabyte.com/D41337
  • Loading branch information
spolitov committed Jan 21, 2025
1 parent 6eb1d81 commit a80e755
Show file tree
Hide file tree
Showing 22 changed files with 157 additions and 207 deletions.
2 changes: 1 addition & 1 deletion src/yb/docdb/pgsql_operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ class PgsqlVectorFilter {
}

bool operator()(const vector_index::VectorId& vector_id) {
auto key = VectorIdKey(vector_id);
auto key = dockv::VectorIdKey(vector_id);
// TODO(vector_index) handle failure
auto ybctid = CHECK_RESULT(iter_.impl().FetchDirect(key.AsSlice()));
if (ybctid.empty()) {
Expand Down
44 changes: 35 additions & 9 deletions src/yb/docdb/rocksdb_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include "yb/docdb/rocksdb_writer.h"

#include <boost/dynamic_bitset/dynamic_bitset.hpp>

#include "yb/common/row_mark.h"

#include "yb/docdb/conflict_resolution.h"
Expand All @@ -30,6 +32,7 @@
#include "yb/dockv/packed_value.h"
#include "yb/dockv/schema_packing.h"
#include "yb/dockv/value_type.h"
#include "yb/dockv/vector_id.h"

#include "yb/gutil/walltime.h"

Expand Down Expand Up @@ -688,7 +691,7 @@ Result<bool> ApplyIntentsContext::Entry(
}

if (vector_indexes_) {
RETURN_NOT_OK(ProcessVectorIndexes(intent.doc_path, decoded_value.body));
RETURN_NOT_OK(ProcessVectorIndexes(handler, intent.doc_path, decoded_value.body));
}

++write_id_;
Expand All @@ -704,25 +707,40 @@ Result<bool> ApplyIntentsContext::Entry(
return false;
}

Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) {
void ApplyIntentsContext::AddVectorIndexReverseEntry(
rocksdb::DirectWriteHandler* handler, Slice ybctid, Slice value) {
DocHybridTimeBuffer ht_buf;
auto encoded_write_time = ht_buf.EncodeWithValueType({ commit_ht_, write_id_ });

handler->Put(dockv::VectorIndexReverseEntryKeyParts(value, encoded_write_time), {&ybctid, 1});
}

Status ApplyIntentsContext::ProcessVectorIndexes(
rocksdb::DirectWriteHandler* handler, Slice key, Slice value) {
auto sizes = VERIFY_RESULT(dockv::DocKey::EncodedPrefixAndDocKeySizes(key));
if (sizes.doc_key_size < key.size()) {
auto entry_type = static_cast<KeyEntryType>(key[sizes.doc_key_size]);
if (entry_type == KeyEntryType::kColumnId) {
auto column_id = VERIFY_RESULT(ColumnId::FullyDecode(
key.WithoutPrefix(sizes.doc_key_size + 1)));
// We expect small amount of vector indexes, usually 1. So it is faster to iterate over them.
bool added_to_vector_index = false;
for (size_t i = 0; i != vector_indexes_->size(); ++i) {
if (!ApplyToVectorIndex(i)) {
continue;
}
const auto& vector_index = *(*vector_indexes_)[i];
auto table_key_prefix = vector_index.indexed_table_key_prefix();
if (key.starts_with(table_key_prefix) && vector_index.column_id() == column_id) {
auto ybctid = key.Prefix(sizes.doc_key_size).WithoutPrefix(table_key_prefix.size());
vector_index_batches_[i].push_back(VectorIndexInsertEntry {
.key = KeyBuffer(key.Prefix(sizes.doc_key_size).WithoutPrefix(table_key_prefix.size())),
.key = KeyBuffer(ybctid),
.value = ValueBuffer(value),
});
if (!added_to_vector_index) {
AddVectorIndexReverseEntry(handler, ybctid, value);
added_to_vector_index = true;
}
}
}
} else {
Expand All @@ -740,10 +758,10 @@ Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) {
switch (*packed_row_version) {
case dockv::PackedRowVersion::kV1:
return ProcessVectorIndexesForPackedRow<dockv::PackedRowDecoderV1>(
sizes.prefix_size, key, value);
handler, sizes.prefix_size, key, value);
case dockv::PackedRowVersion::kV2:
return ProcessVectorIndexesForPackedRow<dockv::PackedRowDecoderV2>(
sizes.prefix_size, key, value);
handler, sizes.prefix_size, key, value);
}
FATAL_INVALID_ENUM_VALUE(dockv::PackedRowVersion, *packed_row_version);
}
Expand All @@ -752,7 +770,7 @@ Status ApplyIntentsContext::ProcessVectorIndexes(Slice key, Slice value) {

template <class Decoder>
Status ApplyIntentsContext::ProcessVectorIndexesForPackedRow(
size_t prefix_size, Slice key, Slice value) {
rocksdb::DirectWriteHandler* handler, size_t prefix_size, Slice key, Slice value) {
value.consume_byte();

auto schema_version = narrow_cast<SchemaVersion>(VERIFY_RESULT(FastDecodeUnsignedVarInt(&value)));
Expand All @@ -771,6 +789,7 @@ Status ApplyIntentsContext::ProcessVectorIndexesForPackedRow(
}
Decoder decoder(*schema_packing_, value.data());

boost::dynamic_bitset<> columns_added_to_vector_index;
for (size_t i = 0; i != vector_indexes_->size(); ++i) {
if (!ApplyToVectorIndex(i)) {
continue;
Expand All @@ -786,10 +805,18 @@ Status ApplyIntentsContext::ProcessVectorIndexesForPackedRow(
continue;
}

auto ybctid = key.WithoutPrefix(table_key_prefix.size());
vector_index_batches_[i].push_back(VectorIndexInsertEntry {
.key = KeyBuffer(key.WithoutPrefix(table_key_prefix.size())),
.key = KeyBuffer(ybctid),
.value = ValueBuffer(*column_value),
});

size_t column_index = schema_packing_->GetIndex(vector_index.column_id());
columns_added_to_vector_index.resize(
std::max(columns_added_to_vector_index.size(), column_index + 1));
if (!columns_added_to_vector_index.test_set(column_index)) {
AddVectorIndexReverseEntry(handler, ybctid, *column_value);
}
}
return Status::OK();
}
Expand All @@ -804,8 +831,7 @@ Status ApplyIntentsContext::Complete(rocksdb::DirectWriteHandler* handler) {
DocHybridTime write_time { commit_ht_, write_id_ };
for (size_t i = 0; i != vector_index_batches_.size(); ++i) {
if (!vector_index_batches_[i].empty()) {
RETURN_NOT_OK((*vector_indexes_)[i]->Insert(
vector_index_batches_[i], frontiers(), handler, write_time));
RETURN_NOT_OK((*vector_indexes_)[i]->Insert(vector_index_batches_[i], frontiers()));
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/yb/docdb/rocksdb_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,11 @@ class ApplyIntentsContext : public IntentsWriterContext, public FrontierSchemaVe

private:
Result<bool> StoreApplyState(const Slice& key, rocksdb::DirectWriteHandler* handler);
Status ProcessVectorIndexes(Slice key, Slice value);
Status ProcessVectorIndexes(rocksdb::DirectWriteHandler* handler, Slice key, Slice value);
template <class Decoder>
Status ProcessVectorIndexesForPackedRow(size_t prefix_size, Slice key, Slice value);
Status ProcessVectorIndexesForPackedRow(
rocksdb::DirectWriteHandler* handler, size_t prefix_size, Slice key, Slice value);
void AddVectorIndexReverseEntry(rocksdb::DirectWriteHandler* handler, Slice ybctid, Slice value);

bool ApplyToRegularDB() const {
return apply_to_storages_.TestRegularDB();
Expand Down
77 changes: 17 additions & 60 deletions src/yb/docdb/vector_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

#include "yb/qlexpr/index.h"

#include "yb/rocksdb/write_batch.h"

#include "yb/util/decimal.h"
#include "yb/util/endian_util.h"
#include "yb/util/path_util.h"
Expand Down Expand Up @@ -119,8 +117,7 @@ Result<vector_index::VectorLSMInsertEntry<Vector>> ConvertEntry(

auto encoded = dockv::EncodedDocVectorValue::FromSlice(entry.value.AsSlice());
return vector_index::VectorLSMInsertEntry<Vector> {
.vertex_id = VERIFY_RESULT(encoded.DecodeId()),
.base_table_key = entry.key,
.vector_id = VERIFY_RESULT(encoded.DecodeId()),
.vector = VERIFY_RESULT(VectorFromBinary<Vector>(encoded.data)),
};
}
Expand All @@ -129,14 +126,9 @@ size_t EncodeDistance(float distance) {
return bit_cast<uint32_t>(util::CanonicalizeFloat(distance));
}

struct VectorIndexInsertContext : public vector_index::VectorLSMInsertContext {
rocksdb::DirectWriteHandler* handler;
DocHybridTime write_time;
};

template<vector_index::IndexableVectorType Vector,
vector_index::ValidDistanceResultType DistanceResult>
class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyValueStorage {
class VectorIndexImpl : public VectorIndex {
public:
VectorIndexImpl(
const TableId& table_id, Slice indexed_table_key_prefix, ColumnId column_id,
Expand Down Expand Up @@ -167,40 +159,45 @@ class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyVal
.vector_index_factory = VERIFY_RESULT((GetVectorLSMFactory<Vector, DistanceResult>(
idx_options.idx_type(), idx_options.dimensions()))),
.points_per_chunk = FLAGS_vector_index_initial_chunk_size,
.key_value_storage = this,
.thread_pool = &thread_pool,
.frontiers_factory = [] { return std::make_unique<docdb::ConsensusFrontiers>(); },
};
return lsm_.Open(std::move(lsm_options));
}

Status Insert(
const VectorIndexInsertEntries& entries,
const rocksdb::UserFrontiers* frontiers,
rocksdb::DirectWriteHandler* handler,
DocHybridTime write_time) override {
const VectorIndexInsertEntries& entries, const rocksdb::UserFrontiers* frontiers) override {
typename LSM::InsertEntries lsm_entries;
lsm_entries.reserve(entries.size());
for (const auto& entry : entries) {
lsm_entries.push_back(VERIFY_RESULT(ConvertEntry<Vector>(entry)));
}
VectorIndexInsertContext context;
context.frontiers = frontiers;
context.handler = handler;
context.write_time = write_time;
vector_index::VectorLSMInsertContext context {
.frontiers = frontiers,
};
return lsm_.Insert(lsm_entries, context);
}

Result<VectorIndexSearchResult> Search(
Slice vector, const vector_index::SearchOptions& options) override {
auto entries = VERIFY_RESULT(lsm_.Search(
VERIFY_RESULT(VectorFromYSQL<Vector>(vector)), options));

// TODO(vector-index): check if ReadOptions are required.
docdb::BoundedRocksDbIterator iter(doc_db_.regular, {}, doc_db_.key_bounds);

VectorIndexSearchResult result;
result.reserve(entries.size());
for (auto& entry : entries) {
auto key = dockv::VectorIdKey(entry.vector_id);
const auto& db_entry = iter.Seek(key.AsSlice());
if (!db_entry.Valid() || !db_entry.key.starts_with(key.AsSlice())) {
return STATUS_FORMAT(NotFound, "Vector not found: $0", entry.vector_id);
}

result.push_back(VectorIndexSearchResultEntry {
.encoded_distance = EncodeDistance(entry.distance),
.key = entry.base_table_key,
.key = KeyBuffer(db_entry.value),
});
}
return result;
Expand Down Expand Up @@ -237,38 +234,6 @@ class VectorIndexImpl : public VectorIndex, public vector_index::VectorLSMKeyVal
}

private:
Status StoreBaseTableKeys(
const vector_index::BaseTableKeysBatch& batch,
const vector_index::VectorLSMInsertContext& insert_context) override {
const auto& context = static_cast<const VectorIndexInsertContext&>(insert_context);
for (const auto& [vector_id, base_table_key] : batch) {
DocHybridTimeBuffer ht_buf;
auto kb = VectorIdKey(vector_id);
kb.Append(ht_buf.EncodeWithValueType(context.write_time));
auto kbs = kb.AsSlice();

ValueBuffer vb;
vb.Append(base_table_key);
auto vbs = vb.AsSlice();
context.handler->Put({&kbs, 1}, {&vbs, 1});
}

return Status::OK();
}

Result<KeyBuffer> ReadBaseTableKey(vector_index::VectorId vector_id) override {
// TODO(vector-index) check if ReadOptions are required.
docdb::BoundedRocksDbIterator iter(doc_db_.regular, {}, doc_db_.key_bounds);

auto key = VectorIdKey(vector_id);
const auto& entry = iter.Seek(key.AsSlice());
if (!entry.Valid()) {
return STATUS_FORMAT(NotFound, "Vector not found: $0", vector_id);
}

return KeyBuffer { entry.value };
}

std::string DirName() const {
return kVectorIndexDirPrefix + table_id_;
}
Expand Down Expand Up @@ -299,12 +264,4 @@ Result<VectorIndexPtr> CreateVectorIndex(
return result;
}

KeyBuffer VectorIdKey(vector_index::VectorId vector_id) {
KeyBuffer key;
key.PushBack(dockv::KeyEntryTypeAsChar::kVectorIndexMetadata);
key.PushBack(dockv::KeyEntryTypeAsChar::kVectorId);
key.Append(vector_id.AsSlice());
return key;
}

} // namespace yb::docdb
7 changes: 1 addition & 6 deletions src/yb/docdb/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ class VectorIndex {
virtual const std::string& path() const = 0;

virtual Status Insert(
const VectorIndexInsertEntries& entries,
const rocksdb::UserFrontiers* frontiers,
rocksdb::DirectWriteHandler* handler,
DocHybridTime write_time) = 0;
const VectorIndexInsertEntries& entries, const rocksdb::UserFrontiers* frontiers) = 0;
virtual Result<VectorIndexSearchResult> Search(
Slice vector, const vector_index::SearchOptions& options) = 0;
virtual Result<EncodedDistance> Distance(Slice lhs, Slice rhs) = 0;
Expand All @@ -77,8 +74,6 @@ Result<VectorIndexPtr> CreateVectorIndex(
const qlexpr::IndexInfo& index_info,
const DocDB& doc_db);

KeyBuffer VectorIdKey(vector_index::VectorId vector_id);

extern const std::string kVectorIndexDirPrefix;

} // namespace yb::docdb
17 changes: 17 additions & 0 deletions src/yb/dockv/vector_id.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace {

constexpr const size_t kEncodedVectorIdValueSize = 1 + kUuidSize;
constexpr const size_t kEncodedVectorIdSize = kEncodedVectorIdValueSize + 1;
constexpr std::array<char, 2> kVectorIdKeyPrefix =
{ dockv::KeyEntryTypeAsChar::kVectorIndexMetadata, dockv::KeyEntryTypeAsChar::kVectorId };

char* GrowAtLeast(std::string* buffer, size_t size) {
const auto current_size = buffer->size();
Expand Down Expand Up @@ -120,4 +122,19 @@ bool IsNull(const dockv::DocVectorValue& v) {
return IsNull(v.value());
}

KeyBuffer VectorIdKey(vector_index::VectorId vector_id) {
KeyBuffer key;
key.Append(Slice(kVectorIdKeyPrefix));
key.Append(vector_id.AsSlice());
return key;
}

std::array<Slice, 3> VectorIndexReverseEntryKeyParts(Slice value, Slice encoded_write_time) {
return std::array<Slice, 3>{
Slice(kVectorIdKeyPrefix),
EncodedDocVectorValue::FromSlice(value).id,
encoded_write_time,
};
}

} // namespace yb::dockv
3 changes: 3 additions & 0 deletions src/yb/dockv/vector_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ class DocVectorValue final {

bool IsNull(const DocVectorValue& v);

KeyBuffer VectorIdKey(vector_index::VectorId vector_id);
std::array<Slice, 3> VectorIndexReverseEntryKeyParts(Slice value, Slice encoded_write_time);

} // namespace yb::dockv
10 changes: 5 additions & 5 deletions src/yb/vector_index/ann_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ namespace yb::vector_index {
namespace {

template<ValidDistanceResultType DistanceResult>
using VerticesWithDistances = std::vector<VertexWithDistance<DistanceResult>>;
using VerticesWithDistances = std::vector<VectorWithDistance<DistanceResult>>;

template<ValidDistanceResultType DistanceResult>
std::vector<VectorId> VertexIdsOnly(
const VerticesWithDistances<DistanceResult>& vertices_with_distances) {
std::vector<VectorId> result;
result.reserve(vertices_with_distances.size());
for (const auto& v_dist : vertices_with_distances) {
result.push_back(v_dist.vertex_id);
result.push_back(v_dist.vector_id);
}
return result;
}
Expand Down Expand Up @@ -228,14 +228,14 @@ Status GroundTruth<Vector, DistanceResult>::ProcessQuery(
}

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
std::vector<VertexWithDistance<DistanceResult>>
std::vector<VectorWithDistance<DistanceResult>>
GroundTruth<Vector, DistanceResult>::AugmentWithDistancesAndTrimToK(
const std::vector<VectorId>& precomputed_correct_results,
const Vector& query) {
VerticesWithDistances<DistanceResult> result;
result.reserve(k_);
for (auto vertex_id : precomputed_correct_results) {
result.push_back(VertexWithDistance<DistanceResult>(vertex_id, distance_fn_(vertex_id, query)));
result.push_back(VectorWithDistance<DistanceResult>(vertex_id, distance_fn_(vertex_id, query)));
if (result.size() == k_) {
break;
}
Expand All @@ -255,7 +255,7 @@ void GroundTruth<Vector, DistanceResult>::DoApproxSearchAndUpdateStats(
vector_cast<Vector>(query), SearchOptions{.max_num_results = k_}));
std::unordered_set<VectorId> approx_set;
for (const auto& approx_entry : approx_result) {
approx_set.insert(approx_entry.vertex_id);
approx_set.insert(approx_entry.vector_id);
}

size_t overlap = 0;
Expand Down
Loading

0 comments on commit a80e755

Please sign in to comment.