From edef1b2df6cc714c28e169b34baa7171cfcdd183 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Thu, 14 Nov 2024 17:28:40 +0900 Subject: [PATCH 01/14] Initial commit: Add TwoPhaseKnnVectorQuery --- .../lucene/search/AbstractKnnVectorQuery.java | 2 +- .../lucene/search/KnnFloatVectorQuery.java | 2 +- .../lucene/search/TwoPhaseKnnVectorQuery.java | 111 ++++++++++++++++++ 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index df19de6cc8d8..d51fce32eb65 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -116,7 +116,7 @@ private TopDocs searchLeaf( return results; } - private TopDocs getLeafResults( + protected TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 585893fa3c2a..2bf214850830 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final float[] target; + final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java new file mode 100644 index 000000000000..31f9f5fddd75 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; + +public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { + + private final int originalK; + private final double oversample; + + public TwoPhaseKnnVectorQuery( + String field, float[] target, int k, double oversample, Query filter) { + super(field, target, k + (int) Math.round(k * oversample), filter); + if (oversample < 0) { + throw new IllegalArgumentException("oversample must be non-negative, got " + oversample); + } + this.originalK = k; + this.oversample = oversample; + } + + @Override + protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { + return TopDocs.merge(originalK, perLeafResults); + } + + @Override + protected TopDocs getLeafResults( + LeafReaderContext context, + Weight filterWeight, + TimeLimitingKnnCollectorManager knnCollectorManager) + throws IOException { + TopDocs results = super.getLeafResults(context, filterWeight, knnCollectorManager); + if (results.scoreDocs.length <= originalK) { + // short-circuit: no re-ranking needed. we got what we need + return results; + } + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null) { + return results; + } + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field); + if (floatVectorValues == null) { + return results; + } + + for (int i = 0; i < results.scoreDocs.length; i++) { + // get the raw vector value + float[] vectorValue = floatVectorValues.vectorValue(results.scoreDocs[i].doc); + + // recompute the score + results.scoreDocs[i].score = fi.getVectorSimilarityFunction().compare(vectorValue, target); + } + + // Sort the ScoreDocs by the new scores in descending order + Arrays.sort(results.scoreDocs, (a, b) -> Float.compare(b.score, a.score)); + + // Select the top-k ScoreDocs after re-ranking + ScoreDoc[] topKDocs = Arrays.copyOfRange(results.scoreDocs, 0, originalK); + + return new TopDocs(results.totalHits, topKDocs); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Objects.hash(originalK, oversample); + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (super.equals(o) == false) return false; + TwoPhaseKnnVectorQuery that = (TwoPhaseKnnVectorQuery) o; + return oversample == that.oversample && originalK == that.originalK; + } + + @Override + public String toString(String field) { + return getClass().getSimpleName() + + ":" + + this.field + + "[" + + target[0] + + ",...][" + + originalK + + "][" + + oversample + + "]"; + } +} From bbc7081cbcf59ba29ef1b2d09586a5ba692b1015 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 15 Nov 2024 15:15:46 +0900 Subject: [PATCH 02/14] Add tests --- .../services/org.apache.lucene.codecs.Codec | 1 + .../search/TestTwoPhaseKnnVectorQuery.java | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java diff --git a/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec b/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec index 8c7c0df63966..0512052bdc3a 100644 --- a/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec +++ b/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec @@ -15,3 +15,4 @@ org.apache.lucene.codecs.TestMinimalCodec$MinimalCodec org.apache.lucene.codecs.TestMinimalCodec$MinimalCompoundCodec +org.apache.lucene.search.TestTwoPhaseKnnVectorQuery$QuantizedCodec diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java new file mode 100644 index 000000000000..54021e8063bd --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java @@ -0,0 +1,102 @@ +package org.apache.lucene.search; + +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +public class TestTwoPhaseKnnVectorQuery { + + private static final String FIELD = "vector"; + public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = VectorSimilarityFunction.COSINE; + private Directory directory; + private IndexWriterConfig config; + private static final int NUM_VECTORS = 1000; + private static final int VECTOR_DIMENSION = 128; + + @Before + public void setUp() throws Exception { + directory = new ByteBuffersDirectory(); + + // Set up the IndexWriterConfig to use quantized vector storage + config = new IndexWriterConfig(); + config.setCodec(new QuantizedCodec()); + } + + @Test + public void testTwoPhaseKnnVectorQuery() throws Exception { + Map vectors = new HashMap<>(); + + // Step 1: Index random vectors in quantized format + try (IndexWriter writer = new IndexWriter(directory, config)) { + Random random = new Random(); + for (int i = 0; i < NUM_VECTORS; i++) { + float[] vector = randomFloatVector(VECTOR_DIMENSION, random); + Document doc = new Document(); + doc.add(new IntField("id", i, Field.Store.YES)); + doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); + writer.addDocument(doc); + vectors.put(i, vector); + } + } + + // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] targetVector = randomFloatVector(VECTOR_DIMENSION, new Random()); + int k = 10; + double oversample = 1.0; + + TwoPhaseKnnVectorQuery query = new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null); + TopDocs topDocs = searcher.search(query, k); + + // Step 3: Verify that TopDocs scores match similarity with unquantized vectors + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); + float[] docVector = vectors.get(retrievedDoc.getField("id").numericValue().intValue()); + float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); + Assert.assertEquals( + "Score does not match expected similarity for docId: " + scoreDoc.doc, + expectedScore, scoreDoc.score, 1e-5); + } + } + } + + private float[] randomFloatVector(int dimension, Random random) { + float[] vector = new float[dimension]; + for (int i = 0; i < dimension; i++) { + vector[i] = random.nextFloat(); + } + return vector; + } + + public static class QuantizedCodec extends FilterCodec { + + public QuantizedCodec() { + super("QuantizedCodec", new Lucene100Codec()); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene99HnswScalarQuantizedVectorsFormat(); + } + } +} From 96d298734912cdc5f117f2910aa30ad35522ad8d Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 15 Nov 2024 15:22:04 +0900 Subject: [PATCH 03/14] Remove forbidden API --- .../lucene/search/TwoPhaseKnnVectorQuery.java | 5 +- .../search/TestTwoPhaseKnnVectorQuery.java | 131 +++++++++--------- 2 files changed, 71 insertions(+), 65 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java index 31f9f5fddd75..2c9c622bb86c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.ArrayUtil; public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { @@ -75,7 +76,9 @@ protected TopDocs getLeafResults( Arrays.sort(results.scoreDocs, (a, b) -> Float.compare(b.score, a.score)); // Select the top-k ScoreDocs after re-ranking - ScoreDoc[] topKDocs = Arrays.copyOfRange(results.scoreDocs, 0, originalK); + ScoreDoc[] topKDocs = ArrayUtil.copyOfSubArray(results.scoreDocs, 0, originalK); + + assert topKDocs.length == originalK; return new TopDocs(results.totalHits, topKDocs); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java index 54021e8063bd..f1a35c6f3143 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java @@ -1,5 +1,8 @@ package org.apache.lucene.search; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene100.Lucene100Codec; @@ -19,84 +22,84 @@ import org.junit.Before; import org.junit.Test; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; - public class TestTwoPhaseKnnVectorQuery { - private static final String FIELD = "vector"; - public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = VectorSimilarityFunction.COSINE; - private Directory directory; - private IndexWriterConfig config; - private static final int NUM_VECTORS = 1000; - private static final int VECTOR_DIMENSION = 128; + private static final String FIELD = "vector"; + public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = + VectorSimilarityFunction.COSINE; + private Directory directory; + private IndexWriterConfig config; + private static final int NUM_VECTORS = 1000; + private static final int VECTOR_DIMENSION = 128; - @Before - public void setUp() throws Exception { - directory = new ByteBuffersDirectory(); + @Before + public void setUp() throws Exception { + directory = new ByteBuffersDirectory(); - // Set up the IndexWriterConfig to use quantized vector storage - config = new IndexWriterConfig(); - config.setCodec(new QuantizedCodec()); - } + // Set up the IndexWriterConfig to use quantized vector storage + config = new IndexWriterConfig(); + config.setCodec(new QuantizedCodec()); + } - @Test - public void testTwoPhaseKnnVectorQuery() throws Exception { - Map vectors = new HashMap<>(); + @Test + public void testTwoPhaseKnnVectorQuery() throws Exception { + Map vectors = new HashMap<>(); - // Step 1: Index random vectors in quantized format - try (IndexWriter writer = new IndexWriter(directory, config)) { - Random random = new Random(); - for (int i = 0; i < NUM_VECTORS; i++) { - float[] vector = randomFloatVector(VECTOR_DIMENSION, random); - Document doc = new Document(); - doc.add(new IntField("id", i, Field.Store.YES)); - doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); - writer.addDocument(doc); - vectors.put(i, vector); - } - } + // Step 1: Index random vectors in quantized format + try (IndexWriter writer = new IndexWriter(directory, config)) { + Random random = new Random(); + for (int i = 0; i < NUM_VECTORS; i++) { + float[] vector = randomFloatVector(VECTOR_DIMENSION, random); + Document doc = new Document(); + doc.add(new IntField("id", i, Field.Store.YES)); + doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); + writer.addDocument(doc); + vectors.put(i, vector); + } + } - // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector - try (IndexReader reader = DirectoryReader.open(directory)) { - IndexSearcher searcher = new IndexSearcher(reader); - float[] targetVector = randomFloatVector(VECTOR_DIMENSION, new Random()); - int k = 10; - double oversample = 1.0; + // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] targetVector = randomFloatVector(VECTOR_DIMENSION, new Random()); + int k = 10; + double oversample = 1.0; - TwoPhaseKnnVectorQuery query = new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null); - TopDocs topDocs = searcher.search(query, k); + TwoPhaseKnnVectorQuery query = + new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null); + TopDocs topDocs = searcher.search(query, k); - // Step 3: Verify that TopDocs scores match similarity with unquantized vectors - for (ScoreDoc scoreDoc : topDocs.scoreDocs) { - Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); - float[] docVector = vectors.get(retrievedDoc.getField("id").numericValue().intValue()); - float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); - Assert.assertEquals( - "Score does not match expected similarity for docId: " + scoreDoc.doc, - expectedScore, scoreDoc.score, 1e-5); - } - } + // Step 3: Verify that TopDocs scores match similarity with unquantized vectors + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); + float[] docVector = vectors.get(retrievedDoc.getField("id").numericValue().intValue()); + float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); + Assert.assertEquals( + "Score does not match expected similarity for docId: " + scoreDoc.doc, + expectedScore, + scoreDoc.score, + 1e-5); + } } + } - private float[] randomFloatVector(int dimension, Random random) { - float[] vector = new float[dimension]; - for (int i = 0; i < dimension; i++) { - vector[i] = random.nextFloat(); - } - return vector; + private float[] randomFloatVector(int dimension, Random random) { + float[] vector = new float[dimension]; + for (int i = 0; i < dimension; i++) { + vector[i] = random.nextFloat(); } + return vector; + } - public static class QuantizedCodec extends FilterCodec { + public static class QuantizedCodec extends FilterCodec { - public QuantizedCodec() { - super("QuantizedCodec", new Lucene100Codec()); - } + public QuantizedCodec() { + super("QuantizedCodec", new Lucene100Codec()); + } - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new Lucene99HnswScalarQuantizedVectorsFormat(); - } + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene99HnswScalarQuantizedVectorsFormat(); } + } } From e2ab4bcf19d8fe10c8b50be720c2e49e2b953228 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 15 Nov 2024 15:22:04 +0900 Subject: [PATCH 04/14] Remove forbidden API --- .../lucene/search/AbstractKnnVectorQuery.java | 2 +- .../lucene/search/TwoPhaseKnnVectorQuery.java | 12 ++++++--- .../search/TestTwoPhaseKnnVectorQuery.java | 26 ++++++++++++++++--- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index d51fce32eb65..df19de6cc8d8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -116,7 +116,7 @@ private TopDocs searchLeaf( return results; } - protected TopDocs getLeafResults( + private TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java index 2c9c622bb86c..fd8538e1e07e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -22,7 +22,9 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.Bits; public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { @@ -45,12 +47,14 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { } @Override - protected TopDocs getLeafResults( + protected TopDocs approximateSearch( LeafReaderContext context, - Weight filterWeight, - TimeLimitingKnnCollectorManager knnCollectorManager) + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) throws IOException { - TopDocs results = super.getLeafResults(context, filterWeight, knnCollectorManager); + TopDocs results = + super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); if (results.scoreDocs.length <= originalK) { // short-circuit: no re-ranking needed. we got what we need return results; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java index f1a35c6f3143..099a44366fb4 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.lucene.search; import java.util.HashMap; @@ -18,11 +34,12 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -public class TestTwoPhaseKnnVectorQuery { +public class TestTwoPhaseKnnVectorQuery extends LuceneTestCase { private static final String FIELD = "vector"; public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = @@ -33,7 +50,9 @@ public class TestTwoPhaseKnnVectorQuery { private static final int VECTOR_DIMENSION = 128; @Before + @Override public void setUp() throws Exception { + super.setUp(); directory = new ByteBuffersDirectory(); // Set up the IndexWriterConfig to use quantized vector storage @@ -45,9 +64,10 @@ public void setUp() throws Exception { public void testTwoPhaseKnnVectorQuery() throws Exception { Map vectors = new HashMap<>(); + Random random = random(); + // Step 1: Index random vectors in quantized format try (IndexWriter writer = new IndexWriter(directory, config)) { - Random random = new Random(); for (int i = 0; i < NUM_VECTORS; i++) { float[] vector = randomFloatVector(VECTOR_DIMENSION, random); Document doc = new Document(); @@ -61,7 +81,7 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector try (IndexReader reader = DirectoryReader.open(directory)) { IndexSearcher searcher = new IndexSearcher(reader); - float[] targetVector = randomFloatVector(VECTOR_DIMENSION, new Random()); + float[] targetVector = randomFloatVector(VECTOR_DIMENSION, random); int k = 10; double oversample = 1.0; From 4e3297191b7d09f8b044e6e7bc78fbdf49a55d18 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Sun, 17 Nov 2024 22:53:42 +0900 Subject: [PATCH 05/14] Add javadoc --- .../lucene/search/TwoPhaseKnnVectorQuery.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java index fd8538e1e07e..c2adccb00a25 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; @@ -26,11 +27,24 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; +/** A subclass of KnnFloatVectorQuery which does oversampling and full-precision reranking. */ public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { private final int originalK; private final double oversample; + /** + * Find the k nearest documents to the target vector according to the vectors in the + * given field. target vector. It also over-samples by oversample parameter and does + * full precision reranking if oversample > 0 + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search + * @param k the number of documents to find + * @param oversample the oversampling factor, a value of 0 means no oversampling + * @param filter a filter applied before the vector search + * @throws IllegalArgumentException if k is less than 1 + */ public TwoPhaseKnnVectorQuery( String field, float[] target, int k, double oversample, Query filter) { super(field, target, k + (int) Math.round(k * oversample), filter); From ccd3e25c62252f9777e148798d8184cf4438ec45 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Sun, 17 Nov 2024 22:58:26 +0900 Subject: [PATCH 06/14] Make the Query experimental --- .../org/apache/lucene/search/TwoPhaseKnnVectorQuery.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java index c2adccb00a25..84be6fc5a663 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -27,7 +27,11 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; -/** A subclass of KnnFloatVectorQuery which does oversampling and full-precision reranking. */ +/** + * A subclass of KnnFloatVectorQuery which does oversampling and full-precision reranking. + * + * @lucene.experimental + */ public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { private final int originalK; From f9da336ed935d8783795a86b67dab647124bfce3 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Mon, 18 Nov 2024 10:16:49 +0900 Subject: [PATCH 07/14] Use Math.ceil instead of rounding --- .../java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java index 84be6fc5a663..132794d63ca7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -51,7 +51,7 @@ public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { */ public TwoPhaseKnnVectorQuery( String field, float[] target, int k, double oversample, Query filter) { - super(field, target, k + (int) Math.round(k * oversample), filter); + super(field, target, k + (int) Math.ceil(k * oversample), filter); if (oversample < 0) { throw new IllegalArgumentException("oversample must be non-negative, got " + oversample); } From 8d88cab9f290fc9744187f275784862035a3e966 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Mon, 18 Nov 2024 15:38:58 +0900 Subject: [PATCH 08/14] Store target separately in child class --- .../src/java/org/apache/lucene/search/KnnFloatVectorQuery.java | 2 +- .../java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 2bf214850830..585893fa3c2a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - final float[] target; + private final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java index 132794d63ca7..16223e99c3c6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java @@ -36,6 +36,7 @@ public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { private final int originalK; private final double oversample; + private final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the @@ -55,6 +56,7 @@ public TwoPhaseKnnVectorQuery( if (oversample < 0) { throw new IllegalArgumentException("oversample must be non-negative, got " + oversample); } + this.target = target; this.originalK = k; this.oversample = oversample; } From b67637a8d0c79ac0dc411bf83a6f5a747b9591dc Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Thu, 21 Nov 2024 17:02:14 +0900 Subject: [PATCH 09/14] Change abstraction to wrap around KNN query --- .../lucene/search/AbstractKnnVectorQuery.java | 14 +- .../search/RerankKnnFloatVectorQuery.java | 110 ++++++++++++++ .../lucene/search/TwoPhaseKnnVectorQuery.java | 138 ------------------ .../services/org.apache.lucene.codecs.Codec | 2 +- ...ava => TestRerankKnnFloatVectorQuery.java} | 7 +- 5 files changed, 122 insertions(+), 149 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java delete mode 100644 lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java rename lucene/core/src/test/org/apache/lucene/search/{TestTwoPhaseKnnVectorQuery.java => TestRerankKnnFloatVectorQuery.java} (94%) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index df19de6cc8d8..ce7c1ea41012 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -99,7 +99,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (topK.scoreDocs.length == 0) { return new MatchNoDocsQuery(); } - return createRewrittenQuery(reader, topK); + return createRewrittenQuery(reader, topK.scoreDocs); } private TopDocs searchLeaf( @@ -255,18 +255,18 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { return TopDocs.merge(k, perLeafResults); } - private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { - int len = topK.scoreDocs.length; + static Query createRewrittenQuery(IndexReader reader, ScoreDoc[] scoreDocs) { + int len = scoreDocs.length; assert len > 0; - float maxScore = topK.scoreDocs[0].score; + float maxScore = scoreDocs[0].score; - Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); + Arrays.sort(scoreDocs, Comparator.comparingInt(a -> a.doc)); int[] docs = new int[len]; float[] scores = new float[len]; for (int i = 0; i < len; i++) { - docs[i] = topK.scoreDocs[i].doc; - scores[i] = topK.scoreDocs[i].score; + docs[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; } int[] segmentStarts = findSegmentStarts(reader.leaves(), docs); return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id()); diff --git a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java new file mode 100644 index 000000000000..eb1f65f54863 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.AbstractKnnVectorQuery.createRewrittenQuery; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.VectorSimilarityFunction; + +/** + * A wrapper of KnnFloatVectorQuery which does full-precision reranking. + * + * @lucene.experimental + */ +public class RerankKnnFloatVectorQuery extends Query { + + private final int k; + private final float[] target; + private final KnnFloatVectorQuery query; + + /** + * Execute the KnnFloatVectorQuery and re-rank using full-precision vectors + * + * @param query the KNN query to execute as initial phase + * @param target the target of the search + * @param k the number of documents to find + * @throws IllegalArgumentException if k is less than 1 + */ + public RerankKnnFloatVectorQuery(KnnFloatVectorQuery query, float[] target, int k) { + this.query = query; + this.target = target; + this.k = k; + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + IndexReader reader = indexSearcher.getIndexReader(); + Query rewritten = indexSearcher.rewrite(query); + Weight weight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f); + HitQueue queue = new HitQueue(k, false); + for (var leaf : reader.leaves()) { + Scorer scorer = weight.scorer(leaf); + if (scorer == null) { + continue; + } + FloatVectorValues floatVectorValues = leaf.reader().getFloatVectorValues(query.getField()); + if (floatVectorValues == null) { + continue; + } + FieldInfo fi = leaf.reader().getFieldInfos().fieldInfo(query.getField()); + VectorSimilarityFunction comparer = fi.getVectorSimilarityFunction(); + DocIdSetIterator iterator = scorer.iterator(); + while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + int docId = iterator.docID(); + float[] vectorValue = floatVectorValues.vectorValue(docId); + float score = comparer.compare(vectorValue, target); + queue.insertWithOverflow(new ScoreDoc(docId, score)); + } + } + int i = 0; + ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()]; + for (ScoreDoc topDoc : queue) { + scoreDocs[i++] = topDoc; + } + return createRewrittenQuery(reader, scoreDocs); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(target); + result = 31 * result + Objects.hash(query, k); + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + RerankKnnFloatVectorQuery that = (RerankKnnFloatVectorQuery) o; + return Objects.equals(query, that.query) && k == that.k; + } + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor); + } + + @Override + public String toString(String field) { + return getClass().getSimpleName() + ":" + query.toString(field) + "[" + k + "]"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java deleted file mode 100644 index 16223e99c3c6..000000000000 --- a/lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.search; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Objects; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.knn.KnnCollectorManager; -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.Bits; - -/** - * A subclass of KnnFloatVectorQuery which does oversampling and full-precision reranking. - * - * @lucene.experimental - */ -public class TwoPhaseKnnVectorQuery extends KnnFloatVectorQuery { - - private final int originalK; - private final double oversample; - private final float[] target; - - /** - * Find the k nearest documents to the target vector according to the vectors in the - * given field. target vector. It also over-samples by oversample parameter and does - * full precision reranking if oversample > 0 - * - * @param field a field that has been indexed as a {@link KnnFloatVectorField}. - * @param target the target of the search - * @param k the number of documents to find - * @param oversample the oversampling factor, a value of 0 means no oversampling - * @param filter a filter applied before the vector search - * @throws IllegalArgumentException if k is less than 1 - */ - public TwoPhaseKnnVectorQuery( - String field, float[] target, int k, double oversample, Query filter) { - super(field, target, k + (int) Math.ceil(k * oversample), filter); - if (oversample < 0) { - throw new IllegalArgumentException("oversample must be non-negative, got " + oversample); - } - this.target = target; - this.originalK = k; - this.oversample = oversample; - } - - @Override - protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { - return TopDocs.merge(originalK, perLeafResults); - } - - @Override - protected TopDocs approximateSearch( - LeafReaderContext context, - Bits acceptDocs, - int visitedLimit, - KnnCollectorManager knnCollectorManager) - throws IOException { - TopDocs results = - super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); - if (results.scoreDocs.length <= originalK) { - // short-circuit: no re-ranking needed. we got what we need - return results; - } - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null) { - return results; - } - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field); - if (floatVectorValues == null) { - return results; - } - - for (int i = 0; i < results.scoreDocs.length; i++) { - // get the raw vector value - float[] vectorValue = floatVectorValues.vectorValue(results.scoreDocs[i].doc); - - // recompute the score - results.scoreDocs[i].score = fi.getVectorSimilarityFunction().compare(vectorValue, target); - } - - // Sort the ScoreDocs by the new scores in descending order - Arrays.sort(results.scoreDocs, (a, b) -> Float.compare(b.score, a.score)); - - // Select the top-k ScoreDocs after re-ranking - ScoreDoc[] topKDocs = ArrayUtil.copyOfSubArray(results.scoreDocs, 0, originalK); - - assert topKDocs.length == originalK; - - return new TopDocs(results.totalHits, topKDocs); - } - - @Override - public int hashCode() { - int result = super.hashCode(); - result = 31 * result + Objects.hash(originalK, oversample); - return result; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (super.equals(o) == false) return false; - TwoPhaseKnnVectorQuery that = (TwoPhaseKnnVectorQuery) o; - return oversample == that.oversample && originalK == that.originalK; - } - - @Override - public String toString(String field) { - return getClass().getSimpleName() - + ":" - + this.field - + "[" - + target[0] - + ",...][" - + originalK - + "][" - + oversample - + "]"; - } -} diff --git a/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec b/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec index 0512052bdc3a..3905502af834 100644 --- a/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec +++ b/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec @@ -15,4 +15,4 @@ org.apache.lucene.codecs.TestMinimalCodec$MinimalCodec org.apache.lucene.codecs.TestMinimalCodec$MinimalCompoundCodec -org.apache.lucene.search.TestTwoPhaseKnnVectorQuery$QuantizedCodec +org.apache.lucene.search.TestRerankKnnFloatVectorQuery$QuantizedCodec diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java similarity index 94% rename from lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java rename to lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java index 099a44366fb4..a90494eaf15f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java @@ -39,7 +39,7 @@ import org.junit.Before; import org.junit.Test; -public class TestTwoPhaseKnnVectorQuery extends LuceneTestCase { +public class TestRerankKnnFloatVectorQuery extends LuceneTestCase { private static final String FIELD = "vector"; public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = @@ -85,8 +85,9 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { int k = 10; double oversample = 1.0; - TwoPhaseKnnVectorQuery query = - new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null); + KnnFloatVectorQuery knnQuery = + new KnnFloatVectorQuery(FIELD, targetVector, k + (int) (k * oversample)); + RerankKnnFloatVectorQuery query = new RerankKnnFloatVectorQuery(knnQuery, targetVector, k); TopDocs topDocs = searcher.search(query, k); // Step 3: Verify that TopDocs scores match similarity with unquantized vectors From 8cd3ccf44162a3e64a9867b49719245b6f93fca4 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Thu, 21 Nov 2024 21:08:00 +0900 Subject: [PATCH 10/14] Fix doc ord bug & flush writer multiple times --- .../lucene/search/RerankKnnFloatVectorQuery.java | 2 +- .../search/TestRerankKnnFloatVectorQuery.java | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java index eb1f65f54863..db8180dfafbf 100644 --- a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java @@ -73,7 +73,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { int docId = iterator.docID(); float[] vectorValue = floatVectorValues.vectorValue(docId); float score = comparer.compare(vectorValue, target); - queue.insertWithOverflow(new ScoreDoc(docId, score)); + queue.insertWithOverflow(new ScoreDoc(leaf.docBase + docId, score)); } } int i = 0; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java index a90494eaf15f..373523969d51 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java @@ -46,7 +46,6 @@ public class TestRerankKnnFloatVectorQuery extends LuceneTestCase { VectorSimilarityFunction.COSINE; private Directory directory; private IndexWriterConfig config; - private static final int NUM_VECTORS = 1000; private static final int VECTOR_DIMENSION = 128; @Before @@ -66,15 +65,22 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { Random random = random(); + int numVectors = atLeast(1000); + // Step 1: Index random vectors in quantized format try (IndexWriter writer = new IndexWriter(directory, config)) { - for (int i = 0; i < NUM_VECTORS; i++) { + for (int i = 0; i < numVectors; i++) { float[] vector = randomFloatVector(VECTOR_DIMENSION, random); Document doc = new Document(); doc.add(new IntField("id", i, Field.Store.YES)); doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); writer.addDocument(doc); vectors.put(i, vector); + + // flush to create multiple segments + if (random.nextInt(10) == 0) { + writer.flush(); + } } } @@ -93,10 +99,11 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { // Step 3: Verify that TopDocs scores match similarity with unquantized vectors for (ScoreDoc scoreDoc : topDocs.scoreDocs) { Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); - float[] docVector = vectors.get(retrievedDoc.getField("id").numericValue().intValue()); + int id = retrievedDoc.getField("id").numericValue().intValue(); + float[] docVector = vectors.get(id); float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); Assert.assertEquals( - "Score does not match expected similarity for docId: " + scoreDoc.doc, + "Score does not match expected similarity for doc ord: " + scoreDoc.doc + ", id: " + id, expectedScore, scoreDoc.score, 1e-5); From 30e377a361c3e213d1e91f16d7287c355c111121 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 22 Nov 2024 09:28:00 +0900 Subject: [PATCH 11/14] Add null check --- .../org/apache/lucene/search/RerankKnnFloatVectorQuery.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java index db8180dfafbf..753ab22f95dc 100644 --- a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java @@ -67,6 +67,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { continue; } FieldInfo fi = leaf.reader().getFieldInfos().fieldInfo(query.getField()); + if (fi == null) { + continue; + } VectorSimilarityFunction comparer = fi.getVectorSimilarityFunction(); DocIdSetIterator iterator = scorer.iterator(); while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { @@ -95,7 +98,7 @@ public int hashCode() { public boolean equals(Object o) { if (this == o) return true; RerankKnnFloatVectorQuery that = (RerankKnnFloatVectorQuery) o; - return Objects.equals(query, that.query) && k == that.k; + return Objects.equals(query, that.query) && Arrays.equals(target, that.target) && k == that.k; } @Override From 5d1910c1c651a94e12b0b806fcdfd60a7ad506aa Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 22 Nov 2024 09:44:47 +0900 Subject: [PATCH 12/14] Refactor test case --- .../search/TestRerankKnnFloatVectorQuery.java | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java index 373523969d51..533b09dccc73 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java @@ -42,11 +42,13 @@ public class TestRerankKnnFloatVectorQuery extends LuceneTestCase { private static final String FIELD = "vector"; - public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = + private static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = VectorSimilarityFunction.COSINE; + private static final int NUM_VECTORS = 1000; + private static final int VECTOR_DIMENSION = 128; + private Directory directory; private IndexWriterConfig config; - private static final int VECTOR_DIMENSION = 128; @Before @Override @@ -65,20 +67,21 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { Random random = random(); - int numVectors = atLeast(1000); + int numVectors = atLeast(NUM_VECTORS); + int numSegments = random.nextInt(2, 10); // Step 1: Index random vectors in quantized format try (IndexWriter writer = new IndexWriter(directory, config)) { - for (int i = 0; i < numVectors; i++) { - float[] vector = randomFloatVector(VECTOR_DIMENSION, random); - Document doc = new Document(); - doc.add(new IntField("id", i, Field.Store.YES)); - doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); - writer.addDocument(doc); - vectors.put(i, vector); - - // flush to create multiple segments - if (random.nextInt(10) == 0) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numVectors; i++) { + float[] vector = randomFloatVector(VECTOR_DIMENSION, random); + Document doc = new Document(); + int id = j * numVectors + i; + doc.add(new IntField("id", id, Field.Store.YES)); + doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); + writer.addDocument(doc); + vectors.put(id, vector); + writer.flush(); } } @@ -101,6 +104,7 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); int id = retrievedDoc.getField("id").numericValue().intValue(); float[] docVector = vectors.get(id); + assert docVector != null : "Vector for id " + id + " not found"; float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); Assert.assertEquals( "Score does not match expected similarity for doc ord: " + scoreDoc.doc + ", id: " + id, From feda6af1a78f72076a3a99c9225298c603176516 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 22 Nov 2024 10:05:35 +0900 Subject: [PATCH 13/14] Simplify Codec --- .../services/org.apache.lucene.codecs.Codec | 1 - .../search/TestRerankKnnFloatVectorQuery.java | 19 +++---------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec b/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec index 3905502af834..8c7c0df63966 100644 --- a/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec +++ b/lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec @@ -15,4 +15,3 @@ org.apache.lucene.codecs.TestMinimalCodec$MinimalCodec org.apache.lucene.codecs.TestMinimalCodec$MinimalCompoundCodec -org.apache.lucene.search.TestRerankKnnFloatVectorQuery$QuantizedCodec diff --git a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java index 533b09dccc73..fb8c9cf7e91c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java @@ -19,9 +19,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Random; -import org.apache.lucene.codecs.FilterCodec; -import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.lucene100.Lucene100Codec; import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -35,6 +32,7 @@ import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -58,7 +56,8 @@ public void setUp() throws Exception { // Set up the IndexWriterConfig to use quantized vector storage config = new IndexWriterConfig(); - config.setCodec(new QuantizedCodec()); + config.setCodec( + TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswScalarQuantizedVectorsFormat())); } @Test @@ -122,16 +121,4 @@ private float[] randomFloatVector(int dimension, Random random) { } return vector; } - - public static class QuantizedCodec extends FilterCodec { - - public QuantizedCodec() { - super("QuantizedCodec", new Lucene100Codec()); - } - - @Override - public KnnVectorsFormat knnVectorsFormat() { - return new Lucene99HnswScalarQuantizedVectorsFormat(); - } - } } From 3178bbc22e66abd1f4d87f0031f626407a0f43b4 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Tue, 26 Nov 2024 15:18:03 +0900 Subject: [PATCH 14/14] short-circuit for case there is no oversample --- .../org/apache/lucene/search/RerankKnnFloatVectorQuery.java | 4 ++++ .../apache/lucene/search/TestRerankKnnFloatVectorQuery.java | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java index 753ab22f95dc..a7bf7ee5be95 100644 --- a/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/RerankKnnFloatVectorQuery.java @@ -55,6 +55,10 @@ public RerankKnnFloatVectorQuery(KnnFloatVectorQuery query, float[] target, int public Query rewrite(IndexSearcher indexSearcher) throws IOException { IndexReader reader = indexSearcher.getIndexReader(); Query rewritten = indexSearcher.rewrite(query); + // short-circuit: don't re-rank if we already got all possible results + if (query.getK() <= k) { + return rewritten; + } Weight weight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f); HitQueue queue = new HitQueue(k, false); for (var leaf : reader.leaves()) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java index fb8c9cf7e91c..9f9ee842e6d3 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java @@ -91,7 +91,7 @@ public void testTwoPhaseKnnVectorQuery() throws Exception { IndexSearcher searcher = new IndexSearcher(reader); float[] targetVector = randomFloatVector(VECTOR_DIMENSION, random); int k = 10; - double oversample = 1.0; + double oversample = random.nextFloat(1.5f, 3.0f); KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(FIELD, targetVector, k + (int) (k * oversample));