Skip to content

Commit

Permalink
Concurrent HNSW merging, rebased, not tested after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaih committed Oct 25, 2023
1 parent 7795927 commit a81714d
Show file tree
Hide file tree
Showing 18 changed files with 899 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
Expand Down Expand Up @@ -151,6 +152,9 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
*/
public static final int DEFAULT_BEAM_WIDTH = 100;

/** Default to use single thread merge */
public static final int DEFAULT_NUM_MERGE_WORKER = 1;

static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;

/**
Expand All @@ -169,20 +173,36 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
/** Should this codec scalar quantize float32 vectors and use this format */
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;

private final int numMergeWorkers;
private final ExecutorService mergeExec;

/** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
}

public Lucene99HnswVectorsFormat(
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
this(maxConn, beamWidth, null, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param scalarQuantize the scalar quantization format
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge
*/
public Lucene99HnswVectorsFormat(
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
int maxConn,
int beamWidth,
Lucene99ScalarQuantizedVectorsFormat scalarQuantize,
int numMergeWorkers,
ExecutorService mergeExec) {
super("Lucene99HnswVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
Expand All @@ -198,14 +218,25 @@ public Lucene99HnswVectorsFormat(
+ "; beamWidth="
+ beamWidth);
}
if (numMergeWorkers > 1 && mergeExec == null) {
throw new IllegalArgumentException(
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
}
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException(
"No executor service is needed as we'll use single thread to merge");
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.scalarQuantizedVectorsFormat = scalarQuantize;
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, scalarQuantizedVectorsFormat);
return new Lucene99HnswVectorsWriter(
state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
Expand All @@ -52,9 +53,11 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphMerger;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
Expand All @@ -75,6 +78,8 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private final int M;
private final int beamWidth;
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
private final int numMergeWorkers;
private final ExecutorService mergeExec;

private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
Expand All @@ -83,10 +88,14 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
SegmentWriteState state,
int M,
int beamWidth,
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat)
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat,
int numMergeWorkers,
ExecutorService mergeExec)
throws IOException {
this.M = M;
this.beamWidth = beamWidth;
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
Expand Down Expand Up @@ -557,6 +566,12 @@ public void close() throws IOException {
IOUtils.close(finalVectorDataInput);
segmentWriteState.directory.deleteFile(tempFileName);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
// here we just return the inner out since we only need to close this outside copy
return innerScoreSupplier.copy();
}
};
} else {
// No need to use temporary file as we don't have to re-open for reading
Expand All @@ -579,8 +594,7 @@ public void close() throws IOException {
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
// build graph
IncrementalHnswGraphMerger merger =
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
Expand All @@ -592,9 +606,9 @@ public void close() throws IOException {
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(docsWithField.cardinality());
graph =
merger.merge(
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
Expand Down Expand Up @@ -675,6 +689,15 @@ public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
return sortedNodes;
}

private HnswGraphMerger createGraphMerger(
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
if (mergeExec != null) {
return new ConcurrentHnswMerger(
fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
}
return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
}

private void writeMeta(
boolean isQuantized,
FieldInfo field,
Expand Down Expand Up @@ -819,6 +842,9 @@ private static DocsWithFieldSet writeVectorData(
@Override
public void close() throws IOException {
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
if (mergeExec != null) {
mergeExec.shutdownNow();
}
}

private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

/**
* Writes quantized vector values and metadata to index segments.
Expand Down Expand Up @@ -761,6 +762,11 @@ public RandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return supplier.copy();
}

@Override
public void close() throws IOException {
onClose.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,22 @@ final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSco
this.values = values;
}

private ScalarQuantizedRandomVectorScorerSupplier(
ScalarQuantizedVectorSimilarity similarity, RandomAccessQuantizedByteVectorValues values) {
this.similarity = similarity;
this.values = values;
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant();
return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ScalarQuantizedRandomVectorScorerSupplier(similarity, values.copy());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
/**
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
* close after use
*
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
* closeable
*/
public interface CloseableRandomVectorScorerSupplier
extends Closeable, RandomVectorScorerSupplier {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;

/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {

private final ExecutorService exec;
private final int numWorker;

/**
* @param fieldInfo FieldInfo for the field being merged
*/
public ConcurrentHnswMerger(
FieldInfo fieldInfo,
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
ExecutorService exec,
int numWorker) {
super(fieldInfo, scorerSupplier, M, beamWidth);
this.exec = exec;
this.numWorker = numWorker;
}

@Override
protected IHnswGraphBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
throws IOException {
if (initReader == null) {
return new HnswConcurrentMergeBuilder(
exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
}

HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);

return new HnswConcurrentMergeBuilder(
exec,
numWorker,
scorerSupplier,
M,
beamWidth,
InitializedHnswGraphBuilder.initGraph(M, initializerGraph, oldToNewOrdinalMap, maxOrd),
initializedNodes);
}
}
Loading

0 comments on commit a81714d

Please sign in to comment.