diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelColumnOrigin.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelColumnOrigin.java index 974501b50bdd..57d25b4cda77 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelColumnOrigin.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelColumnOrigin.java @@ -17,9 +17,12 @@ package org.apache.calcite.rel.metadata; import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.core.CorrelationId; import org.checkerframework.checker.nullness.qual.Nullable; +import static java.util.Objects.requireNonNull; + /** * RelColumnOrigin is a data structure describing one of the origins of an * output column produced by a relational expression. @@ -27,10 +30,14 @@ public class RelColumnOrigin { //~ Instance fields -------------------------------------------------------- - private final RelOptTable originTable; + private final @Nullable RelOptTable originTable; + + private final @Nullable CorrelationId correlationId; private final int iOriginColumn; + private final boolean isCorVar; + private final boolean isDerived; //~ Constructors ----------------------------------------------------------- @@ -39,18 +46,40 @@ public RelColumnOrigin( RelOptTable originTable, int iOriginColumn, boolean isDerived) { + this(originTable, null, iOriginColumn, isDerived, false); + } + + public RelColumnOrigin( + CorrelationId correlationId, + int iOriginColumn, + boolean isDerived) { + this(null, correlationId, iOriginColumn, isDerived, true); + } + + private RelColumnOrigin(@Nullable RelOptTable originTable, + @Nullable CorrelationId correlationId, + int iOriginColumn, + boolean isDerived, + boolean isCorVar) { this.originTable = originTable; + this.correlationId = correlationId; this.iOriginColumn = iOriginColumn; this.isDerived = isDerived; + this.isCorVar = isCorVar; } //~ Methods ---------------------------------------------------------------- - /** Returns table of origin. */ - public RelOptTable getOriginTable() { + /** Returns table of origin and null only if isCorVar is true. */ + public @Nullable RelOptTable getOriginTable() { return originTable; } + /** Returns correlateId of origin and null only if isCorVar is true. */ + public @Nullable CorrelationId getCorrelationId() { + return correlationId; + } + /** Returns the 0-based index of column in origin table; whether this ordinal * is flattened or unflattened depends on whether UDT flattening has already * been performed on the relational expression which produced this @@ -71,21 +100,48 @@ public boolean isDerived() { return isDerived; } + /** Returns whether this columnOrigin is from an external Correlate field. */ + public boolean isCorVar() { + return isCorVar; + } + + public RelColumnOrigin copyWith(boolean isDerived) { + if (isCorVar) { + return new RelColumnOrigin( + requireNonNull(correlationId, "correlationId"), iOriginColumn, isDerived); + } + return new RelColumnOrigin( + requireNonNull(originTable, "originTable"), iOriginColumn, isDerived); + } + // override Object @Override public boolean equals(@Nullable Object obj) { if (!(obj instanceof RelColumnOrigin)) { return false; } RelColumnOrigin other = (RelColumnOrigin) obj; - return originTable.getQualifiedName().equals( - other.originTable.getQualifiedName()) - && (iOriginColumn == other.iOriginColumn) - && (isDerived == other.isDerived); + + if (isCorVar != other.isCorVar + || iOriginColumn != other.iOriginColumn + || isDerived != other.isDerived) { + return false; + } + + if (isCorVar) { + return requireNonNull(correlationId, "correlationId") + .equals(requireNonNull(other.getCorrelationId(), "other correlationId")); + } + return requireNonNull(originTable, "originTable").getQualifiedName() + .equals(requireNonNull(other.getOriginTable(), "originTable").getQualifiedName()); } // override Object @Override public int hashCode() { - return originTable.getQualifiedName().hashCode() + if (isCorVar) { + return requireNonNull(correlationId, "correlationId").hashCode() + + iOriginColumn + (isDerived ? 313 : 0); + } + return requireNonNull(originTable, "originTable").getQualifiedName().hashCode() + iOriginColumn + (isDerived ? 313 : 0); } } diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnOrigins.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnOrigins.java index cbf48990fff5..d96e9a0ec0b6 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnOrigins.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnOrigins.java @@ -21,6 +21,8 @@ import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; @@ -32,6 +34,8 @@ import org.apache.calcite.rel.core.TableFunctionScan; import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; @@ -47,6 +51,8 @@ import java.util.List; import java.util.Set; +import static java.util.Objects.requireNonNull; + /** * RelMdColumnOrigins supplies a default implementation of * {@link RelMetadataQuery#getColumnOrigins} for the standard logical algebra. @@ -143,6 +149,37 @@ private RelMdColumnOrigins() {} return createDerivedColumnOrigins(set); } + public @Nullable Set getColumnOrigins(Correlate rel, + RelMetadataQuery mq, int iOutputColumn) { + int nLeftColumns = rel.getLeft().getRowType().getFieldList().size(); + if (iOutputColumn < nLeftColumns) { + return mq.getColumnOrigins(rel.getLeft(), iOutputColumn); + } + Set result = new HashSet<>(); + + Set columnOrigins = + mq.getColumnOrigins(rel.getRight(), iOutputColumn - nLeftColumns); + if (columnOrigins != null) { + for (RelColumnOrigin columnOrigin : columnOrigins) { + // If corVar, get the origin column of the left input. + if (columnOrigin.isCorVar()) { + CorrelationId correlationId = + requireNonNull(columnOrigin.getCorrelationId(), "correlationId"); + if (correlationId.equals(rel.getCorrelationId())) { + Set corVarOrigin = + mq.getColumnOrigins(rel.getLeft(), columnOrigin.getOriginColumnOrdinal()); + if (corVarOrigin != null) { + result.addAll(corVarOrigin); + } + continue; + } + } + result.add(columnOrigin); + } + } + return rel.getJoinType().generatesNullsOnRight() ? createDerivedColumnOrigins(result) : result; + } + public @Nullable Set getColumnOrigins(Calc rel, final RelMetadataQuery mq, int iOutputColumn) { final RelNode input = rel.getInput(); @@ -280,12 +317,7 @@ private RelMdColumnOrigins() {} } final Set set = new HashSet<>(); for (RelColumnOrigin rco : inputSet) { - RelColumnOrigin derived = - new RelColumnOrigin( - rco.getOriginTable(), - rco.getOriginColumnOrdinal(), - true); - set.add(derived); + set.add(rco.copyWith(true)); } return set; } @@ -303,6 +335,17 @@ private static Set getMultipleColumns(RexNode rexNode, RelNode } return null; } + + @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { + final RexNode ref = fieldAccess.getReferenceExpr(); + if (ref instanceof RexCorrelVariable) { + RexCorrelVariable variable = (RexCorrelVariable) ref; + RelColumnOrigin columnOrigin = + new RelColumnOrigin(variable.id, fieldAccess.getField().getIndex(), false); + set.add(columnOrigin); + } + return null; + } }; rexNode.accept(visitor); return set; diff --git a/core/src/main/java/org/apache/calcite/rel/rules/LoptSemiJoinOptimizer.java b/core/src/main/java/org/apache/calcite/rel/rules/LoptSemiJoinOptimizer.java index 724884641dec..6ba4deb35fbd 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/LoptSemiJoinOptimizer.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/LoptSemiJoinOptimizer.java @@ -418,12 +418,12 @@ private RexNode adjustSemiJoinCondition( mq.getColumnOrigin(factRel, keyIter.next()); // can't use the rid column as a semijoin key - if ((colOrigin == null || !colOrigin.isDerived()) + if ((colOrigin == null || colOrigin.isCorVar() || !colOrigin.isDerived()) || LucidDbSpecialOperators.isLcsRidColumnId( colOrigin.getOriginColumnOrdinal())) { removeKey = true; } else { - RelOptTable table = colOrigin.getOriginTable(); + RelOptTable table = requireNonNull(colOrigin.getOriginTable(), "originTable"); if (theTable == null) { if (!(table instanceof LcsTable)) { // not a column store table diff --git a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java index bbe599923dd6..2de7183434e6 100644 --- a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java @@ -43,6 +43,7 @@ import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Exchange; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Intersect; @@ -143,6 +144,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.calcite.test.Matchers.hasFieldNames; +import static org.apache.calcite.test.Matchers.hasTree; import static org.apache.calcite.test.Matchers.isAlmost; import static org.apache.calcite.test.Matchers.sortsAs; @@ -458,6 +460,56 @@ void testColumnOriginsUnion() { equalTo("SAL")); } + /** Test case for + * [CALCITE-6744] + * Support getColumnOrigins for correlate in RelMdColumnOrigins. */ + @Test void testColumnOriginsForCorrelate() { + final String sql = "select (select max(dept.name || '_' || emp.ename)" + + "from dept where emp.deptno = dept.deptno) from emp"; + final RelMetadataFixture fixture = sql(sql); + + final HepProgramBuilder programBuilder = HepProgram.builder(); + programBuilder.addRuleInstance(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE); + final HepPlanner planner = new HepPlanner(programBuilder.build()); + planner.setRoot(fixture.toRel()); + final RelNode relNode = planner.findBestExp(); + + String expect = "LogicalProject(EXPR$0=[$9])\n" + + " LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{1, 7}])\n" + + " LogicalTableScan(table=[[CATALOG, SALES, EMP]])\n" + + " LogicalAggregate(group=[{}], EXPR$0=[MAX($0)])\n" + + " LogicalProject($f0=[||(||($1, '_'), $cor1.ENAME)])\n" + + " LogicalFilter(condition=[=($cor1.DEPTNO, $0)])\n" + + " LogicalTableScan(table=[[CATALOG, SALES, DEPT]])\n"; + assertThat(relNode, hasTree(expect)); + + fixture.withRelTransform(a -> relNode) + .assertColumnOriginDouble("EMP", "ENAME", + "DEPT", "NAME", true); + + // check correlate input column origins + final RelMetadataFixture.MetadataConfig metadataConfig = + fixture.metadataConfig; + final RelMetadataQuery mq = + new RelMetadataQuery(metadataConfig.getDefaultHandlerProvider()); + Set origins = + mq.getColumnOrigins(relNode.getInput(0).getInput(1), 0); + + assertThat(origins, hasSize(2)); + for (RelColumnOrigin origin : origins) { + if (origin.isCorVar()) { + CorrelationId correlationId = origin.getCorrelationId(); + assertThat(correlationId, notNullValue()); + assertThat(correlationId.getName(), equalTo("$cor1")); + assertThat(origin.getOriginColumnOrdinal(), equalTo(1)); + continue; + } + assertThat(origin.getOriginTable(), notNullValue()); + assertThat(origin.getOriginTable().getQualifiedName().get(2), equalTo("DEPT")); + assertThat(origin.getOriginColumnOrdinal(), equalTo(1)); + } + } + // ---------------------------------------------------------------------- // Tests for getRowCount, getMinRowCount, getMaxRowCount // ---------------------------------------------------------------------- diff --git a/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_ColumnOriginHandler.java b/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_ColumnOriginHandler.java index ddb1a317b35e..53312129aea3 100644 --- a/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_ColumnOriginHandler.java +++ b/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_ColumnOriginHandler.java @@ -72,6 +72,8 @@ private java.util.Set getColumnOrigins_( return provider0.getColumnOrigins((org.apache.calcite.rel.core.Aggregate) r, mq, a2); } else if (r instanceof org.apache.calcite.rel.core.Calc) { return provider0.getColumnOrigins((org.apache.calcite.rel.core.Calc) r, mq, a2); + } else if (r instanceof org.apache.calcite.rel.core.Correlate) { + return provider0.getColumnOrigins((org.apache.calcite.rel.core.Correlate) r, mq, a2); } else if (r instanceof org.apache.calcite.rel.core.Exchange) { return provider0.getColumnOrigins((org.apache.calcite.rel.core.Exchange) r, mq, a2); } else if (r instanceof org.apache.calcite.rel.core.Filter) { diff --git a/site/_docs/history.md b/site/_docs/history.md index feaa06069dcd..c0782075ae9c 100644 --- a/site/_docs/history.md +++ b/site/_docs/history.md @@ -68,6 +68,14 @@ large results set to a manageable value. Users that need a bigger/smaller limit should create a new instance of `RelMdUniqueKeys` and register it using the metadata provider of their choice. +* [CALCITE-6744] +Support getColumnOrigins for correlate in RelMdColumnOrigins. +In RelMetadataQuery#RelMdColumnOrigin, if the source of the column is an +external correlation variable, add the isCorVar and correlationId fields +in RelColumnOrigin to indicate this. When isCorVar is true, the field +comes from a table outside the input RelNode. This also causes getOriginTable +to be nullable, so the user needs to first determine isCorVar when getting the originTable. + #### New features {: #new-features-1-39-0}