From 40d0061d3b2deb46890300b8ea53816dec6ece56 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Sat, 18 Jan 2025 17:38:27 +0100 Subject: [PATCH] Drier and faster SumAggregator and AvgAggregator Dried up (and moved to the much faster inline logic) for the summation here for both implementations. Obviously this could have been done even drier but it didn't seem like that was possible without a performance hit (we really don't want to sub-class the leaf-collector I think). Benchmarks suggest this variant is ~10% faster than the previous iteration of `SumAggregator` (probably from making the grow method smaller) and a bigger than that improvement for the `AvgAggregator`. --- .../aggregations/metrics/AvgAggregator.java | 47 ++------ .../aggregations/metrics/SumAggregator.java | 106 +++++++++++------- 2 files changed, 76 insertions(+), 77 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java index c1bf1a085b709..e5a9886b4f450 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java @@ -9,12 +9,10 @@ package org.elasticsearch.search.aggregations.metrics; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.DoubleArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.fielddata.NumericDoubleValues; import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; -import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; @@ -25,12 +23,9 @@ import java.io.IOException; import java.util.Map; -class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue { +class AvgAggregator extends SumAggregator { LongArray counts; - DoubleArray sums; - DoubleArray compensations; - DocValueFormat format; AvgAggregator( String name, @@ -40,32 +35,17 @@ class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue { Map metadata ) throws IOException { super(name, valuesSourceConfig, context, parent, metadata); - assert valuesSourceConfig.hasValues(); - this.format = valuesSourceConfig.format(); - final BigArrays bigArrays = context.bigArrays(); - counts = bigArrays.newLongArray(1, true); - sums = bigArrays.newDoubleArray(1, true); - compensations = bigArrays.newDoubleArray(1, true); + counts = context.bigArrays().newLongArray(1, true); } @Override protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) { - final CompensatedSum kahanSummation = new CompensatedSum(0, 0); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { if (values.advanceExact(doc)) { maybeGrow(bucket); - final int valueCount = values.docValueCount(); - counts.increment(bucket, valueCount); - // Compute the sum of double values with Kahan summation algorithm which is more - // accurate than naive summation. - kahanSummation.reset(sums.get(bucket), compensations.get(bucket)); - for (int i = 0; i < valueCount; i++) { - kahanSummation.add(values.nextValue()); - } - sums.set(bucket, kahanSummation.value()); - compensations.set(bucket, kahanSummation.delta()); + counts.increment(bucket, sumSortedDoubles(bucket, values, sums, compensations)); } } }; @@ -73,30 +53,22 @@ public void collect(int doc, long bucket) throws IOException { @Override protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) { - final CompensatedSum kahanSummation = new CompensatedSum(0, 0); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { if (values.advanceExact(doc)) { maybeGrow(bucket); + computeSum(bucket, values, sums, compensations); counts.increment(bucket, 1L); - // Compute the sum of double values with Kahan summation algorithm which is more - // accurate than naive summation. - kahanSummation.reset(sums.get(bucket), compensations.get(bucket)); - kahanSummation.add(values.doubleValue()); - sums.set(bucket, kahanSummation.value()); - compensations.set(bucket, kahanSummation.delta()); } } }; } - private void maybeGrow(long bucket) { - if (bucket >= counts.size()) { - counts = bigArrays().grow(counts, bucket + 1); - sums = bigArrays().grow(sums, bucket + 1); - compensations = bigArrays().grow(compensations, bucket + 1); - } + @Override + protected void doGrow(long bucket, BigArrays bigArrays) { + super.doGrow(bucket, bigArrays); + counts = bigArrays.grow(counts, bucket + 1); } @Override @@ -122,7 +94,8 @@ public InternalAggregation buildEmptyAggregation() { @Override public void doClose() { - Releasables.close(counts, sums, compensations); + super.doClose(); + Releasables.close(counts); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java index c8b364c08bec5..237ba6dfe4060 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java @@ -8,6 +8,7 @@ */ package org.elasticsearch.search.aggregations.metrics; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.DoubleArray; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.fielddata.NumericDoubleValues; @@ -25,10 +26,9 @@ public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue { - private final DocValueFormat format; - - private DoubleArray sums; - private DoubleArray compensations; + protected final DocValueFormat format; + protected DoubleArray sums; + protected DoubleArray compensations; SumAggregator( String name, @@ -40,31 +40,56 @@ public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue { super(name, valuesSourceConfig, context, parent, metadata); assert valuesSourceConfig.hasValues(); this.format = valuesSourceConfig.format(); - sums = bigArrays().newDoubleArray(1, true); - compensations = bigArrays().newDoubleArray(1, true); + var bigArrays = context.bigArrays(); + sums = bigArrays.newDoubleArray(1, true); + compensations = bigArrays.newDoubleArray(1, true); } @Override protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) { - final CompensatedSum kahanSummation = new CompensatedSum(0, 0); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { if (values.advanceExact(doc)) { maybeGrow(bucket); - // Compute the sum of double values with Kahan summation algorithm which is more - // accurate than naive summation. - kahanSummation.reset(sums.get(bucket), compensations.get(bucket)); - for (int i = 0; i < values.docValueCount(); i++) { - kahanSummation.add(values.nextValue()); - } - compensations.set(bucket, kahanSummation.delta()); - sums.set(bucket, kahanSummation.value()); + sumSortedDoubles(bucket, values, sums, compensations); } } }; } + // returns number of values added + static int sumSortedDoubles(long bucket, SortedNumericDoubleValues values, DoubleArray sums, DoubleArray compensations) + throws IOException { + final int valueCount = values.docValueCount(); + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + double value = sums.get(bucket); + double delta = compensations.get(bucket); + for (int i = 0; i < valueCount; i++) { + double added = values.nextValue(); + value = addIfNonOrInf(added, value); + if (Double.isFinite(value)) { + double correctedSum = added + delta; + double updatedValue = value + correctedSum; + delta = correctedSum - (updatedValue - value); + value = updatedValue; + } + } + compensations.set(bucket, delta); + sums.set(bucket, value); + return valueCount; + } + + private static double addIfNonOrInf(double added, double value) { + // If the value is Inf or NaN, just add it to the running tally to "convert" to + // Inf/NaN. This keeps the behavior bwc from before kahan summing + if (Double.isFinite(added)) { + return value; + } + return added + value; + } + @Override protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) { return new LeafBucketCollectorBase(sub, values) { @@ -72,40 +97,41 @@ protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final public void collect(int doc, long bucket) throws IOException { if (values.advanceExact(doc)) { maybeGrow(bucket); - var sums = SumAggregator.this.sums; - // Compute the sum of double values with Kahan summation algorithm which is more - // accurate than naive summation. - double value = sums.get(bucket); - // If the value is Inf or NaN, just add it to the running tally to "convert" to - // Inf/NaN. This keeps the behavior bwc from before kahan summing - double v = values.doubleValue(); - if (Double.isFinite(v) == false) { - value = v + value; - } - - if (Double.isFinite(value)) { - var compensations = SumAggregator.this.compensations; - double delta = compensations.get(bucket); - double correctedSum = v + delta; - double updatedValue = value + correctedSum; - delta = correctedSum - (updatedValue - value); - value = updatedValue; - compensations.set(bucket, delta); - } - - sums.set(bucket, value); + computeSum(bucket, values, sums, compensations); } } }; } - private void maybeGrow(long bucket) { + static void computeSum(long bucket, NumericDoubleValues values, DoubleArray sums, DoubleArray compensations) throws IOException { + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + double added = values.doubleValue(); + double value = addIfNonOrInf(added, sums.get(bucket)); + if (Double.isFinite(value)) { + double delta = compensations.get(bucket); + double correctedSum = added + delta; + double updatedValue = value + correctedSum; + delta = correctedSum - (updatedValue - value); + value = updatedValue; + compensations.set(bucket, delta); + } + + sums.set(bucket, value); + } + + protected final void maybeGrow(long bucket) { if (bucket >= sums.size()) { - sums = bigArrays().grow(sums, bucket + 1); - compensations = bigArrays().grow(compensations, bucket + 1); + var bigArrays = bigArrays(); + doGrow(bucket, bigArrays); } } + protected void doGrow(long bucket, BigArrays bigArrays) { + sums = bigArrays.grow(sums, bucket + 1); + compensations = bigArrays.grow(compensations, bucket + 1); + } + @Override public double metric(long owningBucketOrd) { if (owningBucketOrd >= sums.size()) {