Skip to content

Commit

Permalink
[SPARK-50559][SQL] Store Except, Intersect and Union's outputs as laz…
Browse files Browse the repository at this point in the history
…y vals

### What changes were proposed in this pull request?

Store `Except`, `Intersect` and `Union`'s outputs as lazy vals.

### Why are the changes needed?

Currently `Union`'s (same is for `Except` and `Intersect`) `output` is a `def`. This creates performance issues for queries with large number of stacked `UNION`s because of rules like `WidenSetOperationTypes` that traverse the logical plan and call `output` on each `Union` node. This has quadratic complexity: O(number_of_unions * (1 + 2 + 3 + ... + number_of_unions)).

Profile:
![image](https://github.com/user-attachments/assets/97192bf3-c38e-47dd-81ac-49ee9a546525)
![image](https://github.com/user-attachments/assets/68ed13d7-b108-4c8d-b156-1dcec07c324b)

[flamegraph.tar.gz](https://github.com/user-attachments/files/18118260/flamegraph.tar.gz)

The improvement in parsing + analysis wall-clock time for a query with 500 UNIONs over 30 columns each is 13x (5.5s -> 400ms):
![image](https://github.com/user-attachments/assets/a824c693-0a6b-4c6b-8a90-a783a3c44d6d)

Repro:
```
def genValues(num: Int) = s"VALUES (${(0 until num).mkString(", ")})"
def genUnions(numUnions: Int, numValues: Int) = (0 until numUnions).map(_ => genValues(numValues)).mkString(" UNION ALL ")
spark.time { spark.sql(s"SELECT * FROM ${genUnions(numUnions = 500, numValues = 30)}").queryExecution.analyzed }
```

 For `EXCEPT` the perf difference is not that noticeable. Perhaps because it reuses the same `Seq` (it just calls `left.output`).

### Does this PR introduce _any_ user-facing change?

No, this is an optimization.

### How was this patch tested?

- Ran the async-profiler
- Ran the benchmark in spark-shell.
- Existing tests.

### Was this patch authored or co-authored using generative AI tooling?

copilot.nvim.

Closes apache#49166 from vladimirg-db/vladimirg-db/store-union-output-as-lazy-val.

Authored-by: Vladimir Golubev <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
vladimirg-db authored and MaxGekk committed Dec 14, 2024
1 parent 2b9eb08 commit 976192a
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,13 @@ case class Intersect(

final override val nodePatterns: Seq[TreePattern] = Seq(INTERSECT)

override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
override def output: Seq[Attribute] = {
if (conf.getConf(SQLConf.LAZY_SET_OPERATOR_OUTPUT)) {
lazyOutput
} else {
computeOutput()
}
}

override def metadataOutput: Seq[Attribute] = Nil

Expand All @@ -396,15 +399,29 @@ case class Intersect(

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): Intersect = copy(left = newLeft, right = newRight)

private lazy val lazyOutput: Seq[Attribute] = computeOutput()

/** We don't use right.output because those rows get excluded from the set. */
private def computeOutput(): Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
}

case class Except(
left: LogicalPlan,
right: LogicalPlan,
isAll: Boolean) extends SetOperation(left, right) {
override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) " All" else "" )
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output

override def output: Seq[Attribute] = {
if (conf.getConf(SQLConf.LAZY_SET_OPERATOR_OUTPUT)) {
lazyOutput
} else {
computeOutput()
}
}

override def metadataOutput: Seq[Attribute] = Nil

Expand All @@ -416,6 +433,11 @@ case class Except(

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight)

private lazy val lazyOutput: Seq[Attribute] = computeOutput()

/** We don't use right.output because those rows get excluded from the set. */
private def computeOutput(): Seq[Attribute] = left.output
}

/** Factory for constructing new `Union` nodes. */
Expand Down Expand Up @@ -479,18 +501,11 @@ case class Union(
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
}

// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
firstAttr.exprId, firstAttr.qualifier)
}
if (conf.getConf(SQLConf.LAZY_SET_OPERATOR_OUTPUT)) {
lazyOutput
} else {
computeOutput()
}
}

