forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.h
395 lines (302 loc) · 11.8 KB
/
utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD+Patents license found in the
* LICENSE file in the root directory of this source tree.
*/
/** Copyright 2004-present Facebook. All Rights Reserved
* -*- c++ -*-
*
* A few utilitary functions for similarity search:
* - random generators
* - optimized exhaustive distance and knn search functions
* - some functions reimplemented from torch for speed
*/
#ifndef FAISS_utils_h
#define FAISS_utils_h
#include <stdint.h>
// for the random data struct
#include <cstdlib>
#include "Heap.h"
namespace faiss {
/**************************************************
* Get some stats about the system
**************************************************/
/// ms elapsed since some arbitrary epoch
double getmillisecs ();
/// get current RSS usage in kB
size_t get_mem_usage_kb ();
/**************************************************
* Random data generation functions
**************************************************/
/// random generator that can be used in multithreaded contexts
struct RandomGenerator {
#ifdef __linux__
char rand_state [8];
struct random_data rand_data;
#elif __APPLE__
unsigned rand_state;
#endif
/// random 31-bit positive integer
int rand_int ();
/// random long < 2 ^ 62
long rand_long ();
/// generate random number between 0 and max-1
int rand_int (int max);
/// between 0 and 1
float rand_float ();
double rand_double ();
/// initialize
explicit RandomGenerator (long seed = 1234);
/// default copy constructor messes up pointer in rand_data
RandomGenerator (const RandomGenerator & other);
};
/* Generate an array of uniform random floats / multi-threaded implementation */
void float_rand (float * x, size_t n, long seed);
void float_randn (float * x, size_t n, long seed);
void long_rand (long * x, size_t n, long seed);
void byte_rand (uint8_t * x, size_t n, long seed);
/* random permutation */
void rand_perm (int * perm, size_t n, long seed);
/*********************************************************
* Optimized distance/norm/inner prod computations
*********************************************************/
/// Squared L2 distance between two vectors
float fvec_L2sqr (
const float * x,
const float * y,
size_t d);
/* SSE-implementation of inner product and L2 distance */
float fvec_inner_product (
const float * x,
const float * y,
size_t d);
/// a balanced assignment has a IF of 1
double imbalance_factor (int n, int k, const long *assign);
/// same, takes a histogram as input
double imbalance_factor (int k, const int *hist);
/** Compute pairwise distances between sets of vectors
*
* @param d dimension of the vectors
* @param nq nb of query vectors
* @param nb nb of database vectors
* @param xq query vectors (size nq * d)
* @param xb database vectros (size nb * d)
* @param dis output distances (size nq * nb)
* @param ldq,ldb, ldd strides for the matrices
*/
void pairwise_L2sqr (long d,
long nq, const float *xq,
long nb, const float *xb,
float *dis,
long ldq = -1, long ldb = -1, long ldd = -1);
/* compute the inner product between nx vectors x and one y */
void fvec_inner_products_ny (
float * ip, /* output inner product */
const float * x,
const float * y,
size_t d, size_t ny);
/* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
void fvec_L2sqr_ny (
float * __restrict dis,
const float * x,
const float * y,
size_t d, size_t ny);
/** squared norm of a vector */
float fvec_norm_L2sqr (const float * x,
size_t d);
/** compute the L2 norms for a set of vectors
*
* @param ip output norms, size nx
* @param x set of vectors, size nx * d
*/
void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx);
/// same as fvec_norms_L2, but computes square norms
void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx);
/* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
void fvec_renorm_L2 (size_t d, size_t nx, float * x);
/* This function exists because the Torch counterpart is extremly slow
(not multi-threaded + unexpected overhead even in single thread).
It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
void inner_product_to_L2sqr (float * __restrict dis,
const float * nr1,
const float * nr2,
size_t n1, size_t n2);
/***************************************************************************
* Compute a subset of distances
***************************************************************************/
/* compute the inner product between x and a subset y of ny vectors,
whose indices are given by idy. */
void fvec_inner_products_by_idx (
float * __restrict ip,
const float * x,
const float * y,
const long * __restrict ids,
size_t d, size_t nx, size_t ny);
/* same but for a subset in y indexed by idsy (ny vectors in total) */
void fvec_L2sqr_by_idx (
float * __restrict dis,
const float * x,
const float * y,
const long * __restrict ids, /* ids of y vecs */
size_t d, size_t nx, size_t ny);
/***************************************************************************
* KNN functions
***************************************************************************/
// threshold on nx above which we switch to BLAS to compute distances
extern int distance_compute_blas_threshold;
/** Return the k nearest neighors of each of the nx vectors x among the ny
* vector y, w.r.t to max inner product
*
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param res result array, which also provides k. Sorted on output
*/
void knn_inner_product (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_minheap_array_t * res);
/** Same as knn_inner_product, for the L2 distance */
void knn_L2sqr (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res);
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
* computed distances.
*
* @param base_shift size ny
*/
void knn_L2sqr_base_shift (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res,
const float *base_shift);
/* Find the nearest neighbors for nx queries in a set of ny vectors
* indexed by ids. May be useful for re-ranking a pre-selected vector list
*/
void knn_inner_products_by_idx (
const float * x,
const float * y,
const long * ids,
size_t d, size_t nx, size_t ny,
float_minheap_array_t * res);
void knn_L2sqr_by_idx (const float * x,
const float * y,
const long * __restrict ids,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res);
/***************************************************************************
* Range search
***************************************************************************/
/// Forward declaration, see AuxIndexStructures.h
struct RangeSearchResult;
/** Return the k nearest neighors of each of the nx vectors x among the ny
* vector y, w.r.t to max inner product
*
* @param x query vectors, size nx * d
* @param y database vectors, size ny * d
* @param radius search radius around the x vectors
* @param result result structure
*/
void range_search_L2sqr (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *result);
/// same as range_search_L2sqr for the inner product similarity
void range_search_inner_product (
const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float radius,
RangeSearchResult *result);
/***************************************************************************
* Misc matrix and vector manipulation functions
***************************************************************************/
/** compute c := a + bf * b for a, b and c tables
*
* @param n size of the tables
* @param a size n
* @param b size n
* @param c restult table, size n
*/
void fvec_madd (size_t n, const float *a,
float bf, const float *b, float *c);
/** same as fvec_madd, also return index of the min of the result table
* @return index of the min of table c
*/
int fvec_madd_and_argmin (size_t n, const float *a,
float bf, const float *b, float *c);
/* perform a reflection (not an efficient implementation, just for test ) */
void reflection (const float * u, float * x, size_t n, size_t d, size_t nu);
/** For k-means: update stage.
*
* @param x training vectors, size n * d
* @param centroids centroid vectors, size k * d
* @param assign nearest centroid for each training vector, size n
* @param k_frozen do not update the k_frozen first centroids
* @return nb of spliting operations to fight empty clusters
*/
int km_update_centroids (
const float * x,
float * centroids,
long * assign,
size_t d, size_t k, size_t n,
size_t k_frozen);
/** compute the Q of the QR decomposition for m > n
* @param a size n * m: input matrix and output Q
*/
void matrix_qr (int m, int n, float *a);
/** distances are supposed to be sorted. Sorts indices with same distance*/
void ranklist_handle_ties (int k, long *idx, const float *dis);
/** count the number of comon elements between v1 and v2
* algorithm = sorting + bissection to avoid double-counting duplicates
*/
size_t ranklist_intersection_size (size_t k1, const long *v1,
size_t k2, const long *v2);
/** merge a result table into another one
*
* @param I0, D0 first result table, size (n, k)
* @param I1, D1 second result table, size (n, k)
* @param keep_min if true, keep min values, otherwise keep max
* @param translation add this value to all I1's indexes
* @return nb of values that were taken from the second table
*/
size_t merge_result_table_with (size_t n, size_t k,
long *I0, float *D0,
const long *I1, const float *D1,
bool keep_min = true,
long translation = 0);
void fvec_argsort (size_t n, const float *vals,
size_t *perm);
void fvec_argsort_parallel (size_t n, const float *vals,
size_t *perm);
/// compute histogram on v
int ivec_hist (size_t n, const int * v, int vmax, int *hist);
/** Compute histogram of bits on a code array
*
* @param codes size(n, nbits / 8)
* @param hist size(nbits): nb of 1s in the array of codes
*/
void bincode_hist(size_t n, size_t nbits, const uint8_t *codes, int *hist);
/// compute a checksum on a table.
size_t ivec_checksum (size_t n, const int *a);
/** random subsamples a set of vectors if there are too many of them
*
* @param d dimension of the vectors
* @param n on input: nb of input vectors, output: nb of output vectors
* @param nmax max nb of vectors to keep
* @param x input array, size *n-by-d
* @param seed random seed to use for sampling
* @return x or an array allocated with new [] with *n vectors
*/
const float *fvecs_maybe_subsample (
size_t d, size_t *n, size_t nmax, const float *x,
bool verbose = false, long seed = 1234);
} // namspace faiss
#endif /* FAISS_utils_h */