-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CALCITE-6744] Support getColumnOrigins for correlate in RelMdColumnOrigins #4109
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,8 @@ | |
import org.apache.calcite.rel.core.Aggregate; | ||
import org.apache.calcite.rel.core.AggregateCall; | ||
import org.apache.calcite.rel.core.Calc; | ||
import org.apache.calcite.rel.core.Correlate; | ||
import org.apache.calcite.rel.core.CorrelationId; | ||
import org.apache.calcite.rel.core.Exchange; | ||
import org.apache.calcite.rel.core.Filter; | ||
import org.apache.calcite.rel.core.Join; | ||
|
@@ -32,6 +34,8 @@ | |
import org.apache.calcite.rel.core.TableFunctionScan; | ||
import org.apache.calcite.rel.core.TableModify; | ||
import org.apache.calcite.rel.core.TableScan; | ||
import org.apache.calcite.rex.RexCorrelVariable; | ||
import org.apache.calcite.rex.RexFieldAccess; | ||
import org.apache.calcite.rex.RexInputRef; | ||
import org.apache.calcite.rex.RexLocalRef; | ||
import org.apache.calcite.rex.RexNode; | ||
|
@@ -47,6 +51,8 @@ | |
import java.util.List; | ||
import java.util.Set; | ||
|
||
import static java.util.Objects.requireNonNull; | ||
|
||
/** | ||
* RelMdColumnOrigins supplies a default implementation of | ||
* {@link RelMetadataQuery#getColumnOrigins} for the standard logical algebra. | ||
|
@@ -143,6 +149,37 @@ private RelMdColumnOrigins() {} | |
return createDerivedColumnOrigins(set); | ||
} | ||
|
||
public @Nullable Set<RelColumnOrigin> getColumnOrigins(Correlate rel, | ||
RelMetadataQuery mq, int iOutputColumn) { | ||
int nLeftColumns = rel.getLeft().getRowType().getFieldList().size(); | ||
if (iOutputColumn < nLeftColumns) { | ||
return mq.getColumnOrigins(rel.getLeft(), iOutputColumn); | ||
} | ||
Set<RelColumnOrigin> result = new HashSet<>(); | ||
|
||
Set<RelColumnOrigin> columnOrigins = | ||
mq.getColumnOrigins(rel.getRight(), iOutputColumn - nLeftColumns); | ||
if (columnOrigins != null) { | ||
for (RelColumnOrigin columnOrigin : columnOrigins) { | ||
// If corVar, get the origin column of the left input. | ||
if (columnOrigin.isCorVar()) { | ||
CorrelationId correlationId = | ||
requireNonNull(columnOrigin.getCorrelationId(), "correlationId"); | ||
if (correlationId.equals(rel.getCorrelationId())) { | ||
Set<RelColumnOrigin> corVarOrigin = | ||
mq.getColumnOrigins(rel.getLeft(), columnOrigin.getOriginColumnOrdinal()); | ||
if (corVarOrigin != null) { | ||
result.addAll(corVarOrigin); | ||
} | ||
continue; | ||
} | ||
} | ||
result.add(columnOrigin); | ||
} | ||
} | ||
return rel.getJoinType().generatesNullsOnRight() ? createDerivedColumnOrigins(result) : result; | ||
} | ||
|
||
public @Nullable Set<RelColumnOrigin> getColumnOrigins(Calc rel, | ||
final RelMetadataQuery mq, int iOutputColumn) { | ||
final RelNode input = rel.getInput(); | ||
|
@@ -280,12 +317,7 @@ private RelMdColumnOrigins() {} | |
} | ||
final Set<RelColumnOrigin> set = new HashSet<>(); | ||
for (RelColumnOrigin rco : inputSet) { | ||
RelColumnOrigin derived = | ||
new RelColumnOrigin( | ||
rco.getOriginTable(), | ||
rco.getOriginColumnOrdinal(), | ||
true); | ||
set.add(derived); | ||
set.add(rco.copyWith(true)); | ||
} | ||
return set; | ||
} | ||
|
@@ -303,6 +335,17 @@ private static Set<RelColumnOrigin> getMultipleColumns(RexNode rexNode, RelNode | |
} | ||
return null; | ||
} | ||
|
||
@Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this actually use as origin the table that correlated variable comes from (recursively)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I fully understand what you mean. There is a case where the RelNode that calls getColumnOrigin is not the top-level RelNode, so it may not get the original table. |
||
final RexNode ref = fieldAccess.getReferenceExpr(); | ||
if (ref instanceof RexCorrelVariable) { | ||
RexCorrelVariable variable = (RexCorrelVariable) ref; | ||
RelColumnOrigin columnOrigin = | ||
new RelColumnOrigin(variable.id, fieldAccess.getField().getIndex(), false); | ||
set.add(columnOrigin); | ||
} | ||
return null; | ||
} | ||
}; | ||
rexNode.accept(visitor); | ||
return set; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ | |
import org.apache.calcite.rel.core.Aggregate; | ||
import org.apache.calcite.rel.core.AggregateCall; | ||
import org.apache.calcite.rel.core.Correlate; | ||
import org.apache.calcite.rel.core.CorrelationId; | ||
import org.apache.calcite.rel.core.Exchange; | ||
import org.apache.calcite.rel.core.Filter; | ||
import org.apache.calcite.rel.core.Intersect; | ||
|
@@ -143,6 +144,7 @@ | |
import static com.google.common.collect.ImmutableList.toImmutableList; | ||
|
||
import static org.apache.calcite.test.Matchers.hasFieldNames; | ||
import static org.apache.calcite.test.Matchers.hasTree; | ||
import static org.apache.calcite.test.Matchers.isAlmost; | ||
import static org.apache.calcite.test.Matchers.sortsAs; | ||
|
||
|
@@ -458,6 +460,56 @@ void testColumnOriginsUnion() { | |
equalTo("SAL")); | ||
} | ||
|
||
/** Test case for | ||
* <a href="https://issues.apache.org/jira/browse/CALCITE-6744">[CALCITE-6744] | ||
* Support getColumnOrigins for correlate in RelMdColumnOrigins</a>. */ | ||
@Test void testColumnOriginsForCorrelate() { | ||
final String sql = "select (select max(dept.name || '_' || emp.ename)" | ||
+ "from dept where emp.deptno = dept.deptno) from emp"; | ||
final RelMetadataFixture fixture = sql(sql); | ||
|
||
final HepProgramBuilder programBuilder = HepProgram.builder(); | ||
programBuilder.addRuleInstance(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE); | ||
final HepPlanner planner = new HepPlanner(programBuilder.build()); | ||
planner.setRoot(fixture.toRel()); | ||
final RelNode relNode = planner.findBestExp(); | ||
|
||
String expect = "LogicalProject(EXPR$0=[$9])\n" | ||
+ " LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{1, 7}])\n" | ||
+ " LogicalTableScan(table=[[CATALOG, SALES, EMP]])\n" | ||
+ " LogicalAggregate(group=[{}], EXPR$0=[MAX($0)])\n" | ||
+ " LogicalProject($f0=[||(||($1, '_'), $cor1.ENAME)])\n" | ||
+ " LogicalFilter(condition=[=($cor1.DEPTNO, $0)])\n" | ||
+ " LogicalTableScan(table=[[CATALOG, SALES, DEPT]])\n"; | ||
assertThat(relNode, hasTree(expect)); | ||
|
||
fixture.withRelTransform(a -> relNode) | ||
.assertColumnOriginDouble("EMP", "ENAME", | ||
"DEPT", "NAME", true); | ||
|
||
// check correlate input column origins | ||
final RelMetadataFixture.MetadataConfig metadataConfig = | ||
fixture.metadataConfig; | ||
final RelMetadataQuery mq = | ||
new RelMetadataQuery(metadataConfig.getDefaultHandlerProvider()); | ||
Set<RelColumnOrigin> origins = | ||
mq.getColumnOrigins(relNode.getInput(0).getInput(1), 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this LogicalAggregate? Maybe you can say this in a comment. |
||
|
||
assertThat(origins, hasSize(2)); | ||
for (RelColumnOrigin origin : origins) { | ||
if (origin.isCorVar()) { | ||
CorrelationId correlationId = origin.getCorrelationId(); | ||
assertThat(correlationId, notNullValue()); | ||
assertThat(correlationId.getName(), equalTo("$cor1")); | ||
assertThat(origin.getOriginColumnOrdinal(), equalTo(1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. want to check that originTable is null too? |
||
continue; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in this case I think that an |
||
} | ||
assertThat(origin.getOriginTable(), notNullValue()); | ||
assertThat(origin.getOriginTable().getQualifiedName().get(2), equalTo("DEPT")); | ||
assertThat(origin.getOriginColumnOrdinal(), equalTo(1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this |
||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------- | ||
// Tests for getRowCount, getMinRowCount, getMaxRowCount | ||
// ---------------------------------------------------------------------- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this would be implemented using an interface and two subclasses, one for each kind of origin.
Not sure how easy this can be accomplished while keeping compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would also be difficult to keep compatible. Another compromise would be to keep the RelColumnOrigin unchanged and return null if the input field's CorField comes from outside the input RelNode.
This way we just need to add a Map to maintain the mapping of CorrelationId, Index to ColumnOrigin, and go top-down to getColumnOrigin.
public @Nullable Set<RelColumnOrigin> getColumnOrigins(Aggregate rel, RelMetadataQuery mq, int iOutputColumn, Map<Pair<CorrelationId, Integer>, ColumnOrigin>) corVarOriginMap) { ... }