-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathquery.hpp
337 lines (279 loc) · 13.7 KB
/
query.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
// Copyright (c) Lorenz Hübschle-Schneider
// Copyright (c) Facebook, Inc. and its affiliates.
// All Rights Reserved. This source code is licensed under the Apache 2.0
// License (found in the LICENSE file in the root directory).
#pragma once
#include "rocksdb/math.h"
#include <tlx/logger.hpp>
#include <tlx/logger/wrap_unprintable.hpp>
#include <bitset>
#include <cstddef>
#include <ios>
#include <utility>
namespace ribbon {
namespace {
template <typename Hasher, typename SolutionStorage>
inline bool CheckBumped([[maybe_unused]] typename Hasher::Index val,
typename Hasher::Index cval, typename Hasher::Index bucket,
const Hasher &hasher, const SolutionStorage &sol) {
[[maybe_unused]] constexpr bool debug = false;
constexpr bool oneBitThresh = Hasher::kThreshMode == ThreshMode::onebit;
// if constexpr because hasher.Get() doesn't exist for other threshold compressors
if constexpr (oneBitThresh) {
if (cval == 2) {
sol.PrefetchMeta(bucket);
const auto plusthresh = hasher.Get(bucket);
sLOG << "plus bumping:" << val << (val >= plusthresh ? ">=" : "<")
<< plusthresh << "bucket" << bucket;
return (val >= plusthresh);
}
}
return (cval >= sol.GetMeta(bucket));
}
} // namespace
// Common functionality for querying a key (already hashed) in
// SimpleSolutionStorage.
template <typename SimpleSolutionStorage, typename Hasher>
std::pair<bool, typename SimpleSolutionStorage::ResultRow> inline ShiftQueryHelper(
const Hasher &hasher, typename Hasher::Hash hash,
const SimpleSolutionStorage &sss) {
using Index = typename SimpleSolutionStorage::Index;
using CoeffRow = typename SimpleSolutionStorage::CoeffRow;
using ResultRow = typename SimpleSolutionStorage::ResultRow;
constexpr bool debug = false;
const Index start_slot = hasher.GetStart(hash, sss.GetNumStarts());
// prefetch result rows (or, for CLS, also metadata)
sss.PrefetchQuery(start_slot);
const Index bucket = hasher.GetBucket(start_slot);
Index val = hasher.GetIntraBucketFromStart(start_slot),
cval = hasher.Compress(val);
CoeffRow cr = hasher.GetCoeffs(hash);
if (CheckBumped(val, cval, bucket, hasher, sss)) {
sLOG << "Item was bumped, hash" << hash << "start" << start_slot
<< "bucket" << bucket << "val" << val << cval << "thresh"
<< (size_t)sss.GetMeta(bucket);
return std::make_pair(true, 0);
}
sLOG << "Searching in bucket" << bucket << "start" << start_slot << "val"
<< val << cval << "below thresh =" << (size_t)sss.GetMeta(bucket)
<< "coeffs" << std::hex << (uint64_t)cr << std::dec;
ResultRow result = 0;
while (cr) {
CoeffRow lsb = cr & -cr; // get the lowest set bit
int i = rocksdb::CountTrailingZeroBits(cr);
result ^= sss.GetResult(start_slot + i);
cr ^= lsb;
}
return std::make_pair(false, result);
}
// Common functionality for querying a key (already hashed) in
// SimpleSolutionStorage.
template <typename SimpleSolutionStorage, typename Hasher>
std::pair<bool, typename SimpleSolutionStorage::ResultRow>
SimpleQueryHelper(const Hasher &hasher, typename Hasher::Hash hash,
const SimpleSolutionStorage &sss) {
using Index = typename SimpleSolutionStorage::Index;
using CoeffRow = typename SimpleSolutionStorage::CoeffRow;
using ResultRow = typename SimpleSolutionStorage::ResultRow;
constexpr bool debug = false;
constexpr unsigned kCoeffBits = static_cast<unsigned>(sizeof(CoeffRow) * 8U);
const Index start_slot = hasher.GetStart(hash, sss.GetNumStarts());
// prefetch result rows (or, for CLS, also metadata)
sss.PrefetchQuery(start_slot);
const Index bucket = hasher.GetBucket(start_slot);
const Index val = hasher.GetIntraBucketFromStart(start_slot),
cval = hasher.Compress(val);
const CoeffRow cr = hasher.GetCoeffs(hash);
if (CheckBumped(val, cval, bucket, hasher, sss)) {
sLOG << "Item was bumped, hash" << hash << "start" << start_slot
<< "bucket" << bucket << "val" << val << cval << "thresh"
<< (size_t)sss.GetMeta(bucket);
return std::make_pair(true, 0);
}
sLOG << "Searching in bucket" << bucket << "start" << start_slot << "val"
<< val << cval << "below thresh =" << (size_t)sss.GetMeta(bucket)
<< "coeffs" << std::hex << (uint64_t)cr << std::dec;
ResultRow result = 0;
auto state = sss.PrepareGetResult(start_slot);
for (unsigned i = 0; i < kCoeffBits; ++i) {
// Bit masking whole value is generally faster here than 'if'
// if ((cr >> i) & 1)
// result ^= sss.GetResult(start_slot + i);
#ifdef RIBBON_CHECK
auto expected = sss.PrepareGetResult(start_slot + i);
assert(state == expected);
#endif
ResultRow row = sss.GetFromState(state);
result ^= row & (ResultRow{0} -
(static_cast<ResultRow>(cr >> i) & ResultRow{1}));
state = sss.AdvanceState(state);
if (debug && (static_cast<ResultRow>(cr >> i) & ResultRow{1})) {
LOG << "Coeff " << i << " set, using row " << std::hex
<< (uint64_t)row << std::dec;
}
}
return std::make_pair(false, result);
}
// General retrieval query a key from SimpleSolutionStorage.
template <typename SimpleSolutionStorage, typename Hasher>
std::pair<bool, typename SimpleSolutionStorage::ResultRow>
SimpleRetrievalQuery(const typename HashTraits<Hasher>::mhc_or_key_t &key,
const Hasher &hasher, const SimpleSolutionStorage &sss) {
const auto hash = hasher.GetHash(key);
static_assert(sizeof(typename SimpleSolutionStorage::Index) ==
sizeof(typename Hasher::Index),
"must be same");
static_assert(sizeof(typename SimpleSolutionStorage::CoeffRow) ==
sizeof(typename Hasher::CoeffRow),
"must be same");
// don't query an empty ribbon, please
assert(sss.GetNumSlots() >= Hasher::kCoeffBits);
return ShiftQueryHelper(hasher, hash, sss);
}
// Filter query a key from SimpleSolutionStorage.
template <typename SimpleSolutionStorage, typename Hasher>
std::pair<bool, bool>
SimpleFilterQuery(const typename HashTraits<Hasher>::mhc_or_key_t &key,
const Hasher &hasher, const SimpleSolutionStorage &sss) {
constexpr bool debug = false;
const auto hash = hasher.GetHash(key);
const typename SimpleSolutionStorage::ResultRow expected =
hasher.GetResultRowFromHash(hash);
static_assert(sizeof(typename SimpleSolutionStorage::Index) ==
sizeof(typename Hasher::Index),
"must be same");
static_assert(sizeof(typename SimpleSolutionStorage::CoeffRow) ==
sizeof(typename Hasher::CoeffRow),
"must be same");
static_assert(sizeof(typename SimpleSolutionStorage::ResultRow) ==
sizeof(typename Hasher::ResultRow),
"must be same");
// don't query an empty filter, please
assert(sss.GetNumSlots() >= Hasher::kCoeffBits);
auto [bumped, retrieved] = ShiftQueryHelper(hasher, hash, sss);
sLOG << "Key" << tlx::wrap_unprintable(key) << "b?" << bumped << "retrieved"
<< std::hex << (size_t)retrieved << "expected" << (size_t)expected
<< std::dec;
return std::make_pair(bumped, retrieved == expected);
}
/******************************************************************************/
// General retrieval query a key from InterleavedSolutionStorage.
template <typename InterleavedSolutionStorage, typename Hasher>
std::pair<bool, typename InterleavedSolutionStorage::ResultRow>
InterleavedRetrievalQuery(const typename HashTraits<Hasher>::mhc_or_key_t &key,
const Hasher &hasher,
const InterleavedSolutionStorage &iss) {
using Hash = typename Hasher::Hash;
using Index = typename InterleavedSolutionStorage::Index;
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
using ResultRow = typename InterleavedSolutionStorage::ResultRow;
static_assert(sizeof(Index) == sizeof(typename Hasher::Index),
"must be same");
static_assert(sizeof(CoeffRow) == sizeof(typename Hasher::CoeffRow),
"must be same");
constexpr bool debug = false;
constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
constexpr Index num_columns = InterleavedSolutionStorage::kResultBits;
// don't query an empty ribbon, please
assert(iss.GetNumSlots() >= kCoeffBits);
const Hash hash = hasher.GetHash(key);
const Index start_slot = hasher.GetStart(hash, iss.GetNumStarts());
const Index bucket = hasher.GetBucket(start_slot);
const Index start_block_num = start_slot / kCoeffBits;
Index segment = start_block_num * num_columns;
iss.PrefetchQuery(segment);
const Index val = hasher.GetIntraBucketFromStart(start_slot),
cval = hasher.Compress(val);
if (CheckBumped(val, cval, bucket, hasher, iss)) {
sLOG << "Item was bumped, hash" << hash << "start" << start_slot
<< "bucket" << bucket << "val" << val << cval << "thresh"
<< (size_t)iss.GetMeta(bucket);
return std::make_pair(true, 0);
}
sLOG << "Searching in bucket" << bucket << "start" << start_slot << "val"
<< val << cval << "below thresh =" << (size_t)iss.GetMeta(bucket);
const Index start_bit = start_slot % kCoeffBits;
const CoeffRow cr = hasher.GetCoeffs(hash);
ResultRow sr = 0;
const CoeffRow cr_left = cr << start_bit;
for (Index i = 0; i < num_columns; ++i) {
sr ^= rocksdb::BitParity(iss.GetSegment(segment + i) & cr_left) << i;
}
if (start_bit > 0) {
segment += num_columns;
const CoeffRow cr_right = cr >> (kCoeffBits - start_bit);
for (Index i = 0; i < num_columns; ++i) {
sr ^= rocksdb::BitParity(iss.GetSegment(segment + i) & cr_right) << i;
}
}
return std::make_pair(false, sr);
}
// Filter query a key from InterleavedFilterQuery.
template <typename InterleavedSolutionStorage, typename Hasher>
std::pair<bool, bool>
InterleavedFilterQuery(const typename HashTraits<Hasher>::mhc_or_key_t &key,
const Hasher &hasher,
const InterleavedSolutionStorage &iss) {
// BEGIN mostly copied from InterleavedRetrievalQuery
using Index = typename InterleavedSolutionStorage::Index;
using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
using ResultRow = typename InterleavedSolutionStorage::ResultRow;
static_assert(sizeof(Index) == sizeof(typename Hasher::Index),
"must be same");
static_assert(sizeof(CoeffRow) == sizeof(typename Hasher::CoeffRow),
"must be same");
constexpr bool debug = false;
constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
constexpr auto num_columns = InterleavedSolutionStorage::kResultBits;
// don't query an empty filter, please
assert(iss.GetNumSlots() >= kCoeffBits);
const typename HashTraits<Hasher>::hash_t hash = hasher.GetHash(key);
const Index start_slot = hasher.GetStart(hash, iss.GetNumStarts());
const Index bucket = hasher.GetBucket(start_slot);
const Index start_block_num = start_slot / kCoeffBits;
const Index segment = start_block_num * num_columns;
iss.PrefetchQuery(segment);
const Index val = hasher.GetIntraBucketFromStart(start_slot),
cval = hasher.Compress(val);
if (CheckBumped(val, cval, bucket, hasher, iss)) {
sLOG << "Item was bumped, hash" << hash << "start" << start_slot
<< "bucket" << bucket << "val" << val << cval << "thresh"
<< (size_t)iss.GetMeta(bucket);
return std::make_pair(true, false);
}
sLOG << "Searching for" << tlx::wrap_unprintable(key) << "in bucket"
<< bucket << "start" << start_slot << "val" << val << cval
<< "below thresh =" << (size_t)iss.GetMeta(bucket);
const Index start_bit = start_slot % kCoeffBits;
const CoeffRow cr = hasher.GetCoeffs(hash);
// END mostly copied from InterleavedRetrievalQuery.
const ResultRow expected = hasher.GetResultRowFromHash(hash);
sLOG << "\tSlot" << start_slot << "-> block" << start_block_num << "segment"
<< segment << "start_bit" << start_bit << "expecting" << std::hex
<< (size_t)expected << "="
<< std::bitset<sizeof(ResultRow) * 8u>(expected).to_string()
<< "coeffs" << cr << std::dec << "="
<< std::bitset<sizeof(CoeffRow) * 8u>(cr).to_string();
const CoeffRow cr_left = cr << start_bit;
const CoeffRow cr_right =
cr >> static_cast<unsigned>((kCoeffBits - start_bit) % kCoeffBits);
// This determines whether our two memory loads are to different
// addresses (common) or the same address (1/kCoeffBits chance)
const Index maybe_num_columns = (start_bit != 0) * num_columns;
for (Index i = 0; i < num_columns; ++i) {
CoeffRow soln_data =
(iss.GetSegment(segment + i) & cr_left) |
(iss.GetSegment(segment + maybe_num_columns + i) & cr_right);
if (rocksdb::BitParity(soln_data) != (static_cast<int>(expected >> i) & 1)) {
sLOG << "\tMismatch at bit" << i << "have"
<< rocksdb::BitParity(soln_data) << "expect"
<< ((expected >> i) & 1) << "soln_data"
<< std::bitset<sizeof(CoeffRow) * 8u>(soln_data).to_string();
return std::make_pair(false, false);
}
}
// otherwise, all match
LOG << "\tGot all the right bits";
return std::make_pair(false, true);
}
} // namespace ribbon