From d4f5454bc4ef98d49ad9eb86d7882f704064d7f9 Mon Sep 17 00:00:00 2001 From: NobiGo Date: Sat, 4 Jun 2022 10:19:46 +0800 Subject: [PATCH] [CALCITE-4924] REGR_SXX and similar aggregate functions return the wrong data type --- .../rel/type/DelegatingTypeSystem.java | 5 +- .../calcite/rel/type/RelDataTypeSystem.java | 3 +- .../rel/type/RelDataTypeSystemImpl.java | 112 +++++++++++++++++- .../apache/calcite/sql/type/ReturnTypes.java | 2 +- core/src/test/resources/sql/agg.iq | 31 ++++- core/src/test/resources/sql/dummy.iq | 14 ++- .../apache/calcite/test/SqlOperatorTest.java | 20 ++-- 7 files changed, 166 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/type/DelegatingTypeSystem.java b/core/src/main/java/org/apache/calcite/rel/type/DelegatingTypeSystem.java index 96b8628a9bac..a83121319961 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/DelegatingTypeSystem.java +++ b/core/src/main/java/org/apache/calcite/rel/type/DelegatingTypeSystem.java @@ -16,6 +16,7 @@ */ package org.apache.calcite.rel.type; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlTypeName; import org.checkerframework.checker.nullness.qual.Nullable; @@ -76,9 +77,9 @@ protected DelegatingTypeSystem(RelDataTypeSystem typeSystem) { return typeSystem.deriveAvgAggType(typeFactory, argumentType); } - @Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory, + @Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory, SqlKind sqlKind, RelDataType arg0Type, RelDataType arg1Type) { - return typeSystem.deriveCovarType(typeFactory, arg0Type, arg1Type); + return typeSystem.deriveCovarType(typeFactory, sqlKind, arg0Type, arg1Type); } @Override public RelDataType deriveFractionalRankType(RelDataTypeFactory typeFactory) { diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java index ba2b0e6f6010..c11712f1f179 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java @@ -16,6 +16,7 @@ */ package org.apache.calcite.rel.type; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.util.Glossary; @@ -86,7 +87,7 @@ RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory, /** Returns the return type of a call to the {@code COVAR} aggregate function, * inferred from its argument types. */ - RelDataType deriveCovarType(RelDataTypeFactory typeFactory, + RelDataType deriveCovarType(RelDataTypeFactory typeFactory, SqlKind sqlKind, RelDataType arg0Type, RelDataType arg1Type); /** Returns the return type of the {@code CUME_DIST} and {@code PERCENT_RANK} diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java index d4a9cac0f8ce..ddbd9682b0eb 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java @@ -16,12 +16,25 @@ */ package org.apache.calcite.rel.type; +import org.apache.calcite.runtime.CalciteException; +import org.apache.calcite.runtime.Resources; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.BasicSqlType; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidatorException; + +import com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.List; + /** Default implementation of * {@link org.apache.calcite.rel.type.RelDataTypeSystem}, * providing parameters from the SQL standard. @@ -255,8 +268,72 @@ && getDefaultPrecision(typeName) != -1) { return argumentType; } - @Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory, + @Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory, SqlKind sqlKind, RelDataType arg0Type, RelDataType arg1Type) { + switch (sqlKind) { + case REGR_SXX: + // REGR_SXX(x, y) → REGR_COUNT(x, y) * VAR_POP(y) + + // REGR_COUNT(x, y) + RelDataType regrCountTypeSXX = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.REGR_COUNT, arg0Type, arg1Type); + // VAR_POP(y) + RelDataType varPopTypeY = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.VAR_POP, arg1Type); + // REGR_COUNT(x, y) * VAR_POP(y) + return getOperatorRelDataType( + typeFactory, SqlStdOperatorTable.MULTIPLY, regrCountTypeSXX, varPopTypeY); + + case REGR_SYY: + // REGR_SYY(x, y) → REGR_COUNT(x, y) * VAR_POP(x) + + // REGR_COUNT(x, y) + RelDataType regrCountTypeSYY = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.REGR_COUNT, arg0Type, arg1Type); + // VAR_POP(x) + RelDataType varPopTypeX = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.VAR_POP, arg0Type); + // REGR_COUNT(x, y) * VAR_POP(y) + return getOperatorRelDataType( + typeFactory, SqlStdOperatorTable.MULTIPLY, regrCountTypeSYY, varPopTypeX); + + case COVAR_POP: + case COVAR_SAMP: + // COVAR_POP(x, y) → (SUM(x * y) - SUM(x) * SUM(y) / REGR_COUNT(x, y)) / REGR_COUNT(x, y) + // COVAR_SAMP(x, y) → (SUM(x * y) - SUM(x) * SUM(y) / REGR_COUNT(x, y)) / + // CASE REGR_COUNT(x, y) WHEN 1 THEN NULL ELSE REGR_COUNT(x, y) - 1 END + + // (x * y) + RelDataType multiplyXY = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.MULTIPLY, arg0Type, arg1Type); + // SUM(x * y) + RelDataType sumMultiplyXY = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.SUM, multiplyXY); + // SUM(x) + RelDataType sumX = getOperatorRelDataType(typeFactory, SqlStdOperatorTable.SUM, arg0Type); + // SUM(y) + RelDataType sumY = getOperatorRelDataType(typeFactory, SqlStdOperatorTable.SUM, arg1Type); + // SUM(x) * SUM(y) + RelDataType multiplySumXSumY = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.MULTIPLY, sumX, sumY); + // REGR_COUNT(x, y) + RelDataType regrCountXY = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.REGR_COUNT, arg0Type, arg1Type); + // SUM(x) * SUM(y) / REGR_COUNT(x, y) + RelDataType divide = + getOperatorRelDataType(typeFactory, + SqlStdOperatorTable.DIVIDE, multiplySumXSumY, regrCountXY); + // SUM(x * y) - SUM(x) * SUM(y) / REGR_COUNT(x, y) + RelDataType minus = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.MINUS, sumMultiplyXY, divide); + // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y)) / regr_count(x, y) + RelDataType relDataType = + getOperatorRelDataType(typeFactory, SqlStdOperatorTable.DIVIDE, minus, regrCountXY); + if (sqlKind == SqlKind.COVAR_POP) { + return relDataType; + } + return typeFactory.createTypeWithNullability(relDataType, true); + } return arg0Type; } @@ -278,4 +355,37 @@ && getDefaultPrecision(typeName) != -1) { return false; } + /** + * Implementation of the {@link SqlOperatorBinding} interface. + */ + public static class TypeCallBinding extends SqlOperatorBinding { + private final List operands; + + public TypeCallBinding(RelDataTypeFactory typeFactory, + SqlOperator sqlOperator, List operands) { + super(typeFactory, sqlOperator); + this.operands = operands; + } + + @Override public int getOperandCount() { + return operands.size(); + } + + @Override public RelDataType getOperandType(int ordinal) { + return operands.get(ordinal); + } + + @Override public CalciteException newError( + Resources.ExInst e) { + return SqlUtil.newContextException(SqlParserPos.ZERO, e); + } + } + + private RelDataType getOperatorRelDataType(RelDataTypeFactory typeFactory, + SqlOperator sqlOperator, RelDataType... argumentTypes) { + ImmutableList operatorTypeList = ImmutableList.copyOf(argumentTypes); + TypeCallBinding operatorBinding = + new TypeCallBinding(typeFactory, sqlOperator, operatorTypeList); + return sqlOperator.getReturnTypeInference().inferReturnType(operatorBinding); + } } diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index ea8cd95cfa5a..2fc5f9617e5c 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -1411,7 +1411,7 @@ private static RelDataType multivalentStringWithSepSumPrecision( public static final SqlReturnTypeInference COVAR_REGR_FUNCTION = opBinding -> { final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); final RelDataType relDataType = - typeFactory.getTypeSystem().deriveCovarType(typeFactory, + typeFactory.getTypeSystem().deriveCovarType(typeFactory, opBinding.getOperator().getKind(), opBinding.getOperandType(0), opBinding.getOperandType(1)); if (opBinding.getGroupCount() == 0 || opBinding.hasFilter()) { return typeFactory.createTypeWithNullability(relDataType, true); diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq index feeaf9035e66..2a05a177348d 100644 --- a/core/src/test/resources/sql/agg.iq +++ b/core/src/test/resources/sql/agg.iq @@ -2937,7 +2937,7 @@ from "scott".emp; +-----------------------+----------------------+---------------+---------------+ | COVAR_POP(COMM, COMM) | COVAR_SAMP(SAL, SAL) | VAR_POP(COMM) | VAR_SAMP(SAL) | +-----------------------+----------------------+---------------+---------------+ -| 272500.0000 | 1398313.8736 | 272500.0000 | 1398313.8736 | +| 272500.0000000000 | 1398313.873626374 | 272500.0000 | 1398313.8736 | +-----------------------+----------------------+---------------+---------------+ (1 row) @@ -2981,16 +2981,41 @@ group by MONTH(HIREDATE); | 1 | | | 1201250.0000 | | 11 | | | | | 12 | | | 1510833.3333 | -| 2 | -35000.0000 | 10000.0000 | 831458.3333 | | 4 | | | | | 5 | | | | | 6 | | | | -| 9 | -175000.0000 | 490000.0000 | 31250.0000 | +| 2 | -35000.00000000000 | 10000.0000 | 831458.3333 | +| 9 | -175000.0000000000 | 490000.0000 | 31250.0000 | +-------+-----------------------+---------------+---------------+ (8 rows) !ok +# [CALCITE-4924] REGR_SXX and similar aggregate functions return the wrong data type +SELECT + MONTH(HIREDATE) as "MONTH", + covar_samp(SAL, COMM) as "COVAR_SAMP(SAL, COMM)", + covar_pop(SAL, COMM) as "COVAR_POP(SAL, COMM)", + regr_syy(SAL, COMM) as "REGR_SYY(COMM)", + regr_sxx(SAL, COMM) as "REGR_SXX(SAL)" +from "scott".emp +group by MONTH(HIREDATE); ++-------+-----------------------+----------------------+----------------+---------------+ +| MONTH | COVAR_SAMP(SAL, COMM) | COVAR_POP(SAL, COMM) | REGR_SYY(COMM) | REGR_SXX(SAL) | ++-------+-----------------------+----------------------+----------------+---------------+ +| 1 | | | | | +| 11 | | | | | +| 12 | | | | | +| 2 | -35000.00000000000 | -17500.00000000000 | 61250.00 | 20000.00 | +| 4 | | | | | +| 5 | | | | | +| 6 | | | | | +| 9 | -175000.0000000000 | -87500.00000000000 | 31250.00 | 980000.00 | ++-------+-----------------------+----------------------+----------------+---------------+ +(8 rows) + +!ok + # [CALCITE-2224] WITHIN GROUP clause for aggregate functions select deptno, collect(empno) within group (order by empno asc) as empnos from "scott".emp diff --git a/core/src/test/resources/sql/dummy.iq b/core/src/test/resources/sql/dummy.iq index 166e1b06d870..0d8879617f0b 100644 --- a/core/src/test/resources/sql/dummy.iq +++ b/core/src/test/resources/sql/dummy.iq @@ -15,9 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # -!use post -values 1; +!use scott +select regr_sxx(SAL, COMM) from "scott".emp; EXPR$0 -1 +1090000.00 !ok + +EnumerableCalc(expr#0..2=[{inputs}], expr#3=[CAST($t0):DECIMAL(19, 2)], expr#4=[0], expr#5=[=($t2, $t4)], expr#6=[null:INTEGER], expr#7=[*($t1, $t1)], expr#8=[/($t7, $t2)], expr#9=[CASE($t5, $t6, $t8)], expr#10=[CAST($t9):DECIMAL(19, 2)], expr#11=[-($t3, $t10)], EXPR$0=[$t11]) + EnumerableAggregate(group=[{}], agg#0=[SUM($1) FILTER $2], agg#1=[SUM($0) FILTER $2], agg#2=[REGR_COUNT($0) FILTER $2]) + EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(19, 2)], expr#9=[*($t8, $t8)], expr#10=[IS NOT NULL($t8)], expr#11=[CAST($t5):DECIMAL(19, 2)], expr#12=[IS NOT NULL($t11)], expr#13=[AND($t10, $t12)], COMM=[$t6], $f8=[$t9], $f9=[$t13]) + EnumerableTableScan(table=[[scott, EMP]]) +!plan + + # End dummy.iq diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 1fa79f309976..4fe4804b564d 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -14771,10 +14771,10 @@ void testTimestampDiff(boolean coercionEnabled) { + "'COVAR_POP\\(, \\)'.*", false); f.checkType("covar_pop(cast(null as varchar(2)),cast(null as varchar(2)))", - "DECIMAL(19, 9)"); + "DECIMAL(19, 6)"); f.checkType("covar_pop(CAST(NULL AS INTEGER),CAST(NULL AS INTEGER))", - "INTEGER"); - f.checkAggType("covar_pop(1.5, 2.5)", "DECIMAL(2, 1) NOT NULL"); + "BIGINT"); + f.checkAggType("covar_pop(1.5, 2.5)", "DECIMAL(19, 6) NOT NULL"); if (!f.brokenTestsEnabled()) { return; } @@ -14795,10 +14795,10 @@ void testTimestampDiff(boolean coercionEnabled) { + "'COVAR_SAMP\\(, \\)'.*", false); f.checkType("covar_samp(cast(null as varchar(2)),cast(null as varchar(2)))", - "DECIMAL(19, 9)"); + "DECIMAL(19, 6)"); f.checkType("covar_samp(CAST(NULL AS INTEGER),CAST(NULL AS INTEGER))", - "INTEGER"); - f.checkAggType("covar_samp(1.5, 2.5)", "DECIMAL(2, 1) NOT NULL"); + "BIGINT"); + f.checkAggType("covar_samp(1.5, 2.5)", "DECIMAL(19, 6)"); if (!f.brokenTestsEnabled()) { return; } @@ -14821,8 +14821,8 @@ void testTimestampDiff(boolean coercionEnabled) { f.checkType("regr_sxx(cast(null as varchar(2)), cast(null as varchar(2)))", "DECIMAL(19, 9)"); f.checkType("regr_sxx(CAST(NULL AS INTEGER), CAST(NULL AS INTEGER))", - "INTEGER"); - f.checkAggType("regr_sxx(1.5, 2.5)", "DECIMAL(2, 1) NOT NULL"); + "BIGINT"); + f.checkAggType("regr_sxx(1.5, 2.5)", "DECIMAL(19, 1) NOT NULL"); if (!f.brokenTestsEnabled()) { return; } @@ -14845,8 +14845,8 @@ void testTimestampDiff(boolean coercionEnabled) { f.checkType("regr_syy(cast(null as varchar(2)), cast(null as varchar(2)))", "DECIMAL(19, 9)"); f.checkType("regr_syy(CAST(NULL AS INTEGER), CAST(NULL AS INTEGER))", - "INTEGER"); - f.checkAggType("regr_syy(1.5, 2.5)", "DECIMAL(2, 1) NOT NULL"); + "BIGINT"); + f.checkAggType("regr_syy(1.5, 2.5)", "DECIMAL(19, 1) NOT NULL"); if (!f.brokenTestsEnabled()) { return; }