Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary vector format for flat and hnsw vectors #14078

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
exports org.apache.lucene.codecs.lucene95;
exports org.apache.lucene.codecs.lucene99;
exports org.apache.lucene.codecs.lucene101;
exports org.apache.lucene.codecs.lucene102;
exports org.apache.lucene.codecs.perfield;
exports org.apache.lucene.codecs;
exports org.apache.lucene.document;
Expand Down Expand Up @@ -76,7 +77,9 @@
provides org.apache.lucene.codecs.KnnVectorsFormat with
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene102;

import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;

import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;

/** Binarized byte vector values */
abstract class BinarizedByteVectorValues extends ByteVectorValues {

/**
* Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
* distances, the corrective terms are, in order
*
* <ul>
* <li>the lower optimized interval
* <li>the upper optimized interval
* <li>the dot-product of the non-centered vector with the centroid
* <li>the sum of quantized components
* </ul>
*
* For euclidean:
*
* <ul>
* <li>the lower optimized interval
* <li>the upper optimized interval
* <li>the l2norm of the centered vector
* <li>the sum of quantized components
* </ul>
*
* @param vectorOrd the vector ordinal
* @return the corrective terms
* @throws IOException if an I/O error occurs
*/
public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd)
throws IOException;

/**
* @return the quantizer used to quantize the vectors
*/
public abstract OptimizedScalarQuantizer getQuantizer();

public abstract float[] getCentroid() throws IOException;

int discretizedDimensions() {
return discretize(dimension(), 64);
}

/**
* Return a {@link VectorScorer} for the given query vector.
*
* @param query the query vector
* @return a {@link VectorScorer} instance or null
*/
public abstract VectorScorer scorer(float[] query) throws IOException;

@Override
public abstract BinarizedByteVectorValues copy() throws IOException;

float getCentroidDP() throws IOException {
// this only gets executed on-merge
float[] centroid = getCentroid();
return VectorUtil.dotProduct(centroid, centroid);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene102;

import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult;

/** Vector scorer over binarized vector values */
public class Lucene102BinaryFlatVectorsScorer implements FlatVectorsScorer {
private final FlatVectorsScorer nonQuantizedDelegate;
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);

public Lucene102BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
this.nonQuantizedDelegate = nonQuantizedDelegate;
}

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof BinarizedByteVectorValues) {
throw new UnsupportedOperationException(
"getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format");
}
return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer();
float[] centroid = binarizedVectors.getCentroid();
// We make a copy as the quantization process mutates the input
float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
if (similarityFunction == COSINE) {
VectorUtil.l2normalize(copy);
}
target = copy;
byte[] initial = new byte[target.length];
byte[] quantized = new byte[QUERY_BITS * binarizedVectors.discretizedDimensions() / 8];
OptimizedScalarQuantizer.QuantizationResult queryCorrections =
quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
transposeHalfByte(initial, quantized);
BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections);
return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
}
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}

RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
BinarizedByteVectorValues targetVectors) {
return new BinarizedRandomVectorScorerSupplier(
scoringVectors, targetVectors, similarityFunction);
}

@Override
public String toString() {
return "Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
}

/** Vector scorer supplier over binarized vector values */
static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
private final Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues
queryVectors;
private final BinarizedByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

BinarizedRandomVectorScorerSupplier(
Lucene102BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
BinarizedByteVectorValues targetVectors,
VectorSimilarityFunction similarityFunction) {
this.queryVectors = queryVectors;
this.targetVectors = targetVectors;
this.similarityFunction = similarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] vector = queryVectors.vectorValue(ord);
QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord);
BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms);
return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinarizedRandomVectorScorerSupplier(
queryVectors.copy(), targetVectors.copy(), similarityFunction);
}
}

/** A binarized query representing its quantized form along with factors */
public record BinaryQueryVector(
byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {}

/** Vector scorer over binarized vector values */
public static class BinarizedRandomVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {
private final BinaryQueryVector queryVector;
private final BinarizedByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

public BinarizedRandomVectorScorer(
BinaryQueryVector queryVectors,
BinarizedByteVectorValues targetVectors,
VectorSimilarityFunction similarityFunction) {
super(targetVectors);
this.queryVector = queryVectors;
this.targetVectors = targetVectors;
this.similarityFunction = similarityFunction;
}

@Override
public float score(int targetOrd) throws IOException {
byte[] quantizedQuery = queryVector.vector();
byte[] binaryCode = targetVectors.vectorValue(targetOrd);
float qcDist = VectorUtil.int4BitDotProduct(quantizedQuery, binaryCode);
OptimizedScalarQuantizer.QuantizationResult queryCorrections =
queryVector.quantizationResult();
OptimizedScalarQuantizer.QuantizationResult indexCorrections =
targetVectors.getCorrectiveTerms(targetOrd);
float x1 = indexCorrections.quantizedComponentSum();
float ax = indexCorrections.lowerInterval();
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = indexCorrections.upperInterval() - ax;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score =
ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
score =
queryCorrections.additionalCorrection()
+ indexCorrections.additionalCorrection()
- 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score +=
queryCorrections.additionalCorrection()
+ indexCorrections.additionalCorrection()
- targetVectors.getCentroidDP();
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}
}
}
Loading
Loading