Skip to content

Commit

Permalink
[CALCITE-6720] Refactor cross product logic in RelMdUniqueKeys#getPas…
Browse files Browse the repository at this point in the history
…sedThroughCols using Linq4j#product

The RelMdUniqueKeys#getPassedThroughCols method exists only for performing a cross product of the various mappings between input and output "pass through" columns.

The entire method can be replaced by exploiting the built-in Linq4j#product API and few other utility methods.

After the refactoring the code is easier to follow and potentially more efficient since the result is computed gradually and we don't have to retain the entire cross product result in memory.
  • Loading branch information
zabetak authored and mihaibudiu committed Dec 6, 2024
1 parent 84ea4be commit c8a513b
Showing 1 changed file with 4 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ public Set<ImmutableBitSet> getUniqueKeys(Aggregate rel, RelMetadataQuery mq,
final ImmutableSet.Builder<ImmutableBitSet> keysBuilder = ImmutableSet.builder();
if (inputUniqueKeys != null) {
for (ImmutableBitSet inputKey : inputUniqueKeys) {
keysBuilder.addAll(getPassedThroughCols(inputKey, rel));
Iterable<List<Integer>> product =
Linq4j.product(Util.transform(inputKey, i -> getPassedThroughCols(i, rel)));
keysBuilder.addAll(Util.transform(product, ImmutableBitSet::of));
}
}

Expand Down Expand Up @@ -371,30 +373,6 @@ private static Set<ImmutableBitSet> filterSupersets(
return minimalKeys;
}

/**
* Given a set of columns in the input of an Aggregate rel, returns the set of mappings from the
* input columns to the output of the aggregations. A mapping for a particular column exists if
* it is part of a simple group by and/or it is "passed through" unmodified by a
* {@link RelMdColumnUniqueness#PASSTHROUGH_AGGREGATIONS pass-through aggregation function}.
*/
private static Set<ImmutableBitSet> getPassedThroughCols(
ImmutableBitSet inputColumns, Aggregate rel) {
checkArgument(Aggregate.isSimple(rel));
Set<ImmutableBitSet> conbinations = new HashSet<>();
conbinations.add(ImmutableBitSet.of());
for (Integer inputColumn : inputColumns.asSet()) {
final ImmutableBitSet passedThroughCols = getPassedThroughCols(inputColumn, rel);
final Set<ImmutableBitSet> crossProduct = new HashSet<>();
for (ImmutableBitSet set : conbinations) {
for (Integer passedThroughCol : passedThroughCols) {
crossProduct.add(set.rebuild().set(passedThroughCol).build());
}
}
conbinations = crossProduct;
}
return conbinations;
}

/**
* Given a column in the input of an Aggregate rel, returns the mappings from the input column to
* the output of the aggregations. A mapping for the column exists if it is part of a simple
Expand All @@ -403,6 +381,7 @@ private static Set<ImmutableBitSet> getPassedThroughCols(
*/
private static ImmutableBitSet getPassedThroughCols(Integer inputColumn,
Aggregate rel) {
checkArgument(Aggregate.isSimple(rel));
final ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
if (rel.getGroupSet().get(inputColumn)) {
builder.set(inputColumn);
Expand Down

0 comments on commit c8a513b

Please sign in to comment.