Skip to content

Commit

Permalink
Fix filtered aggregate with ordering (#13784)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored and ankitsultana committed Aug 10, 2024
1 parent ad284a6 commit 6943c0c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ public class TableResizer {
private final int _numGroupByExpressions;
private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
private final AggregationFunction[] _aggregationFunctions;
private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
private final Map<Pair<FunctionContext, FilterContext>, Integer> _filteredAggregationIndexMap;
private final List<Pair<AggregationFunction, FilterContext>> _filteredAggregationFunctions;
private final int _numOrderByExpressions;
private final OrderByValueExtractor[] _orderByValueExtractors;
private final Comparator<IntermediateRecord> _intermediateRecordComparator;
Expand All @@ -82,10 +80,8 @@ public TableResizer(DataSchema dataSchema, boolean hasFinalInput, QueryContext q

_aggregationFunctions = queryContext.getAggregationFunctions();
assert _aggregationFunctions != null;
_aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap();
assert _aggregationFunctionIndexMap != null;
_filteredAggregationIndexMap = queryContext.getFilteredAggregationsIndexMap();
_filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions();
assert _filteredAggregationIndexMap != null;

List<OrderByExpressionContext> orderByExpressions = queryContext.getOrderByExpressions();
assert orderByExpressions != null;
Expand Down Expand Up @@ -148,26 +144,26 @@ private OrderByValueExtractor getOrderByValueExtractor(ExpressionContext express
FunctionContext function = expression.getFunction();
Preconditions.checkState(function != null, "Failed to find ORDER-BY expression: %s in the GROUP-BY clause",
expression);
FunctionContext aggregation;
FilterContext filter;
if (function.getType() == FunctionContext.Type.AGGREGATION) {
// Aggregation function
int index = _aggregationFunctionIndexMap.get(function);
// For final aggregate result, we can handle it the same way as group key
return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index)
: new AggregationFunctionExtractor(index);
aggregation = function;
filter = null;
} else if (function.getType() == FunctionContext.Type.TRANSFORM && "FILTER".equalsIgnoreCase(
function.getFunctionName())) {
// Filtered aggregation
FunctionContext aggregation = function.getArguments().get(0).getFunction();
ExpressionContext filterExpression = function.getArguments().get(1);
FilterContext filter = RequestContextUtils.getFilter(filterExpression);
int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
// For final aggregate result, we can handle it the same way as group key
return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index)
: new AggregationFunctionExtractor(index, _filteredAggregationFunctions.get(index).getLeft());
aggregation = function.getArguments().get(0).getFunction();
filter = RequestContextUtils.getFilter(function.getArguments().get(1));
} else {
// Post-aggregation function
return new PostAggregationFunctionExtractor(function);
}

int index = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
// For final aggregate result, we can handle it the same way as group key
return _hasFinalInput ? new GroupByExpressionExtractor(_numGroupByExpressions + index)
: new AggregationFunctionExtractor(index);
}

/**
Expand Down Expand Up @@ -441,11 +437,6 @@ private class AggregationFunctionExtractor implements OrderByValueExtractor {
_aggregationFunction = _aggregationFunctions[aggregationFunctionIndex];
}

AggregationFunctionExtractor(int aggregationFunctionIndex, AggregationFunction aggregationFunction) {
_index = aggregationFunctionIndex + _numGroupByExpressions;
_aggregationFunction = aggregationFunction;
}

@Override
public ColumnDataType getValueType() {
return _aggregationFunction.getFinalResultColumnType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ public class QueryContext {

// Pre-calculate the aggregation functions and columns for the query so that it can be shared across all the segments
private AggregationFunction[] _aggregationFunctions;
private Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
private boolean _hasFilteredAggregations;
private List<Pair<AggregationFunction, FilterContext>> _filteredAggregationFunctions;
private Map<Pair<FunctionContext, FilterContext>, Integer> _filteredAggregationsIndexMap;
private boolean _hasFilteredAggregations;
private Set<String> _columns;

// Other properties to be shared across all the segments
Expand Down Expand Up @@ -272,22 +271,6 @@ public List<Pair<AggregationFunction, FilterContext>> getFilteredAggregationFunc
return _filteredAggregationFunctions;
}

/**
* Returns the filtered aggregation expressions for the query.
*/
public boolean hasFilteredAggregations() {
return _hasFilteredAggregations;
}

/**
* Returns a map from the AGGREGATION FunctionContext to the index of the corresponding AggregationFunction in the
* aggregation functions array.
*/
@Nullable
public Map<FunctionContext, Integer> getAggregationFunctionIndexMap() {
return _aggregationFunctionIndexMap;
}

/**
* Returns a map from the filtered aggregation (pair of AGGREGATION FunctionContext and FILTER FilterContext) to the
* index of corresponding AggregationFunction in the aggregation functions array.
Expand All @@ -297,6 +280,13 @@ public Map<Pair<FunctionContext, FilterContext>, Integer> getFilteredAggregation
return _filteredAggregationsIndexMap;
}

/**
* Returns the filtered aggregation expressions for the query.
*/
public boolean hasFilteredAggregations() {
return _hasFilteredAggregations;
}

/**
* Returns the columns (IDENTIFIER expressions) in the query.
*/
Expand Down Expand Up @@ -619,12 +609,7 @@ private void generateAggregationFunctions(QueryContext queryContext) {
for (int i = 0; i < numAggregations; i++) {
aggregationFunctions[i] = filteredAggregationFunctions.get(i).getLeft();
}
Map<FunctionContext, Integer> aggregationFunctionIndexMap = new HashMap<>();
for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : filteredAggregationsIndexMap.entrySet()) {
aggregationFunctionIndexMap.put(entry.getKey().getLeft(), entry.getValue());
}
queryContext._aggregationFunctions = aggregationFunctions;
queryContext._aggregationFunctionIndexMap = aggregationFunctionIndexMap;
queryContext._filteredAggregationFunctions = filteredAggregationFunctions;
queryContext._filteredAggregationsIndexMap = filteredAggregationsIndexMap;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,21 +480,21 @@ public void testHardcodedQueries() {
assertEquals(aggregationFunctions[3].getResultColumnName(), "sum(col4)");
assertEquals(aggregationFunctions[4].getResultColumnName(), "max(col4)");
assertEquals(aggregationFunctions[5].getResultColumnName(), "max(col1)");
Map<FunctionContext, Integer> aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap();
assertNotNull(aggregationFunctionIndexMap);
assertEquals(aggregationFunctionIndexMap.size(), 6);
assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
Collections.singletonList(ExpressionContext.forIdentifier("col1")))), 0);
assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "max",
Collections.singletonList(ExpressionContext.forIdentifier("col2")))), 1);
assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "min",
Collections.singletonList(ExpressionContext.forIdentifier("col2")))), 2);
assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
Collections.singletonList(ExpressionContext.forIdentifier("col4")))), 3);
assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "max",
Collections.singletonList(ExpressionContext.forIdentifier("col4")))), 4);
assertEquals((int) aggregationFunctionIndexMap.get(new FunctionContext(FunctionContext.Type.AGGREGATION, "max",
Collections.singletonList(ExpressionContext.forIdentifier("col1")))), 5);
Map<Pair<FunctionContext, FilterContext>, Integer> indexMap = queryContext.getFilteredAggregationsIndexMap();
assertNotNull(indexMap);
assertEquals(indexMap.size(), 6);
assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
Collections.singletonList(ExpressionContext.forIdentifier("col1"))), null)), 0);
assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "max",
Collections.singletonList(ExpressionContext.forIdentifier("col2"))), null)), 1);
assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "min",
Collections.singletonList(ExpressionContext.forIdentifier("col2"))), null)), 2);
assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "sum",
Collections.singletonList(ExpressionContext.forIdentifier("col4"))), null)), 3);
assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "max",
Collections.singletonList(ExpressionContext.forIdentifier("col4"))), null)), 4);
assertEquals((int) indexMap.get(Pair.of(new FunctionContext(FunctionContext.Type.AGGREGATION, "max",
Collections.singletonList(ExpressionContext.forIdentifier("col1"))), null)), 5);
}

