diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f0a72cc9dd7ca..ef202724e2822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -178,15 +178,21 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Only using secondary join optimization when both lower and upper conditions // are specified (e.g. t1.a < t2.b + x and t1.a > t2.b - x) + val joinKeySet = joinKeys.map{ + case (a, b) => Set(a.references.toSet, b.references.toSet) }.flatten.flatten + val leftRangeCondRefs = rangeConditions.map( + c => c.left.references).flatten.toSet.diff(joinKeySet.toSet) + val rightRangeCondRefs = rangeConditions.map( + c => c.right.references).flatten.toSet.diff(joinKeySet.toSet) + if (rangeConditions.size != 2 || // Looking for one < and one > comparison: rangeConditions.forall(x => !x.isInstanceOf[LessThan] && !x.isInstanceOf[LessThanOrEqual]) || rangeConditions.forall(x => !x.isInstanceOf[GreaterThan] && !x.isInstanceOf[GreaterThanOrEqual]) || - // Check if both comparisons reference the same columns: - rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 || - rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) { + leftRangeCondRefs.size != 1 || + rightRangeCondRefs.size != 1) { logDebug("Inner range optimization conditions not met. Clearing range conditions") rangeConditions = Nil rangePreds.clear() @@ -212,7 +218,10 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { private def checkRangeConditions(l : Expression, r : Expression, left : LogicalPlan, right : LogicalPlan, joinKeys : Seq[(Expression, Expression)]): Option[Boolean] = { - val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) + val joinKeySet = joinKeys.map{ + case (a, b) => Set(a.references.toSet, b.references.toSet) }.flatten.flatten + val lattrs = l.references.toSet.diff(joinKeySet.toSet).toSeq + val rattrs = r.references.toSet.diff(joinKeySet.toSet).toSeq if (lattrs.size != 1 || rattrs.size != 1) { None } else if (canEvaluate(l, left) && canEvaluate(r, right)) {