Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drier and faster SumAggregator and AvgAggregator #120436

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
Expand All @@ -25,12 +23,9 @@
import java.io.IOException;
import java.util.Map;

class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
class AvgAggregator extends SumAggregator {

LongArray counts;
DoubleArray sums;
DoubleArray compensations;
DocValueFormat format;

AvgAggregator(
String name,
Expand All @@ -40,63 +35,40 @@ class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
Map<String, Object> metadata
) throws IOException {
super(name, valuesSourceConfig, context, parent, metadata);
assert valuesSourceConfig.hasValues();
this.format = valuesSourceConfig.format();
final BigArrays bigArrays = context.bigArrays();
counts = bigArrays.newLongArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
counts = context.bigArrays().newLongArray(1, true);
}

@Override
protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
final int valueCount = values.docValueCount();
counts.increment(bucket, valueCount);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
for (int i = 0; i < valueCount; i++) {
kahanSummation.add(values.nextValue());
}
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
counts.increment(bucket, sumSortedDoubles(bucket, values, sums, compensations));
}
}
};
}

@Override
protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
computeSum(bucket, values, sums, compensations);
counts.increment(bucket, 1L);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
kahanSummation.add(values.doubleValue());
sums.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}
}
};
}

private void maybeGrow(long bucket) {
if (bucket >= counts.size()) {
counts = bigArrays().grow(counts, bucket + 1);
sums = bigArrays().grow(sums, bucket + 1);
compensations = bigArrays().grow(compensations, bucket + 1);
}
@Override
protected void doGrow(long bucket, BigArrays bigArrays) {
super.doGrow(bucket, bigArrays);
counts = bigArrays.grow(counts, bucket + 1);
}

@Override
Expand All @@ -122,7 +94,8 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(counts, sums, compensations);
super.doClose();
Releasables.close(counts);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
Expand All @@ -25,10 +26,9 @@

public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {

private final DocValueFormat format;

private DoubleArray sums;
private DoubleArray compensations;
protected final DocValueFormat format;
protected DoubleArray sums;
protected DoubleArray compensations;

SumAggregator(
String name,
Expand All @@ -40,72 +40,98 @@ public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {
super(name, valuesSourceConfig, context, parent, metadata);
assert valuesSourceConfig.hasValues();
this.format = valuesSourceConfig.format();
sums = bigArrays().newDoubleArray(1, true);
compensations = bigArrays().newDoubleArray(1, true);
var bigArrays = context.bigArrays();
sums = bigArrays.newDoubleArray(1, true);
compensations = bigArrays.newDoubleArray(1, true);
}

@Override
protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
for (int i = 0; i < values.docValueCount(); i++) {
kahanSummation.add(values.nextValue());
}
compensations.set(bucket, kahanSummation.delta());
sums.set(bucket, kahanSummation.value());
sumSortedDoubles(bucket, values, sums, compensations);
}
}
};
}

// returns number of values added
static int sumSortedDoubles(long bucket, SortedNumericDoubleValues values, DoubleArray sums, DoubleArray compensations)
throws IOException {
final int valueCount = values.docValueCount();
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = sums.get(bucket);
double delta = compensations.get(bucket);
for (int i = 0; i < valueCount; i++) {
double added = values.nextValue();
value = addIfNonOrInf(added, value);
if (Double.isFinite(value)) {
double correctedSum = added + delta;
double updatedValue = value + correctedSum;
delta = correctedSum - (updatedValue - value);
value = updatedValue;
}
}
compensations.set(bucket, delta);
sums.set(bucket, value);
return valueCount;
}

private static double addIfNonOrInf(double added, double value) {
// If the value is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
if (Double.isFinite(added)) {
return value;
}
return added + value;
}

@Override
protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) {
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
if (values.advanceExact(doc)) {
maybeGrow(bucket);
var sums = SumAggregator.this.sums;
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double value = sums.get(bucket);
// If the value is Inf or NaN, just add it to the running tally to "convert" to
// Inf/NaN. This keeps the behavior bwc from before kahan summing
double v = values.doubleValue();
if (Double.isFinite(v) == false) {
value = v + value;
}

if (Double.isFinite(value)) {
var compensations = SumAggregator.this.compensations;
double delta = compensations.get(bucket);
double correctedSum = v + delta;
double updatedValue = value + correctedSum;
delta = correctedSum - (updatedValue - value);
value = updatedValue;
compensations.set(bucket, delta);
}

sums.set(bucket, value);
computeSum(bucket, values, sums, compensations);
}
}
};
}

private void maybeGrow(long bucket) {
static void computeSum(long bucket, NumericDoubleValues values, DoubleArray sums, DoubleArray compensations) throws IOException {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double added = values.doubleValue();
double value = addIfNonOrInf(added, sums.get(bucket));
if (Double.isFinite(value)) {
double delta = compensations.get(bucket);
double correctedSum = added + delta;
double updatedValue = value + correctedSum;
delta = correctedSum - (updatedValue - value);
value = updatedValue;
compensations.set(bucket, delta);
}

sums.set(bucket, value);
}

protected final void maybeGrow(long bucket) {
if (bucket >= sums.size()) {
sums = bigArrays().grow(sums, bucket + 1);
compensations = bigArrays().grow(compensations, bucket + 1);
var bigArrays = bigArrays();
doGrow(bucket, bigArrays);
}
}

protected void doGrow(long bucket, BigArrays bigArrays) {
sums = bigArrays.grow(sums, bucket + 1);
compensations = bigArrays.grow(compensations, bucket + 1);
}

@Override
public double metric(long owningBucketOrd) {
if (owningBucketOrd >= sums.size()) {
Expand Down