// DistinctCountThetaSketch (string literal and escape quote)
Expand Down Expand Up @@ -540,21 +540,10 @@ public void testFilteredAggregations() {
assertTrue(filteredAggregationFunctions.get(1).getLeft() instanceof CountAggregationFunction);
assertEquals(filteredAggregationFunctions.get(1).getRight().toString(), "foo < '6'");

Map<FunctionContext, Integer> aggregationIndexMap = queryContext.getAggregationFunctionIndexMap();
assertNotNull(aggregationIndexMap);
assertEquals(aggregationIndexMap.size(), 1);
for (Map.Entry<FunctionContext, Integer> entry : aggregationIndexMap.entrySet()) {
FunctionContext aggregation = entry.getKey();
int index = entry.getValue();
assertEquals(aggregation.toString(), "count(*)");
assertTrue(index == 0 || index == 1);
}

Map<Pair<FunctionContext, FilterContext>, Integer> filteredAggregationsIndexMap =
queryContext.getFilteredAggregationsIndexMap();
assertNotNull(filteredAggregationsIndexMap);
assertEquals(filteredAggregationsIndexMap.size(), 2);
for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : filteredAggregationsIndexMap.entrySet()) {
Map<Pair<FunctionContext, FilterContext>, Integer> indexMap = queryContext.getFilteredAggregationsIndexMap();
assertNotNull(indexMap);
assertEquals(indexMap.size(), 2);
for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : indexMap.entrySet()) {
Pair<FunctionContext, FilterContext> pair = entry.getKey();
FunctionContext aggregation = pair.getLeft();
FilterContext filter = pair.getRight();
Expand Down Expand Up @@ -600,32 +589,10 @@ public void testFilteredAggregations() {
assertTrue(filteredAggregationFunctions.get(3).getLeft() instanceof MinAggregationFunction);
assertEquals(filteredAggregationFunctions.get(3).getRight().toString(), "salary > '50000'");

Map<FunctionContext, Integer> aggregationIndexMap = queryContext.getAggregationFunctionIndexMap();
assertNotNull(aggregationIndexMap);
assertEquals(aggregationIndexMap.size(), 2);
for (Map.Entry<FunctionContext, Integer> entry : aggregationIndexMap.entrySet()) {
FunctionContext aggregation = entry.getKey();
int index = entry.getValue();
switch (index) {
case 0:
case 1:
assertEquals(aggregation.toString(), "sum(salary)");
break;
case 2:
case 3:
assertEquals(aggregation.toString(), "min(salary)");
break;
default:
fail();
break;
}
}

Map<Pair<FunctionContext, FilterContext>, Integer> filteredAggregationsIndexMap =
queryContext.getFilteredAggregationsIndexMap();
assertNotNull(filteredAggregationsIndexMap);
assertEquals(filteredAggregationsIndexMap.size(), 4);
for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : filteredAggregationsIndexMap.entrySet()) {
Map<Pair<FunctionContext, FilterContext>, Integer> indexMap = queryContext.getFilteredAggregationsIndexMap();
assertNotNull(indexMap);
assertEquals(indexMap.size(), 4);
for (Map.Entry<Pair<FunctionContext, FilterContext>, Integer> entry : indexMap.entrySet()) {
Pair<FunctionContext, FilterContext> pair = entry.getKey();
FunctionContext aggregation = pair.getLeft();
FilterContext filter = pair.getRight();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,6 @@ public class OfflineClusterIntegrationTest extends BaseClusterIntegrationTestSet
new StarTreeIndexConfig(Collections.singletonList("DestState"), null,
Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100);
private static final String TEST_STAR_TREE_QUERY_2 = "SELECT COUNT(*) FROM mytable WHERE DestState = 'CA'";
private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG =
"SELECT COUNT(*), COUNT(*) FILTER (WHERE Carrier = 'UA') FROM mytable WHERE DestState = 'CA'";
// This query contains a filtered aggregation which cannot be solved with startree, but the COUNT(*) still should be
private static final String TEST_STAR_TREE_QUERY_FILTERED_AGG_MIXED =
"SELECT COUNT(*), AVG(ArrDelay) FILTER (WHERE Carrier = 'UA') FROM mytable WHERE DestState = 'CA'";
private static final StarTreeIndexConfig STAR_TREE_INDEX_CONFIG_3 =
new StarTreeIndexConfig(List.of("Carrier", "DestState"), null,
Collections.singletonList(AggregationFunctionColumnPair.COUNT_STAR.toColumnName()), null, 100);

// For default columns test
private static final String TEST_EXTRA_COLUMNS_QUERY = "SELECT COUNT(*) FROM mytable WHERE NewAddedIntMetric = 1";
Expand Down Expand Up @@ -3472,6 +3464,24 @@ public void testBooleanAggregation()
testQuery("SELECT BOOL_OR(CAST(Diverted AS BOOLEAN)) FROM mytable");
}

@Test(dataProvider = "useBothQueryEngines")
public void testGroupByAggregationWithLimitZero(boolean useMultiStageQueryEngine)
throws Exception {
setUseMultiStageQueryEngine(useMultiStageQueryEngine);
testQuery("SELECT Origin, SUM(ArrDelay) FROM mytable GROUP BY Origin LIMIT 0");
}

@Test(dataProvider = "useBothQueryEngines")
public void testFilteredAggregationWithGroupByOrdering(boolean useMultiStageQueryEngine)
throws Exception {
setUseMultiStageQueryEngine(useMultiStageQueryEngine);

// Test the ordering is correctly applied to the correct aggregation (the one without FILTER clause)
// See https://github.com/apache/pinot/pull/13784
testQuery("SELECT DestCityName, COUNT(*) AS c1, COUNT(*) FILTER (WHERE AirTime = 0) AS c2 FROM mytable "
+ "GROUP BY DestCityName ORDER BY c1 DESC LIMIT 10");
}

private String buildSkipIndexesOption(String columnsAndIndexes) {
return "SET " + SKIP_INDEXES + "='" + columnsAndIndexes + "'; ";
}
Expand Down

0 comments on commit 6943c0c

Please sign in to comment.