Skip to content

Commit

Permalink
Speed up the sort when building forward index (#12712)
Browse files Browse the repository at this point in the history
  • Loading branch information
gf2121 committed Oct 25, 2023
1 parent 8b38b73 commit 1cb1a14
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 99 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ Optimizations

* GITHUB#12710: Use Arrays#mismatch for Outputs#common operations. (Guo Feng)

* GITHUB#12712: Speed up sorting postings file with an offline radix sorter in BPIndexReader. (Guo Feng)

Changes in runtime behavior
---------------------

Expand Down
312 changes: 213 additions & 99 deletions lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.ByteBuffersDataOutput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
Expand All @@ -46,13 +46,11 @@
import org.apache.lucene.store.TrackingDirectoryWrapper;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefComparator;
import org.apache.lucene.util.CloseableThreadLocal;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.OfflineSorter;
import org.apache.lucene.util.OfflineSorter.BufferSize;
import org.apache.lucene.util.packed.PackedInts;

/**
* Implementation of "recursive graph bisection", also called "bipartite graph partitioning" and
Expand Down Expand Up @@ -654,9 +652,7 @@ private int writePostings(
for (int doc = postings.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = postings.nextDoc()) {
// reverse bytes so that byte order matches natural order
postingsOut.writeInt(Integer.reverseBytes(doc));
postingsOut.writeInt(Integer.reverseBytes(termID));
postingsOut.writeLong(Integer.toUnsignedLong(termID) << 32 | Integer.toUnsignedLong(doc));
}
}
}
Expand All @@ -665,107 +661,60 @@ private int writePostings(

private ForwardIndex buildForwardIndex(
Directory tempDir, String postingsFileName, int maxDoc, int maxTerm) throws IOException {
String sortedPostingsFile =
new OfflineSorter(
tempDir,
"forward-index",
// Implement BytesRefComparator to make OfflineSorter use radix sort
new BytesRefComparator(2 * Integer.BYTES) {
@Override
protected int byteAt(BytesRef ref, int i) {
return ref.bytes[ref.offset + i] & 0xFF;
}

@Override
public int compare(BytesRef o1, BytesRef o2, int k) {
assert o1.length == 2 * Integer.BYTES;
assert o2.length == 2 * Integer.BYTES;
return ArrayUtil.compareUnsigned8(o1.bytes, o1.offset, o2.bytes, o2.offset);
}
},
BufferSize.megabytes((long) (ramBudgetMB / getParallelism())),
OfflineSorter.MAX_TEMPFILES,
2 * Integer.BYTES,
forkJoinPool,
getParallelism()) {

@Override
protected ByteSequencesReader getReader(ChecksumIndexInput in, String name)
throws IOException {
return new ByteSequencesReader(in, postingsFileName) {
{
ref.grow(2 * Integer.BYTES);
ref.setLength(2 * Integer.BYTES);
}

@Override
public BytesRef next() throws IOException {
if (in.getFilePointer() >= end) {
return null;
}
// optimized read of 8 bytes
in.readBytes(ref.bytes(), 0, 2 * Integer.BYTES);
return ref.get();
}
};
}

@Override
protected ByteSequencesWriter getWriter(IndexOutput out, long itemCount)
throws IOException {
return new ByteSequencesWriter(out) {
@Override
public void write(byte[] bytes, int off, int len) throws IOException {
assert len == 2 * Integer.BYTES;
// optimized read of 8 bytes
out.writeBytes(bytes, off, len);
}
};
}
}.sort(postingsFileName);

String termIDsFileName;
String startOffsetsFileName;
int prevDoc = -1;
try (IndexInput sortedPostings = tempDir.openInput(sortedPostingsFile, IOContext.READONCE);
IndexOutput termIDs = tempDir.createTempOutput("term-ids", "", IOContext.DEFAULT);
try (IndexOutput termIDs = tempDir.createTempOutput("term-ids", "", IOContext.DEFAULT);
IndexOutput startOffsets =
tempDir.createTempOutput("start-offsets", "", IOContext.DEFAULT)) {
termIDsFileName = termIDs.getName();
startOffsetsFileName = startOffsets.getName();
final long end = sortedPostings.length() - CodecUtil.footerLength();
int[] buffer = new int[TERM_IDS_BLOCK_SIZE];
int bufferLen = 0;
while (sortedPostings.getFilePointer() < end) {
final int doc = Integer.reverseBytes(sortedPostings.readInt());
final int termID = Integer.reverseBytes(sortedPostings.readInt());
if (doc != prevDoc) {
if (bufferLen != 0) {
writeMonotonicInts(buffer, bufferLen, termIDs);
bufferLen = 0;
}
new ForwardIndexSorter(tempDir)
.sortAndConsume(
postingsFileName,
maxDoc,
new LongConsumer() {

int prevDoc = -1;
int bufferLen = 0;

@Override
public void accept(long value) throws IOException {
int doc = (int) value;
int termID = (int) (value >>> 32);
if (doc != prevDoc) {
if (bufferLen != 0) {
writeMonotonicInts(buffer, bufferLen, termIDs);
bufferLen = 0;
}

assert doc > prevDoc;
for (int d = prevDoc + 1; d <= doc; ++d) {
startOffsets.writeLong(termIDs.getFilePointer());
}
prevDoc = doc;
}
assert termID < maxTerm : termID + " " + maxTerm;
if (bufferLen == buffer.length) {
writeMonotonicInts(buffer, bufferLen, termIDs);
bufferLen = 0;
}
buffer[bufferLen++] = termID;
}

assert doc > prevDoc;
for (int d = prevDoc + 1; d <= doc; ++d) {
startOffsets.writeLong(termIDs.getFilePointer());
}
prevDoc = doc;
}
assert termID < maxTerm : termID + " " + maxTerm;
if (bufferLen == buffer.length) {
writeMonotonicInts(buffer, bufferLen, termIDs);
bufferLen = 0;
}
buffer[bufferLen++] = termID;
}
if (bufferLen != 0) {
writeMonotonicInts(buffer, bufferLen, termIDs);
}
for (int d = prevDoc + 1; d <= maxDoc; ++d) {
startOffsets.writeLong(termIDs.getFilePointer());
}
CodecUtil.writeFooter(termIDs);
CodecUtil.writeFooter(startOffsets);
@Override
public void onFinish() throws IOException {
if (bufferLen != 0) {
writeMonotonicInts(buffer, bufferLen, termIDs);
}
for (int d = prevDoc + 1; d <= maxDoc; ++d) {
startOffsets.writeLong(termIDs.getFilePointer());
}
CodecUtil.writeFooter(termIDs);
CodecUtil.writeFooter(startOffsets);
}
});
}

IndexInput termIDsInput = tempDir.openInput(termIDsFileName, IOContext.READ);
Expand Down Expand Up @@ -991,4 +940,169 @@ static int readMonotonicInts(DataInput in, int[] ints) throws IOException {
}
return len;
}

/**
* Use a LSB Radix Sorter to sort the (docID, termID) entries. We only need to compare docIds
* because LSB Radix Sorter is stable and termIDs already sorted.
*
* <p>This sorter will require at least 16MB ({@link #BUFFER_BYTES} * {@link #HISTOGRAM_SIZE})
* RAM.
*/
static class ForwardIndexSorter {

private static final int HISTOGRAM_SIZE = 256;
private static final int BUFFER_SIZE = 8192;
private static final int BUFFER_BYTES = BUFFER_SIZE * Long.BYTES;
private final Directory directory;
private final Bucket[] buckets = new Bucket[HISTOGRAM_SIZE];

private static class Bucket {
private final ByteBuffersDataOutput fps = new ByteBuffersDataOutput();
private final long[] buffer = new long[BUFFER_SIZE];
private IndexOutput output;
private int bufferUsed;
private int blockNum;
private long lastFp;
private int finalBlockSize;

private void addEntry(long l) throws IOException {
buffer[bufferUsed++] = l;
if (bufferUsed == BUFFER_SIZE) {
flush(false);
}
}

private void flush(boolean isFinal) throws IOException {
if (isFinal) {
finalBlockSize = bufferUsed;
}
long fp = output.getFilePointer();
fps.writeVLong(encode(fp - lastFp));
lastFp = fp;
for (int i = 0; i < bufferUsed; i++) {
output.writeLong(buffer[i]);
}
lastFp = fp;
blockNum++;
bufferUsed = 0;
}

private void reset(IndexOutput resetOutput) {
output = resetOutput;
finalBlockSize = 0;
bufferUsed = 0;
blockNum = 0;
lastFp = 0;
fps.reset();
}
}

private static long encode(long fpDelta) {
assert (fpDelta & 0x07) == 0 : "fpDelta should be multiple of 8";
if (fpDelta % BUFFER_BYTES == 0) {
return ((fpDelta / BUFFER_BYTES) << 1) | 1;
} else {
return fpDelta;
}
}

private static long decode(long fpDelta) {
if ((fpDelta & 1) == 1) {
return (fpDelta >>> 1) * BUFFER_BYTES;
} else {
return fpDelta;
}
}

ForwardIndexSorter(Directory directory) {
this.directory = directory;
for (int i = 0; i < HISTOGRAM_SIZE; i++) {
buckets[i] = new Bucket();
}
}

private void consume(String fileName, LongConsumer consumer) throws IOException {
try (IndexInput in = directory.openInput(fileName, IOContext.READONCE)) {
final long end = in.length() - CodecUtil.footerLength();
while (in.getFilePointer() < end) {
consumer.accept(in.readLong());
}
}
consumer.onFinish();
}

private void consume(String fileName, long indexFP, LongConsumer consumer) throws IOException {
try (IndexInput index = directory.openInput(fileName, IOContext.READONCE);
IndexInput value = directory.openInput(fileName, IOContext.READONCE)) {
index.seek(indexFP);
for (int i = 0; i < buckets.length; i++) {
int blockNum = index.readVInt();
int finalBlockSize = index.readVInt();
long fp = decode(index.readVLong());
for (int block = 0; block < blockNum - 1; block++) {
value.seek(fp);
for (int j = 0; j < BUFFER_SIZE; j++) {
consumer.accept(value.readLong());
}
fp += decode(index.readVLong());
}
value.seek(fp);
for (int j = 0; j < finalBlockSize; j++) {
consumer.accept(value.readLong());
}
}
consumer.onFinish();
}
}

private LongConsumer consumer(int shift) {
return new LongConsumer() {
@Override
public void accept(long value) throws IOException {
int b = (int) ((value >>> shift) & 0xFF);
Bucket bucket = buckets[b];
bucket.addEntry(value);
}

@Override
public void onFinish() throws IOException {
for (Bucket bucket : buckets) {
bucket.flush(true);
}
}
};
}

void sortAndConsume(String fileName, int maxDoc, LongConsumer consumer) throws IOException {
int bitsRequired = PackedInts.bitsRequired(maxDoc);
String sourceFileName = fileName;
long indexFP = -1;
for (int shift = 0; shift < bitsRequired; shift += 8) {
try (IndexOutput output = directory.createTempOutput(fileName, "sort", IOContext.DEFAULT)) {
Arrays.stream(buckets).forEach(b -> b.reset(output));
if (shift == 0) {
consume(sourceFileName, consumer(shift));
} else {
consume(sourceFileName, indexFP, consumer(shift));
directory.deleteFile(sourceFileName);
}
indexFP = output.getFilePointer();
for (Bucket bucket : buckets) {
output.writeVInt(bucket.blockNum);
output.writeVInt(bucket.finalBlockSize);
bucket.fps.copyTo(output);
}
CodecUtil.writeFooter(output);
sourceFileName = output.getName();
}
}
consume(sourceFileName, indexFP, consumer);
}
}

interface LongConsumer {
void accept(long value) throws IOException;

default void onFinish() throws IOException {}
}
}
Loading

0 comments on commit 1cb1a14

Please sign in to comment.