Skip to content

Commit

Permalink
Drier and faster SumAggregator and AvgAggregator (elastic#120436)
Browse files Browse the repository at this point in the history
Dried up (and moved to the much faster inline logic) for the summation here for both implementations.
Obviously this could have been done even drier but it didn't seem like that was possible without a performance
hit (we really don't want to sub-class the leaf-collector I think).
Benchmarks suggest this variant is ~10% faster than the previous iteration of `SumAggregator` (probably from
making the grow method smaller) and a bigger than that improvement for the `AvgAggregator`.
  • Loading branch information
original-brownbear committed Jan 22, 2025
1 parent dc66c15 commit d5749a0
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
Expand All @@ -25,12 +23,9 @@
import java.io.IOException;
import java.util.Map;

class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
class AvgAggregator extends SumAggregator {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DocValueFormat format;

AvgAggregator(
String name,
Expand All @@ -40,63 +35,40 @@ class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
Map<String, Object> metadata
) throws IOException {
super(name, valuesSourceConfig, context, parent, metadata);
assert valuesSourceConfig.hasValues();
this.format = valuesSourceConfig.format();
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
counts = context.bigArrays().newLongArray(1, true);
}

@Override
protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
for (int i = 0; i < valueCount; i++) {
kahanSummation.add(values.nextValue());
}
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
counts.increment(bucket, sumSortedDoubles(bucket, values, sums, compensations));
}
}
};
}

@Override
protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
computeSum(bucket, values, sums, compensations);
counts.increment(bucket, 1L);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
kahanSummation.add(values.doubleValue());
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}
}
};
}

private void maybeGrow(long bucket) {
if (bucket >= counts.size()) {
counts = bigArrays().grow(counts, bucket + 1);
sums = bigArrays().grow(sums, bucket + 1);
compensations = bigArrays().grow(compensations, bucket + 1);
}
@Override
protected void doGrow(long bucket, BigArrays bigArrays) {
super.doGrow(bucket, bigArrays);
counts = bigArrays.grow(counts, bucket + 1);
}

@Override
Expand All @@ -122,7 +94,8 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums, compensations);
super.doClose();
Releasables.close(counts);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
Expand All @@ -25,10 +26,9 @@

public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {

private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;
protected final DocValueFormat format;
protected DoubleArray sums;
protected DoubleArray compensations;

SumAggregator(
String name,
Expand All @@ -40,72 +40,98 @@ public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {
super(name, valuesSourceConfig, context, parent, metadata);
assert valuesSourceConfig.hasValues();
this.format = valuesSourceConfig.format();
sums = bigArrays().newDoubleArray(1, true);
compensations = bigArrays().newDoubleArray(1, true);
var bigArrays = context.bigArrays();
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}

@Override
protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
for (int i = 0; i < values.docValueCount(); i++) {
kahanSummation.add(values.nextValue());
}
compensations.set(bucket, kahanSummation.delta());
sums.set(bucket, kahanSummation.value());
sumSortedDoubles(bucket, values, sums, compensations);
}
}
};
}

// returns number of values added
static int sumSortedDoubles(long bucket, SortedNumericDoubleValues values, DoubleArray sums, DoubleArray compensations)
throws IOException {
final int valueCount = values.docValueCount();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = sums.get(bucket);
double delta = compensations.get(bucket);
for (int i = 0; i < valueCount; i++) {
double added = values.nextValue();
value = addIfNonOrInf(added, value);
if (Double.isFinite(value)) {
double correctedSum = added + delta;
double updatedValue = value + correctedSum;
delta = correctedSum - (updatedValue - value);
value = updatedValue;
}
}
compensations.set(bucket, delta);
sums.set(bucket, value);
return valueCount;
}

private static double addIfNonOrInf(double added, double value) {
// If the value is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(added)) {
return value;
}
return added + value;
}

@Override
protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) {
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
var sums = SumAggregator.this.sums;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = sums.get(bucket);
// If the value is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
double v = values.doubleValue();
if (Double.isFinite(v) == false) {
value = v + value;
}

if (Double.isFinite(value)) {
var compensations = SumAggregator.this.compensations;
double delta = compensations.get(bucket);
double correctedSum = v + delta;
double updatedValue = value + correctedSum;
delta = correctedSum - (updatedValue - value);
value = updatedValue;
compensations.set(bucket, delta);
}

sums.set(bucket, value);
computeSum(bucket, values, sums, compensations);
}
}
};
}

private void maybeGrow(long bucket) {
static void computeSum(long bucket, NumericDoubleValues values, DoubleArray sums, DoubleArray compensations) throws IOException {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double added = values.doubleValue();
double value = addIfNonOrInf(added, sums.get(bucket));
if (Double.isFinite(value)) {
double delta = compensations.get(bucket);
double correctedSum = added + delta;
double updatedValue = value + correctedSum;
delta = correctedSum - (updatedValue - value);
value = updatedValue;
compensations.set(bucket, delta);
}

sums.set(bucket, value);
}

protected final void maybeGrow(long bucket) {
if (bucket >= sums.size()) {
sums = bigArrays().grow(sums, bucket + 1);
compensations = bigArrays().grow(compensations, bucket + 1);
var bigArrays = bigArrays();
doGrow(bucket, bigArrays);
}
}

protected void doGrow(long bucket, BigArrays bigArrays) {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);
}

@Override
public double metric(long owningBucketOrd) {
if (owningBucketOrd >= sums.size()) {
Expand Down

0 comments on commit d5749a0

Please sign in to comment.