Expand All @@ -509,6 +524,23 @@ case class Union(
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
}

private lazy val lazyOutput: Seq[Attribute] = computeOutput()

// updating nullability to make all the children consistent
private def computeOutput(): Seq[Attribute] = {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
firstAttr.exprId, firstAttr.qualifier)
}
}
}

/**
* Maps the constraints containing a given (original) sequence of attributes to those with a
* given (reference) sequence of attributes. Given the nature of union, we expect that the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5306,6 +5306,19 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LAZY_SET_OPERATOR_OUTPUT = buildConf("spark.sql.lazySetOperatorOutput.enabled")
.internal()
.doc(
"When set to true, Except/Intersect/Union operator's output will be a lazy val. It " +
"is a performance optimization for querires with a large number of stacked set operators. " +
"This is because of rules like WidenSetOperationTypes that traverse the logical plan tree " +
"and call output on each Except/Intersect/Union node. Such traversal has quadratic " +
"complexity: O(number_of_nodes * (1 + 2 + 3 + ... + number_of_nodes))."
)
.version("4.0.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
13 changes: 13 additions & 0 deletions sql/core/benchmarks/SetOperationsBenchmark-jdk21-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
================================================================================================
Set Operations Benchmark
================================================================================================

OpenJDK 64-Bit Server VM 21.0.5+11-Ubuntu-1ubuntu120.04 on Linux 5.4.0-1131-aws-fips
Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
Parsing + Analysis: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
UNION ALL 342 358 22 0.0 22784.2 1.0X
EXCEPT ALL 310 351 44 0.0 20675.4 1.1X
INTERSECT ALL 305 309 5 0.0 20301.6 1.1X


13 changes: 13 additions & 0 deletions sql/core/benchmarks/SetOperationsBenchmark-results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
================================================================================================
Set Operations Benchmark
================================================================================================

OpenJDK 64-Bit Server VM 17.0.12+7-Ubuntu-1ubuntu220.04 on Linux 5.4.0-1131-aws-fips
Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
Parsing + Analysis: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
UNION ALL 360 423 70 0.0 24019.4 1.0X
EXCEPT ALL 322 328 5 0.0 21463.2 1.1X
INTERSECT ALL 327 360 33 0.0 21777.2 1.1X


Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.benchmark

import org.apache.spark.benchmark.Benchmark

/**
* Benchmark to measure performance for set operations.
* To run this benchmark:
* {{{
* 1. without sbt:
* bin/spark-submit --class <this class>
* --jars <spark core test jar>,<spark catalyst test jar> <spark sql test jar>
* 2. build/sbt "sql/Test/runMain <this class>"
* 3. generate result:
* SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this class>"
* Results will be written to "benchmarks/SetOperationsBenchmark-results.txt".
* }}}
*/
object SetOperationsBenchmark extends SqlBasedBenchmark {
private val setOperations = Seq("UNION ALL", "EXCEPT ALL", "INTERSECT ALL")

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
runBenchmark("Set Operations Benchmark") {
val numOperations = 500
val numValues = 30

val benchmark =
new Benchmark(
"Parsing + Analysis",
valuesPerIteration = numOperations * numValues,
output = output
)

for (operation <- setOperations) {
benchmark.addCase(operation) { _ =>
spark
.sql(
generateQuery(
operation = operation,
numOperations = numOperations,
numValues = numValues
)
)
.queryExecution
.analyzed
()
}
}

benchmark.run()
}
}

private def generateQuery(operation: String, numOperations: Int, numValues: Int) = {
s"""
SELECT
*
FROM
${generateOperations(
operation = operation,
numOperations = numOperations,
numValues = numValues
)}
"""
}

private def generateOperations(operation: String, numOperations: Int, numValues: Int) = {
(0 until numOperations).map(_ => generateValues(numValues)).mkString(s" ${operation} ")
}

private def generateValues(num: Int) = {
s"VALUES (${(0 until num).mkString(", ")})"
}
}

0 comments on commit 976192a

Please sign in to comment.