From c8a513b96185374345ea24e64deebfdbcd126268 Mon Sep 17 00:00:00 2001 From: Stamatis Zampetakis Date: Fri, 6 Dec 2024 12:51:37 +0100 Subject: [PATCH] [CALCITE-6720] Refactor cross product logic in RelMdUniqueKeys#getPassedThroughCols using Linq4j#product The RelMdUniqueKeys#getPassedThroughCols method exists only for performing a cross product of the various mappings between input and output "pass through" columns. The entire method can be replaced by exploiting the built-in Linq4j#product API and few other utility methods. After the refactoring the code is easier to follow and potentially more efficient since the result is computed gradually and we don't have to retain the entire cross product result in memory. --- .../calcite/rel/metadata/RelMdUniqueKeys.java | 29 +++---------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUniqueKeys.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUniqueKeys.java index 8745d060646..ced86022cb9 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUniqueKeys.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUniqueKeys.java @@ -336,7 +336,9 @@ public Set getUniqueKeys(Aggregate rel, RelMetadataQuery mq, final ImmutableSet.Builder keysBuilder = ImmutableSet.builder(); if (inputUniqueKeys != null) { for (ImmutableBitSet inputKey : inputUniqueKeys) { - keysBuilder.addAll(getPassedThroughCols(inputKey, rel)); + Iterable> product = + Linq4j.product(Util.transform(inputKey, i -> getPassedThroughCols(i, rel))); + keysBuilder.addAll(Util.transform(product, ImmutableBitSet::of)); } } @@ -371,30 +373,6 @@ private static Set filterSupersets( return minimalKeys; } - /** - * Given a set of columns in the input of an Aggregate rel, returns the set of mappings from the - * input columns to the output of the aggregations. A mapping for a particular column exists if - * it is part of a simple group by and/or it is "passed through" unmodified by a - * {@link RelMdColumnUniqueness#PASSTHROUGH_AGGREGATIONS pass-through aggregation function}. - */ - private static Set getPassedThroughCols( - ImmutableBitSet inputColumns, Aggregate rel) { - checkArgument(Aggregate.isSimple(rel)); - Set conbinations = new HashSet<>(); - conbinations.add(ImmutableBitSet.of()); - for (Integer inputColumn : inputColumns.asSet()) { - final ImmutableBitSet passedThroughCols = getPassedThroughCols(inputColumn, rel); - final Set crossProduct = new HashSet<>(); - for (ImmutableBitSet set : conbinations) { - for (Integer passedThroughCol : passedThroughCols) { - crossProduct.add(set.rebuild().set(passedThroughCol).build()); - } - } - conbinations = crossProduct; - } - return conbinations; - } - /** * Given a column in the input of an Aggregate rel, returns the mappings from the input column to * the output of the aggregations. A mapping for the column exists if it is part of a simple @@ -403,6 +381,7 @@ private static Set getPassedThroughCols( */ private static ImmutableBitSet getPassedThroughCols(Integer inputColumn, Aggregate rel) { + checkArgument(Aggregate.isSimple(rel)); final ImmutableBitSet.Builder builder = ImmutableBitSet.builder(); if (rel.getGroupSet().get(inputColumn)) { builder.set(inputColumn);