Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make aggregation statement compilation robust #2781

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ spotless {
removeUnusedImports()
trimTrailingWhitespace()
endWithNewline()
toggleOffOn()
googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format')
}
}
Expand Down
4 changes: 3 additions & 1 deletion core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ spotless {
removeUnusedImports()
trimTrailingWhitespace()
endWithNewline()
toggleOffOn()
googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format')
}
}
Expand Down Expand Up @@ -112,7 +113,8 @@ jacocoTestCoverageVerification {
'org.opensearch.sql.utils.Constants',
'org.opensearch.sql.datasource.model.DataSource',
'org.opensearch.sql.datasource.model.DataSourceStatus',
'org.opensearch.sql.datasource.model.DataSourceType'
'org.opensearch.sql.datasource.model.DataSourceType',
'org.opensearch.sql.QueryCompilationError'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql;

import static org.opensearch.sql.common.utils.StringUtils.format;

import lombok.experimental.UtilityClass;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;

/** Grouping error messages from {@link SemanticCheckException} thrown during query compilation. */
@UtilityClass
public class QueryCompilationError {

public static SemanticCheckException fieldNotInGroupByClauseError(String name) {
return new SemanticCheckException(
format(
"Field [%s] must appear in the GROUP BY clause or be used in an aggregate function",
name));
}

public static SemanticCheckException aggregateFunctionNotAllowedInGroupByError(
String functionName) {
return new SemanticCheckException(
format(
"Aggregate function is not allowed in a GROUP BY clause, but found [%s]",
functionName));
}

public static SemanticCheckException nonBooleanExpressionInFilterOrHavingError(ExprType type) {
return new SemanticCheckException(
format(
"FILTER or HAVING expression must be type boolean, but found [%s]", type.typeName()));
}

public static SemanticCheckException aggregateFunctionNotAllowedInFilterError(
String functionName) {
return new SemanticCheckException(
format("Aggregate function is not allowed in a FILTER, but found [%s]", functionName));
}

public static SemanticCheckException windowFunctionNotAllowedError() {
return new SemanticCheckException("Window functions are not allowed in WHERE or HAVING");
}

public static SemanticCheckException unsupportedAggregateFunctionError(String functionName) {
return new SemanticCheckException(format("Unsupported aggregation function %s", functionName));
}

public static SemanticCheckException ordinalRefersOutOfBounds(int ordinal) {
return new SemanticCheckException(
format("Ordinal [%d] is out of bound of select item list", ordinal));
}

public static SemanticCheckException groupByClauseIsMissingError(UnresolvedExpression expr) {
return new SemanticCheckException(
format(
"Explicit GROUP BY clause is required because expression [%s] contains non-aggregated"
+ " column",
expr));
}
}
140 changes: 122 additions & 18 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.DataSourceSchemaName;
import org.opensearch.sql.QueryCompilationError;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand All @@ -47,6 +52,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
Expand Down Expand Up @@ -81,6 +87,7 @@
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.window.WindowFunctionExpression;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalCloseCursor;
Expand All @@ -102,6 +109,7 @@
import org.opensearch.sql.planner.logical.LogicalValues;
import org.opensearch.sql.planner.physical.datasource.DataSourceTable;
import org.opensearch.sql.storage.Table;
import org.opensearch.sql.utils.ExpressionUtils;
import org.opensearch.sql.utils.ParseUtils;

/**
Expand Down Expand Up @@ -235,32 +243,51 @@ public LogicalPlan visitLimit(Limit node, AnalysisContext context) {
public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);
verifyCondition(condition);

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expression optimized = optimizer.optimize(condition, context);
return new LogicalFilter(child, optimized);
}

private void verifyCondition(Expression condition) {
// TODO Remove this when adding support for syntax - nested(path, condition)
// Current WHERE nested(path, condition) is not a valid boolean condition.
boolean isNestedFunction =
condition instanceof FunctionExpression
&& ((FunctionExpression) condition).getFunctionName().equals(FunctionName.of("nested"));
// Check if the filter condition is a valid predicate.
if (condition.type() != ExprCoreType.BOOLEAN && !isNestedFunction) {
throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(condition.type());
}
// Check if any window functions in filter
List<Expression> results =
ExpressionUtils.findSubExpressions(condition, WindowFunctionExpression.class::isInstance);
if (!results.isEmpty()) {
throw QueryCompilationError.windowFunctionNotAllowedError();
}
}

/**
* Ensure NESTED function is not used in GROUP BY, and HAVING clauses. Fallback to legacy engine.
* Can remove when support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING
* clauses.
*
* @param condition : Filter condition
* Ensure NESTED function is not used in GROUP BY. Fallback to legacy engine. Can remove when
* support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING clauses. Ensure
* Aggregate function is not used in GROUP BY.
*/
private void verifySupportsCondition(Expression condition) {
if (condition instanceof FunctionExpression) {
if (((FunctionExpression) condition)
private void verifySupportsGroupBy(Expression groupBy) {
if (groupBy instanceof FunctionExpression) {
if (((FunctionExpression) groupBy)
.getFunctionName()
.getFunctionName()
.equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) {
throw new SyntaxCheckException(
"Falling back to legacy engine. Nested function is not supported in WHERE,"
+ " GROUP BY, and HAVING clauses.");
}
((FunctionExpression) condition)
.getArguments().stream().forEach(e -> verifySupportsCondition(e));
((FunctionExpression) groupBy).getArguments().stream().forEach(e -> verifySupportsGroupBy(e));
} else if (groupBy instanceof Aggregator) {
throw QueryCompilationError.aggregateFunctionNotAllowedInGroupByError(
((Aggregator<?>) groupBy).getFunctionName().getFunctionName());
}
}

Expand Down Expand Up @@ -295,13 +322,7 @@ public LogicalPlan visitRename(Rename node, AnalysisContext context) {
@Override
public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
final LogicalPlan child = node.getChild().get(0).accept(this, context);
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggExprList()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}

// resolve group-by list
ImmutableList.Builder<NamedExpression> groupbyBuilder = new ImmutableList.Builder<>();
// Span should be first expression if exist.
if (node.getSpan() != null) {
Expand All @@ -310,12 +331,62 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {

for (UnresolvedExpression expr : node.getGroupExprList()) {
NamedExpression resolvedExpr = namedExpressionAnalyzer.analyze(expr, context);
verifySupportsCondition(resolvedExpr.getDelegated());
verifySupportsGroupBy(resolvedExpr.getDelegated());
groupbyBuilder.add(resolvedExpr);
}
ImmutableList<NamedExpression> groupBys = groupbyBuilder.build();

// spotless:off
// Verify group-by could work with select expressions.
// The following table shows the examples to explain the purpose:
// +------+------------------------------------------------------------------------------------------+---------+-------------------------+
// | Case | Query | IsValid | Field Missed In GroupBy |
// +------+------------------------------------------------------------------------------------------+---------+-------------------------+
// | 1 | SELECT a FROM table GROUP BY b | No | a |
// | 2 | SELECT a as c FROM table GROUP BY b | No | a |
// | 3 | SELECT a FROM table GROUP BY a * 3 | No | a |
// | 4 | SELECT a * 3 FROM table GROUP BY a | Yes | N/A |
// | 5 | SELECT a * 3 FROM table GROUP BY b | No | a |
// | 6 | SELECT a FROM table GROUP BY upper(a) | No | a |
// | 7 | SELECT upper(a) FROM table GROUP BY a | Yes | N/A |
// | 8 | SELECT upper(a) FROM table GROUP BY upper(a) | Yes | N/A |
// | 9 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY b | No | a |
// | 10 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY upper(b) | No | a |
// | 11 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat(upper(a), upper(b)) | Yes | N/A |
// | 12 | SELECT concat(upper(a), upper(b)) FROM table GROUP BY concat_ws(',', upper(a), upper(b)) | No | a |
// | 13 | SELECT concat(a, b) FROM table group by upper(a), upper(b) | No | a |
// | 14 | SELECT concat(a, b) FROM table group by a | No | b |
// | 15 | SELECT concat(a, b) FROM table group by a, upper(b) | No | b |
// | 16 | SELECT concat(a, b), upper(b) FROM table group by a, upper(b) | No | b |
// | 17 | SELECT concat(a, b), upper(b) FROM table group by a, b | Yes | N/A |
// | 18 | SELECT upper(concat(a, b)) FROM table group by concat(a, b) | Yes | N/A |
// | 19 | SELECT concat(concat(a, b), c) FROM table group by concat(a, b) | No | c |
// | 20 | SELECT 1, 2, 3 FROM table group by a | Yes | N/A |
// | 21 | SELECT 1, 2, b FROM table group by a | No | b |
// +------+------------------------------------------------------------------------------------------+---------+-------------------------+
// spotless:on
for (UnresolvedExpression expr : node.getAliasFreeSelectExprList()) {
Expression resolvedSelectItemExpr = expressionAnalyzer.analyze(expr, context);
Predicate<Expression> notExists =
e ->
(e instanceof ReferenceExpression || e instanceof FunctionExpression)
&& groupBys.stream().noneMatch(g -> e.equals(g.getDelegated()));
Consumer<String> action =
name -> {
throw QueryCompilationError.fieldNotInGroupByClauseError(name);
};
ExpressionUtils.actionOnCheck(resolvedSelectItemExpr, notExists, action);
}

// resolve aggregators
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggExprList()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}
ImmutableList<NamedAggregator> aggregators = aggregatorBuilder.build();

// new context
context.push();
TypeEnvironment newEnv = context.peek();
Expand All @@ -329,6 +400,39 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
return new LogicalAggregation(child, aggregators, groupBys);
}

/** Resolve Having clause to merge its aggregators to {@link LogicalAggregation}. */
@Override
public LogicalPlan visitHaving(Having node, AnalysisContext context) {
LogicalAggregation aggregation =
(LogicalAggregation) node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);
verifyCondition(condition);

// Extract aggregator from Having clause
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggregators()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}
List<NamedAggregator> aggregatorListFromHaving = aggregatorBuilder.build();
// new context
context.push();
TypeEnvironment newEnv = context.peek();
aggregatorListFromHaving.forEach(
aggregator ->
newEnv.define(
new Symbol(Namespace.FIELD_NAME, aggregator.getName()), aggregator.type()));

List<NamedAggregator> aggregatorListFromChild = aggregation.getAggregatorList();
// merge the aggregators from having to its child's
Set<NamedAggregator> dedup = new LinkedHashSet<>(aggregatorListFromChild);
dedup.addAll(aggregatorListFromHaving);
List<NamedAggregator> mergedAggregators = new ArrayList<>(dedup);
return new LogicalAggregation(
aggregation.getChild().get(0), mergedAggregators, aggregation.getGroupByList());
}

/** Build {@link LogicalRareTopN}. */
@Override
public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Getter;
import org.opensearch.sql.QueryCompilationError;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand Down Expand Up @@ -70,6 +71,8 @@
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
import org.opensearch.sql.expression.window.ranking.RankingWindowFunction;
import org.opensearch.sql.utils.ExpressionUtils;

/**
* Analyze the {@link UnresolvedExpression} in the {@link AnalysisContext} to construct the {@link
Expand Down Expand Up @@ -169,11 +172,23 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
builder.build());
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
// Check if the filter condition is a valid predicate.
Expression predicate = node.condition().accept(this, context);
if (predicate.type() != ExprCoreType.BOOLEAN) {
throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(predicate.type());
}
// Check if any aggregate function in filter
List<Expression> results =
ExpressionUtils.findSubExpressions(predicate, Aggregator.class::isInstance);
if (!results.isEmpty()) {
throw QueryCompilationError.aggregateFunctionNotAllowedInFilterError(
((Aggregator) results.get(0)).getFunctionName().getFunctionName());
}
aggregator.condition(predicate);
}
return aggregator;
} else {
throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
throw QueryCompilationError.unsupportedAggregateFunctionError(node.getFuncName());
}
}

Expand Down Expand Up @@ -203,6 +218,10 @@ public Expression visitFunction(Function node, AnalysisContext context) {
repository.compile(context.getFunctionProperties(), functionName, arguments);
}

/**
* Todo. throws SemanticCheckException when a configuration could be set in order to avoid
* breaking change. Order is required if function expression is {@link RankingWindowFunction}.
*/
@SuppressWarnings("unchecked")
@Override
public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
Expand Down Expand Up @@ -312,4 +313,8 @@ public T visitFetchCursor(FetchCursor cursor, C context) {
public T visitCloseCursor(CloseCursor closeCursor, C context) {
return visitChildren(closeCursor, context);
}

public T visitHaving(Having having, C context) {
return visitChildren(having, context);
}
}
Loading
Loading