Skip to content

Commit

Permalink
[CALCITE-6744] Support getColumnOrigins for correlate in RelMdColumnO…
Browse files Browse the repository at this point in the history
…rigins
  • Loading branch information
suibianwanwan committed Dec 24, 2024
1 parent 648a832 commit 5aaf7c5
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@
package org.apache.calcite.rel.metadata;

import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.core.CorrelationId;

import org.checkerframework.checker.nullness.qual.Nullable;

import static java.util.Objects.requireNonNull;

/**
* RelColumnOrigin is a data structure describing one of the origins of an
* output column produced by a relational expression.
*/
public class RelColumnOrigin {
//~ Instance fields --------------------------------------------------------

private final RelOptTable originTable;
@Nullable private final RelOptTable originTable;

@Nullable private final CorrelationId correlationId;

private final int iOriginColumn;

private final boolean isCorVar;

private final boolean isDerived;

//~ Constructors -----------------------------------------------------------
Expand All @@ -39,18 +46,40 @@ public RelColumnOrigin(
RelOptTable originTable,
int iOriginColumn,
boolean isDerived) {
this(originTable, null, iOriginColumn, isDerived, false);
}

public RelColumnOrigin(
CorrelationId correlationId,
int iOriginColumn,
boolean isDerived) {
this(null, correlationId, iOriginColumn, isDerived, true);
}

public RelColumnOrigin(@Nullable RelOptTable originTable,
@Nullable CorrelationId correlationId,
int iOriginColumn,
boolean isDerived,
boolean isCorVar) {
this.originTable = originTable;
this.correlationId = correlationId;
this.iOriginColumn = iOriginColumn;
this.isDerived = isDerived;
this.isCorVar = isCorVar;
}

//~ Methods ----------------------------------------------------------------

/** Returns table of origin. */
public RelOptTable getOriginTable() {
/** Returns table of origin and null only if isCorVar is true. */
@Nullable public RelOptTable getOriginTable() {
return originTable;
}

/** Returns correlateId of origin and null only if isCorVar is true. */
@Nullable public CorrelationId getCorrelationId() {
return correlationId;
}

/** Returns the 0-based index of column in origin table; whether this ordinal
* is flattened or unflattened depends on whether UDT flattening has already
* been performed on the relational expression which produced this
Expand All @@ -71,21 +100,39 @@ public boolean isDerived() {
return isDerived;
}

/** Returns whether this columnOrigin is from an external Correlate field. */
public boolean isCorVar() {
return isCorVar;
}

// override Object
@Override public boolean equals(@Nullable Object obj) {
if (!(obj instanceof RelColumnOrigin)) {
return false;
}
RelColumnOrigin other = (RelColumnOrigin) obj;
return originTable.getQualifiedName().equals(
other.originTable.getQualifiedName())
&& (iOriginColumn == other.iOriginColumn)
&& (isDerived == other.isDerived);

if (isCorVar != other.isCorVar
|| iOriginColumn != other.iOriginColumn
|| isDerived != other.isDerived) {
return false;
}

if (isCorVar) {
return requireNonNull(correlationId, "correlationId")
.equals(requireNonNull(other.getCorrelationId(), "other correlationId"));
}
return requireNonNull(originTable, "originTable").getQualifiedName()
.equals(requireNonNull(other.getOriginTable(), "originTable").getQualifiedName());
}

// override Object
@Override public int hashCode() {
return originTable.getQualifiedName().hashCode()
if (isCorVar) {
return requireNonNull(correlationId, "correlationId").hashCode()
+ iOriginColumn + (isDerived ? 313 : 0);
}
return requireNonNull(originTable, "originTable").getQualifiedName().hashCode()
+ iOriginColumn + (isDerived ? 313 : 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
Expand All @@ -32,6 +33,8 @@
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
Expand Down Expand Up @@ -143,6 +146,33 @@ private RelMdColumnOrigins() {}
return createDerivedColumnOrigins(set);
}

public @Nullable Set<RelColumnOrigin> getColumnOrigins(Correlate rel,
RelMetadataQuery mq, int iOutputColumn) {
int nLeftColumns = rel.getLeft().getRowType().getFieldList().size();
if (iOutputColumn < nLeftColumns) {
return mq.getColumnOrigins(rel.getLeft(), iOutputColumn);
}
Set<RelColumnOrigin> result = new HashSet<>();

Set<RelColumnOrigin> columnOrigins =
mq.getColumnOrigins(rel.getRight(), iOutputColumn - nLeftColumns);
if (columnOrigins != null) {
for (RelColumnOrigin columnOrigin : columnOrigins) {
// If corVar, get the origin column of the left input.
if (columnOrigin.isCorVar()) {
Set<RelColumnOrigin> corVarOrigin =
mq.getColumnOrigins(rel.getLeft(), columnOrigin.getOriginColumnOrdinal());
if (corVarOrigin != null) {
result.addAll(corVarOrigin);
}
continue;
}
result.add(columnOrigin);
}
}
return rel.getJoinType().generatesNullsOnRight() ? createDerivedColumnOrigins(result) : result;
}

public @Nullable Set<RelColumnOrigin> getColumnOrigins(Calc rel,
final RelMetadataQuery mq, int iOutputColumn) {
final RelNode input = rel.getInput();
Expand Down Expand Up @@ -281,10 +311,8 @@ private RelMdColumnOrigins() {}
final Set<RelColumnOrigin> set = new HashSet<>();
for (RelColumnOrigin rco : inputSet) {
RelColumnOrigin derived =
new RelColumnOrigin(
rco.getOriginTable(),
rco.getOriginColumnOrdinal(),
true);
new RelColumnOrigin(rco.getOriginTable(), rco.getCorrelationId(),
rco.getOriginColumnOrdinal(), true, rco.isCorVar());
set.add(derived);
}
return set;
Expand All @@ -303,6 +331,17 @@ private static Set<RelColumnOrigin> getMultipleColumns(RexNode rexNode, RelNode
}
return null;
}

@Override public Void visitFieldAccess(RexFieldAccess fieldAccess) {
final RexNode ref = fieldAccess.getReferenceExpr();
if (ref instanceof RexCorrelVariable) {
RexCorrelVariable variable = (RexCorrelVariable) ref;
RelColumnOrigin columnOrigin =
new RelColumnOrigin(variable.id, fieldAccess.getField().getIndex(), false);
set.add(columnOrigin);
}
return null;
}
};
rexNode.accept(visitor);
return set;
Expand Down
52 changes: 52 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Intersect;
Expand Down Expand Up @@ -143,6 +144,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;

import static org.apache.calcite.test.Matchers.hasFieldNames;
import static org.apache.calcite.test.Matchers.hasTree;
import static org.apache.calcite.test.Matchers.isAlmost;
import static org.apache.calcite.test.Matchers.sortsAs;

Expand Down Expand Up @@ -458,6 +460,56 @@ void testColumnOriginsUnion() {
equalTo("SAL"));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6744">[CALCITE-6744]
* Support getColumnOrigins for correlate in RelMdColumnOrigins</a>. */
@Test void testColumnOriginsForCorrelate() {
final String sql = "select (select max(dept.name || '_' || emp.ename)"
+ "from dept where emp.deptno = dept.deptno) from emp";
final RelMetadataFixture fixture = sql(sql);

final HepProgramBuilder programBuilder = HepProgram.builder();
programBuilder.addRuleInstance(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE);
final HepPlanner planner = new HepPlanner(programBuilder.build());
planner.setRoot(fixture.toRel());
final RelNode relNode = planner.findBestExp();

String expect = "LogicalProject(EXPR$0=[$9])\n"
+ " LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{1, 7}])\n"
+ " LogicalTableScan(table=[[CATALOG, SALES, EMP]])\n"
+ " LogicalAggregate(group=[{}], EXPR$0=[MAX($0)])\n"
+ " LogicalProject($f0=[||(||($1, '_'), $cor1.ENAME)])\n"
+ " LogicalFilter(condition=[=($cor1.DEPTNO, $0)])\n"
+ " LogicalTableScan(table=[[CATALOG, SALES, DEPT]])\n";
assertThat(relNode, hasTree(expect));

fixture.withRelTransform(a -> relNode)
.assertColumnOriginDouble("EMP", "ENAME",
"DEPT", "NAME", true);

// check correlate input column origins
final RelMetadataFixture.MetadataConfig metadataConfig =
fixture.metadataConfig;
final RelMetadataQuery mq =
new RelMetadataQuery(metadataConfig.getDefaultHandlerProvider());
Set<RelColumnOrigin> origins =
mq.getColumnOrigins(relNode.getInput(0).getInput(1), 0);

assertThat(origins, hasSize(2));
for (RelColumnOrigin origin : origins) {
if (origin.isCorVar()) {
CorrelationId correlationId = origin.getCorrelationId();
assertThat(correlationId, notNullValue());
assertThat(correlationId.getName(), equalTo("$cor1"));
assertThat(origin.getOriginColumnOrdinal(), equalTo(1));
continue;
}
assertThat(origin.getOriginTable(), notNullValue());
assertThat(origin.getOriginTable().getQualifiedName().get(2), equalTo("DEPT"));
assertThat(origin.getOriginColumnOrdinal(), equalTo(1));
}
}

// ----------------------------------------------------------------------
// Tests for getRowCount, getMinRowCount, getMaxRowCount
// ----------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ private java.util.Set getColumnOrigins_(
return provider0.getColumnOrigins((org.apache.calcite.rel.core.Aggregate) r, mq, a2);
} else if (r instanceof org.apache.calcite.rel.core.Calc) {
return provider0.getColumnOrigins((org.apache.calcite.rel.core.Calc) r, mq, a2);
} else if (r instanceof org.apache.calcite.rel.core.Correlate) {
return provider0.getColumnOrigins((org.apache.calcite.rel.core.Correlate) r, mq, a2);
} else if (r instanceof org.apache.calcite.rel.core.Exchange) {
return provider0.getColumnOrigins((org.apache.calcite.rel.core.Exchange) r, mq, a2);
} else if (r instanceof org.apache.calcite.rel.core.Filter) {
Expand Down
8 changes: 8 additions & 0 deletions site/_docs/history.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ large results set to a manageable value. Users that need a bigger/smaller limit
should create a new instance of `RelMdUniqueKeys` and register it using the
metadata provider of their choice.

* [<a href="https://issues.apache.org/jira/browse/CALCITE-6744">CALCITE-6744</a>]
Support getColumnOrigins for correlate in RelMdColumnOrigins.
In RelMetadataQuery#RelMdColumnOrigin, if the source of the column is an
external correlation variable, add the isCorVar and correlationId fields
in RelColumnOrigin to indicate this. When isCorVar is true, the field
comes from a table outside the input RelNode. This also causes getOriginTable
to be nullable, so the user needs to first determine isCorVar when getting the originTable.

#### New features
{: #new-features-1-39-0}

Expand Down

0 comments on commit 5aaf7c5

Please sign in to comment.