From 1c08148730c82aa245393a748ddec2d8ab905c46 Mon Sep 17 00:00:00 2001
From: John Ed Quinn <40360967+johnedquinn@users.noreply.github.com>
Date: Tue, 19 Sep 2023 16:07:15 -0700
Subject: [PATCH] Adds support for thread interruption in compilation and
execution (#1211)
* Adds support for thread interruption in compilation and execution
* Adds support for CLI users to use CTRL-C
---
CHANGELOG.md | 4 +
.../partiql/cli/pipeline/AbstractPipeline.kt | 1 +
.../org/partiql/cli/shell/RunnablePipeline.kt | 56 ++++
.../org/partiql/cli/shell/RunnableWriter.kt | 53 +++
.../kotlin/org/partiql/cli/shell/Shell.kt | 316 +++++++++++-------
.../CompilerInterruptionBenchmark.kt | 291 ++++++++++++++++
.../org/partiql/lang/eval/CompileOptions.kt | 8 +-
.../partiql/lang/eval/EvaluatingCompiler.kt | 27 +-
.../kotlin/org/partiql/lang/eval/Thunk.kt | 8 +-
.../eval/EvaluatingCompilerInterruptTests.kt | 192 +++++++++++
.../PartiQLPigParserThreadInterruptTests.kt | 4 +-
11 files changed, 838 insertions(+), 122 deletions(-)
create mode 100644 partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt
create mode 100644 partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt
create mode 100644 partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt
create mode 100644 partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 49f416af9a..c853179b1e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -29,6 +29,10 @@ Thank you to all who have contributed!
## [Unreleased]
### Added
+- Adds `isInterruptible` property to `CompileOptions`. The default value is `false`. Please see the KDocs for more information.
+- Adds support for thread interruption in compilation and execution. If you'd like to opt-in to this addition, please see
+ the `isInterruptible` addition above for more information.
+- Adds support for CLI users to use CTRL-C to cancel long-running compilation/execution of queries
### Changed
diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt
index 925a2dbb14..ac8dcbc6c7 100644
--- a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt
+++ b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt
@@ -145,6 +145,7 @@ internal sealed class AbstractPipeline(open val options: PipelineOptions) {
projectionIteration(options.projectionIterationBehavior)
undefinedVariable(options.undefinedVariableBehavior)
typingMode(options.typingMode)
+ isInterruptible(true)
}
private val compilerPipeline = CompilerPipeline.build {
diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt
new file mode 100644
index 0000000000..e52b396fe1
--- /dev/null
+++ b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt
@@ -0,0 +1,56 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at:
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
+ */
+
+package org.partiql.cli.shell
+
+import org.partiql.cli.pipeline.AbstractPipeline
+import org.partiql.lang.eval.EvaluationSession
+import org.partiql.lang.eval.PartiQLResult
+import java.util.concurrent.BlockingQueue
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+
+/**
+ * A wrapper over [AbstractPipeline]. It constantly grabs input queries from [inputs] and places the [PartiQLResult]
+ * in [results]. When it is done compiling a single statement, it sets [doneCompiling] to true.
+ */
+internal class RunnablePipeline(
+ private val inputs: BlockingQueue,
+ private val results: BlockingQueue,
+ val pipeline: AbstractPipeline,
+ private val doneCompiling: AtomicBoolean
+) : Runnable {
+ /**
+ * When the Thread running this [Runnable] is interrupted, the underlying [AbstractPipeline] should catch the
+ * interruption and fail with some exception. Then, this will break out of [run].
+ */
+ override fun run() {
+ while (true) {
+ val input = inputs.poll(3, TimeUnit.SECONDS)
+ if (input != null) {
+ val result = pipeline.compile(input.input, input.session)
+ results.put(result)
+ doneCompiling.set(true)
+ }
+ }
+ }
+
+ /**
+ * Represents a PartiQL statement ([input]) and the [EvaluationSession] to evaluate with.
+ */
+ internal data class Input(
+ val input: String,
+ val session: EvaluationSession
+ )
+}
diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt
new file mode 100644
index 0000000000..caae3e1069
--- /dev/null
+++ b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt
@@ -0,0 +1,53 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at:
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
+ */
+
+package org.partiql.cli.shell
+
+import org.partiql.lang.eval.ExprValue
+import org.partiql.lang.util.ExprValueFormatter
+import java.io.PrintStream
+import java.util.concurrent.BlockingQueue
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+
+/**
+ * A wrapper over [ExprValueFormatter]. It constantly grabs [ExprValue]s from [values] and writes them to [out].
+ * When it is done printing a single item, it sets [donePrinting] to true.
+ */
+internal class RunnableWriter(
+ private val out: PrintStream,
+ private val formatter: ExprValueFormatter,
+ private val values: BlockingQueue,
+ private val donePrinting: AtomicBoolean
+) : Runnable {
+
+ /**
+ * When the Thread running this [Runnable] is interrupted, the underlying formatter should check the interruption
+ * flag and fail with some exception. The formatter itself doesn't do this, but, since [ExprValue]s are lazily created,
+ * the creation of the [ExprValue] (by means of the thunks produced by the EvaluatingCompiler) should throw an exception
+ * when the thread is interrupted. Then, this will break out of [run].
+ */
+ override fun run() {
+ while (true) {
+ val value = values.poll(3, TimeUnit.SECONDS)
+ if (value != null) {
+ out.info(BAR_1)
+ formatter.formatTo(value, out)
+ out.println()
+ out.info(BAR_2)
+ donePrinting.set(true)
+ }
+ }
+ }
+}
diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt
index 63fc5b914d..bb78d782ea 100644
--- a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt
+++ b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt
@@ -23,6 +23,7 @@ import org.jline.reader.LineReader
import org.jline.reader.LineReaderBuilder
import org.jline.reader.UserInterruptException
import org.jline.reader.impl.completer.AggregateCompleter
+import org.jline.terminal.Terminal
import org.jline.terminal.TerminalBuilder
import org.jline.utils.AttributedString
import org.jline.utils.AttributedStringBuilder
@@ -43,7 +44,6 @@ import org.partiql.lang.graph.ExternalGraphException
import org.partiql.lang.graph.ExternalGraphReader
import org.partiql.lang.syntax.PartiQLParserBuilder
import org.partiql.lang.util.ConfigurableExprValueFormatter
-import org.partiql.lang.util.ExprValueFormatter
import java.io.Closeable
import java.io.File
import java.io.OutputStream
@@ -52,15 +52,19 @@ import java.nio.file.Path
import java.nio.file.Paths
import java.util.Locale
import java.util.Properties
+import java.util.concurrent.ArrayBlockingQueue
+import java.util.concurrent.BlockingQueue
import java.util.concurrent.CountDownLatch
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy
private const val PROMPT_1 = "PartiQL> "
private const val PROMPT_2 = " | "
-private const val BAR_1 = "===' "
-private const val BAR_2 = "--- "
+internal const val BAR_1 = "===' "
+internal const val BAR_2 = "--- "
private const val WELCOME_MSG = "Welcome to the PartiQL shell!"
private const val DEBUG_MSG = """
■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
@@ -98,10 +102,15 @@ private val EXIT_DELAY: Duration = Duration(3000)
* Initial work to replace the REPL with JLine3. I have attempted to keep this similar to Repl.kt, but have some
* opinions on ways to clean this up in later PRs.
*/
+
+val exiting = AtomicBoolean(false)
+val doneCompiling = AtomicBoolean(false)
+val donePrinting = AtomicBoolean(false)
+
internal class Shell(
- private val output: OutputStream,
+ output: OutputStream,
private val compiler: AbstractPipeline,
- private val initialGlobal: Bindings,
+ initialGlobal: Bindings,
private val config: ShellConfiguration = ShellConfiguration()
) {
private val homeDir: Path = Paths.get(System.getProperty("user.home"))
@@ -110,10 +119,17 @@ internal class Shell(
private val out = PrintStream(output)
private val currentUser = System.getProperty("user.name")
+ private val inputs: BlockingQueue = ArrayBlockingQueue(1)
+ private val results: BlockingQueue = ArrayBlockingQueue(1)
+ private var pipelineService: ExecutorService = Executors.newFixedThreadPool(1)
+ private val values: BlockingQueue = ArrayBlockingQueue(1)
+ private var printingService: ExecutorService = Executors.newFixedThreadPool(1)
+
fun start() {
- val exiting = AtomicBoolean()
val interrupter = ThreadInterrupter()
val exited = CountDownLatch(1)
+ pipelineService.submit(RunnablePipeline(inputs, results, compiler, doneCompiling))
+ printingService.submit(RunnableWriter(out, ConfigurableExprValueFormatter.pretty, values, donePrinting))
Runtime
.getRuntime()
.addShutdownHook(
@@ -135,119 +151,135 @@ internal class Shell(
}
}
- private fun run(exiting: AtomicBoolean) = TerminalBuilder.builder().build().use { terminal ->
- val highlighter = when {
- this.config.isMonochrome -> null
- else -> ShellHighlighter()
- }
- val completer = AggregateCompleter(CompleterDefault())
- val reader = LineReaderBuilder.builder()
- .terminal(terminal)
- .parser(ShellParser)
- .completer(completer)
- .option(LineReader.Option.GROUP_PERSIST, true)
- .option(LineReader.Option.AUTO_LIST, true)
- .option(LineReader.Option.CASE_INSENSITIVE, true)
- .variable(LineReader.LIST_MAX, 10)
- .highlighter(highlighter)
- .expander(ShellExpander)
- .variable(LineReader.HISTORY_FILE, homeDir.resolve(".partiql/.history"))
- .variable(LineReader.SECONDARY_PROMPT_PATTERN, PROMPT_2)
- .build()
-
- out.info(WELCOME_MSG)
- out.info("Typing mode: ${compiler.options.typingMode.name}")
- out.info("Using version: ${retrievePartiQLVersionAndHash()}")
- if (compiler is AbstractPipeline.PipelineDebug) {
- out.println("\n\n")
- out.success(DEBUG_MSG)
- out.println("\n\n")
+ private val signalHandler = Terminal.SignalHandler { sig ->
+ if (sig == Terminal.Signal.INT) {
+ exiting.set(true)
}
+ }
- while (!exiting.get()) {
- val line: String = try {
- reader.readLine(PROMPT_1)
- } catch (ex: UserInterruptException) {
- if (ex.partialLine.isNotEmpty()) {
- reader.history.add(ex.partialLine)
- }
- continue
- } catch (ex: EndOfFileException) {
- out.info("^D")
- return
- }
+ private fun run(exiting: AtomicBoolean) = TerminalBuilder.builder()
+ .name("PartiQL")
+ .nativeSignals(true)
+ .signalHandler(signalHandler)
+ .build().use { terminal ->
- // Pretty print AST
- if (line.endsWith("\n!!")) {
- printAST(line.removeSuffix("!!"))
- continue
+ val highlighter = when {
+ this.config.isMonochrome -> null
+ else -> ShellHighlighter()
}
-
- if (line.isBlank()) {
- out.success("OK!")
- continue
+ val completer = AggregateCompleter(CompleterDefault())
+ val reader = LineReaderBuilder.builder()
+ .terminal(terminal)
+ .parser(ShellParser)
+ .completer(completer)
+ .option(LineReader.Option.GROUP_PERSIST, true)
+ .option(LineReader.Option.AUTO_LIST, true)
+ .option(LineReader.Option.CASE_INSENSITIVE, true)
+ .variable(LineReader.LIST_MAX, 10)
+ .highlighter(highlighter)
+ .expander(ShellExpander)
+ .variable(LineReader.HISTORY_FILE, homeDir.resolve(".partiql/.history"))
+ .variable(LineReader.SECONDARY_PROMPT_PATTERN, PROMPT_2)
+ .build()
+
+ out.info(WELCOME_MSG)
+ out.info("Typing mode: ${compiler.options.typingMode.name}")
+ out.info("Using version: ${retrievePartiQLVersionAndHash()}")
+ if (compiler is AbstractPipeline.PipelineDebug) {
+ out.println("\n\n")
+ out.success(DEBUG_MSG)
+ out.println("\n\n")
}
- // Handle commands
- val command = when (val end: Int = CharMatcher.`is`(';').or(CharMatcher.whitespace()).indexIn(line)) {
- -1 -> ""
- else -> line.substring(0, end)
- }.lowercase(Locale.ENGLISH).trim()
- when (command) {
- "!exit" -> return
- "!add_to_global_env" -> {
- // Consider PicoCLI + Jline, but it doesn't easily place nice with commands + raw SQL
- // https://github.com/partiql/partiql-lang-kotlin/issues/63
- val arg = requireInput(line, command) ?: continue
- executeAndPrint {
- val locals = refreshBindings()
- val result = evaluatePartiQL(arg, locals) as PartiQLResult.Value
- globals.add(result.value.bindings)
- result
+ while (!exiting.get()) {
+ val line: String = try {
+ reader.readLine(PROMPT_1)
+ } catch (ex: UserInterruptException) {
+ if (ex.partialLine.isNotEmpty()) {
+ reader.history.add(ex.partialLine)
}
continue
+ } catch (ex: EndOfFileException) {
+ out.info("^D")
+ return
}
- "!add_graph" -> {
- val input = requireInput(line, command) ?: continue
- val (name, graphStr) = requireTokenAndMore(input, command) ?: continue
- bringGraph(name, graphStr)
- continue
- }
- "!add_graph_from_file" -> {
- val input = requireInput(line, command) ?: continue
- val (name, filename) = requireTokenAndMore(input, command) ?: continue
- val graphStr = readTextFile(filename) ?: continue
- bringGraph(name, graphStr)
- continue
- }
- "!global_env" -> {
- executeAndPrint { AbstractPipeline.convertExprValue(globals.asExprValue()) }
+
+ // Pretty print AST
+ if (line.endsWith("\n!!")) {
+ printAST(line.removeSuffix("!!"))
continue
}
- "!clear" -> {
- terminal.puts(InfoCmp.Capability.clear_screen)
- terminal.flush()
+
+ if (line.isBlank()) {
+ out.success("OK!")
continue
}
- "!history" -> {
- for (entry in reader.history) {
- out.println(entry.pretty())
+
+ // Handle commands
+ val command = when (val end: Int = CharMatcher.`is`(';').or(CharMatcher.whitespace()).indexIn(line)) {
+ -1 -> ""
+ else -> line.substring(0, end)
+ }.lowercase(Locale.ENGLISH).trim()
+ when (command) {
+ "!exit" -> return
+ "!add_to_global_env" -> {
+ // Consider PicoCLI + Jline, but it doesn't easily place nice with commands + raw SQL
+ // https://github.com/partiql/partiql-lang-kotlin/issues/63
+ val arg = requireInput(line, command) ?: continue
+ executeAndPrint {
+ val locals = refreshBindings()
+ val result = evaluatePartiQL(
+ arg,
+ locals,
+ exiting
+ ) as PartiQLResult.Value
+ globals.add(result.value.bindings)
+ result
+ }
+ continue
+ }
+ "!add_graph" -> {
+ val input = requireInput(line, command) ?: continue
+ val (name, graphStr) = requireTokenAndMore(input, command) ?: continue
+ bringGraph(name, graphStr)
+ continue
+ }
+ "!add_graph_from_file" -> {
+ val input = requireInput(line, command) ?: continue
+ val (name, filename) = requireTokenAndMore(input, command) ?: continue
+ val graphStr = readTextFile(filename) ?: continue
+ bringGraph(name, graphStr)
+ continue
+ }
+ "!global_env" -> {
+ executeAndPrint { AbstractPipeline.convertExprValue(globals.asExprValue()) }
+ continue
+ }
+ "!clear" -> {
+ terminal.puts(InfoCmp.Capability.clear_screen)
+ terminal.flush()
+ continue
+ }
+ "!history" -> {
+ for (entry in reader.history) {
+ out.println(entry.pretty())
+ }
+ continue
+ }
+ "!list_commands", "!help" -> {
+ out.info(HELP)
+ continue
}
- continue
- }
- "!list_commands", "!help" -> {
- out.info(HELP)
- continue
}
- }
- // Execute PartiQL
- executeAndPrint {
- val locals = refreshBindings()
- evaluatePartiQL(line, locals)
+ // Execute PartiQL
+ executeAndPrint {
+ val locals = refreshBindings()
+ evaluatePartiQL(line, locals, exiting)
+ }
}
+ out.println("Thanks for using PartiQL!")
}
- }
/** After a command [detectedCommand] has been detected to start the user input,
* analyze the entire [wholeLine] user input again, expecting to find more input after the command.
@@ -278,7 +310,7 @@ internal class Shell(
val file = File(filename)
file.readText()
} catch (ex: Exception) {
- out.error("Could not read text from file '$filename'${ex.message?.let { ":\n$it"} ?: "."}")
+ out.error("Could not read text from file '$filename'${ex.message?.let { ":\n$it" } ?: "."}")
null
}
@@ -292,14 +324,31 @@ internal class Shell(
}
/** Evaluate a textual PartiQL query [textPartiQL] in the context of given [bindings]. */
- private fun evaluatePartiQL(textPartiQL: String, bindings: Bindings): PartiQLResult =
- compiler.compile(
- textPartiQL,
- EvaluationSession.build {
- globals(bindings)
- user(currentUser)
- }
+ private fun evaluatePartiQL(
+ textPartiQL: String,
+ bindings: Bindings,
+ exiting: AtomicBoolean
+ ): PartiQLResult {
+ doneCompiling.set(false)
+ inputs.put(
+ RunnablePipeline.Input(
+ textPartiQL,
+ EvaluationSession.build {
+ globals(bindings)
+ user(currentUser)
+ }
+ )
)
+ return catchCancellation(
+ doneCompiling,
+ exiting,
+ pipelineService,
+ PartiQLResult.Value(value = ExprValue.newString("Compilation cancelled."))
+ ) {
+ pipelineService = Executors.newFixedThreadPool(1)
+ pipelineService.submit(RunnablePipeline(inputs, results, compiler, doneCompiling))
+ } ?: results.poll(5, TimeUnit.SECONDS)!!
+ }
private fun bringGraph(name: String, graphIonText: String) {
try {
@@ -333,7 +382,19 @@ internal class Shell(
}
is PartiQLResult.Value -> {
try {
- printExprValue(ConfigurableExprValueFormatter.pretty, result.value)
+ donePrinting.set(false)
+ values.put(result.value)
+ catchCancellation(donePrinting, exiting, printingService, 1) {
+ printingService = Executors.newFixedThreadPool(1)
+ printingService.submit(
+ RunnableWriter(
+ out,
+ ConfigurableExprValueFormatter.pretty,
+ values,
+ donePrinting
+ )
+ )
+ }
} catch (ex: EvaluationException) { // should not need to do this here; see https://github.com/partiql/partiql-lang-kotlin/issues/1002
out.error(ex.generateMessage())
out.error(ex.message)
@@ -355,12 +416,33 @@ internal class Shell(
out.flush()
}
- private fun printExprValue(formatter: ExprValueFormatter, result: ExprValue) {
- out.info(BAR_1)
- formatter.formatTo(result, out)
- out.println()
- out.info(BAR_2)
- previousResult = result
+ /**
+ * If nothing was caught and execution finished: return null
+ * If something was caught: resets service and returns defaultReturn
+ */
+ private fun catchCancellation(
+ doneExecuting: AtomicBoolean,
+ cancellationFlag: AtomicBoolean,
+ service: ExecutorService,
+ defaultReturn: T,
+ resetService: () -> Unit
+ ): T? {
+ while (!doneExecuting.get()) {
+ if (exiting.get()) {
+ service.shutdown()
+ service.shutdownNow()
+ when (service.awaitTermination(2, TimeUnit.SECONDS)) {
+ true -> {
+ cancellationFlag.set(false)
+ doneExecuting.set(false)
+ resetService()
+ return defaultReturn
+ }
+ false -> throw Exception("Printing service couldn't terminate")
+ }
+ }
+ }
+ return null
}
private fun retrievePartiQLVersionAndHash(): String {
@@ -418,7 +500,7 @@ private fun PrintStream.success(string: String) = this.println(ansi(string, SUCC
private fun PrintStream.error(string: String) = this.println(ansi(string, ERROR))
-private fun PrintStream.info(string: String) = this.println(ansi(string, INFO))
+internal fun PrintStream.info(string: String) = this.println(ansi(string, INFO))
private fun PrintStream.warn(string: String) = this.println(ansi(string, WARN))
diff --git a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt
new file mode 100644
index 0000000000..15a4e6c0bb
--- /dev/null
+++ b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt
@@ -0,0 +1,291 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at:
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
+ */
+
+package org.partiql.jmh.benchmarks
+
+import org.openjdk.jmh.annotations.Benchmark
+import org.openjdk.jmh.annotations.BenchmarkMode
+import org.openjdk.jmh.annotations.Fork
+import org.openjdk.jmh.annotations.Measurement
+import org.openjdk.jmh.annotations.Mode
+import org.openjdk.jmh.annotations.OutputTimeUnit
+import org.openjdk.jmh.annotations.Scope
+import org.openjdk.jmh.annotations.State
+import org.openjdk.jmh.annotations.Warmup
+import org.openjdk.jmh.infra.Blackhole
+import org.partiql.jmh.utils.FORK_VALUE_RECOMMENDED
+import org.partiql.jmh.utils.MEASUREMENT_ITERATION_VALUE_RECOMMENDED
+import org.partiql.jmh.utils.MEASUREMENT_TIME_VALUE_RECOMMENDED
+import org.partiql.jmh.utils.WARMUP_ITERATION_VALUE_RECOMMENDED
+import org.partiql.jmh.utils.WARMUP_TIME_VALUE_RECOMMENDED
+import org.partiql.lang.CompilerPipeline
+import org.partiql.lang.eval.CompileOptions
+import org.partiql.lang.eval.EvaluationSession
+import org.partiql.lang.eval.PartiQLResult
+import org.partiql.lang.syntax.PartiQLParserBuilder
+import java.util.concurrent.TimeUnit
+
+/**
+ * These are the sample benchmarks to demonstrate how JMH benchmarks in PartiQL should be set up.
+ * Refer this [JMH tutorial](http://tutorials.jenkov.com/java-performance/jmh.html) for more information on [Benchmark]s,
+ * [BenchmarkMode]s, etc.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.MICROSECONDS)
+open class CompilerInterruptionBenchmark {
+
+ companion object {
+ private const val FORK_VALUE: Int = FORK_VALUE_RECOMMENDED
+ private const val MEASUREMENT_ITERATION_VALUE: Int = MEASUREMENT_ITERATION_VALUE_RECOMMENDED
+ private const val MEASUREMENT_TIME_VALUE: Int = MEASUREMENT_TIME_VALUE_RECOMMENDED
+ private const val WARMUP_ITERATION_VALUE: Int = WARMUP_ITERATION_VALUE_RECOMMENDED
+ private const val WARMUP_TIME_VALUE: Int = WARMUP_TIME_VALUE_RECOMMENDED
+ }
+
+ @State(Scope.Thread)
+ open class MyState {
+ val parser = PartiQLParserBuilder.standard().build()
+ val session = EvaluationSession.standard()
+ val pipeline = CompilerPipeline.standard()
+ val pipelineWithoutInterruption = CompilerPipeline.build {
+ compileOptions(CompileOptions.standard().copy(interruptible = false))
+ }
+
+ val crossJoins = """
+ SELECT
+ *
+ FROM
+ ([1, 2, 3, 4]) as x1,
+ ([1, 2, 3, 4]) as x2,
+ ([1, 2, 3, 4]) as x3,
+ ([1, 2, 3, 4]) as x4,
+ ([1, 2, 3, 4]) as x5,
+ ([1, 2, 3, 4]) as x6,
+ ([1, 2, 3, 4]) as x7,
+ ([1, 2, 3, 4]) as x8
+ """.trimIndent()
+ val crossJoinsAst = parser.parseAstStatement(crossJoins)
+
+ val crossJoinsWithAggFunction = """
+ SELECT
+ COUNT(*)
+ FROM
+ ([1, 2, 3, 4]) as x1,
+ ([1, 2, 3, 4]) as x2,
+ ([1, 2, 3, 4]) as x3,
+ ([1, 2, 3, 4]) as x4,
+ ([1, 2, 3, 4]) as x5,
+ ([1, 2, 3, 4]) as x6,
+ ([1, 2, 3, 4]) as x7,
+ ([1, 2, 3, 4]) as x8,
+ ([1, 2, 3, 4]) as x9,
+ ([1, 2, 3, 4]) as x10,
+ ([1, 2, 3, 4]) as x11
+ """.trimIndent()
+ val crossJoinsAggAst = parser.parseAstStatement(crossJoinsWithAggFunction)
+
+ val crossJoinsWithAggFunctionAndGroupBy = """
+ SELECT
+ COUNT(*)
+ FROM
+ ([1, 2, 3, 4]) as x1,
+ ([1, 2, 3, 4]) as x2,
+ ([1, 2, 3, 4]) as x3,
+ ([1, 2, 3, 4]) as x4,
+ ([1, 2, 3, 4]) as x5,
+ ([1, 2, 3, 4]) as x6,
+ ([1, 2, 3, 4]) as x7,
+ ([1, 2, 3, 4]) as x8,
+ ([1, 2, 3, 4]) as x9,
+ ([1, 2, 3, 4]) as x10,
+ ([1, 2, 3, 4]) as x11
+ GROUP BY x1._1
+ """.trimIndent()
+ val crossJoinsAggGroupAst = parser.parseAstStatement(crossJoinsWithAggFunctionAndGroupBy)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun compileCrossJoinWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val exprValue = state.pipeline.compile(state.crossJoins)
+ blackhole.consume(exprValue)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun compileCrossJoinWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val exprValue = state.pipelineWithoutInterruption.compile(state.crossJoins)
+ blackhole.consume(exprValue)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun compileCrossJoinAggFuncWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val exprValue = state.pipeline.compile(state.crossJoinsWithAggFunction)
+ blackhole.consume(exprValue)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun compileCrossJoinAggFuncWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val exprValue = state.pipelineWithoutInterruption.compile(state.crossJoinsWithAggFunction)
+ blackhole.consume(exprValue)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun compileCrossJoinAggFuncGroupingWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val exprValue = state.pipeline.compile(state.crossJoinsWithAggFunctionAndGroupBy)
+ blackhole.consume(exprValue)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun compileCrossJoinAggFuncGroupingWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val exprValue = state.pipelineWithoutInterruption.compile(state.crossJoinsWithAggFunctionAndGroupBy)
+ blackhole.consume(exprValue)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun evalCrossJoinWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipeline.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ blackhole.consume(value)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun evalCrossJoinWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ blackhole.consume(value)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun evalCrossJoinAggWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipeline.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ blackhole.consume(value)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun evalCrossJoinAggWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ blackhole.consume(value)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun evalCrossJoinAggGroupWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipeline.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ blackhole.consume(value)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun evalCrossJoinAggGroupWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ blackhole.consume(value)
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun iterCrossJoinWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipeline.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ value.forEach { blackhole.consume(it) }
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun iterCrossJoinWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ value.forEach { blackhole.consume(it) }
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun iterCrossJoinAggWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipeline.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ value.forEach { blackhole.consume(it) }
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun iterCrossJoinAggWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ value.forEach { blackhole.consume(it) }
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun iterCrossJoinAggGroupWithInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipeline.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ value.forEach { blackhole.consume(it) }
+ }
+
+ @Benchmark
+ @Fork(value = FORK_VALUE)
+ @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE)
+ @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE)
+ fun iterCrossJoinAggGroupWithoutInterruptible(state: MyState, blackhole: Blackhole) {
+ val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value
+ val value = result.value
+ value.forEach { blackhole.consume(it) }
+ }
+}
diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt
index 59f570f15b..7083c9dc12 100644
--- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt
+++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt
@@ -137,6 +137,10 @@ enum class ThunkReturnTypeAssertions {
*
* @param defaultTimezoneOffset Default timezone offset to be used when TIME WITH TIME ZONE does not explicitly
* specify the time zone. Defaults to [ZoneOffset.UTC]
+ * @param interruptible specifies whether the compilation and execution of the compiled statement is interruptible. If
+ * set to true, the compilation and execution of statements will check [Thread.interrupted] frequently. If set to
+ * false, the compilation and execution of statements is not guaranteed to be interruptible. It *may* still be interrupted,
+ * however, it is not guaranteed. The default is false.
*/
@Suppress("DataClassPrivateConstructor")
data class CompileOptions private constructor (
@@ -146,7 +150,8 @@ data class CompileOptions private constructor (
val thunkOptions: ThunkOptions = ThunkOptions.standard(),
val typingMode: TypingMode = TypingMode.LEGACY,
val typedOpBehavior: TypedOpBehavior = TypedOpBehavior.HONOR_PARAMETERS,
- val defaultTimezoneOffset: ZoneOffset = ZoneOffset.UTC
+ val defaultTimezoneOffset: ZoneOffset = ZoneOffset.UTC,
+ val interruptible: Boolean = false
) {
companion object {
@@ -193,6 +198,7 @@ data class CompileOptions private constructor (
fun thunkOptions(value: ThunkOptions) = set { copy(thunkOptions = value) }
fun thunkOptions(build: ThunkOptions.Builder.() -> Unit) = set { copy(thunkOptions = ThunkOptions.build(build)) }
fun defaultTimezoneOffset(value: ZoneOffset) = set { copy(defaultTimezoneOffset = value) }
+ fun isInterruptible(value: Boolean) = set { copy(interruptible = value) }
private inline fun set(block: CompileOptions.() -> CompileOptions): Builder {
options = block(options)
diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt
index cd0e316ff8..630cc155fb 100644
--- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt
+++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt
@@ -150,6 +150,25 @@ internal open class EvaluatingCompiler(
"compilationContextStack was empty.", ErrorCode.EVALUATOR_UNEXPECTED_VALUE, internal = true
)
+ /**
+ * This checks whether the thread has been interrupted. Specifically, it currently checks during the compilation
+ * of aggregations and joins, the "evaluation" of aggregations and joins, and the materialization of joins
+ * and from source scans.
+ *
+ * Note: This is essentially a way to avoid constantly checking [CompileOptions.interruptible]. By writing it this
+ * way, we statically determine whether to introduce checks. If the compiler has specified
+ * [CompileOptions.interruptible], the invocation of this function will insert a Thread interruption check. If not
+ * specified, it will not perform the check during compilation/evaluation/materialization.
+ */
+ private val interruptionCheck: () -> Unit = when (compileOptions.interruptible) {
+ true -> { ->
+ if (Thread.interrupted()) {
+ throw InterruptedException()
+ }
+ }
+ false -> { -> Unit }
+ }
+
// Note: please don't make this inline -- it messes up [EvaluationException] stack traces and
// isn't a huge benefit because this is only used at SQL-compile time anyway.
internal fun nestCompilationContext(
@@ -1824,7 +1843,10 @@ internal open class EvaluatingCompiler(
// Grouping is not needed -- simply project the results from the FROM clause directly.
thunkFactory.thunkEnv(metas) { env ->
- val sourcedRows = sourceThunks(env)
+ val sourcedRows = sourceThunks(env).map {
+ interruptionCheck()
+ it
+ }
val orderedRows = when (orderByThunk) {
null -> sourcedRows
@@ -1908,6 +1930,7 @@ internal open class EvaluatingCompiler(
// iterate over the values from the FROM clause and populate our
// aggregate register values.
fromProductions.forEach { fromProduction ->
+ interruptionCheck()
compiledAggregates?.forEachIndexed { index, ca ->
registers[index].aggregator.next(ca.argThunk(fromProduction.env))
}
@@ -2473,6 +2496,7 @@ internal open class EvaluatingCompiler(
// compute the join over the data sources
var seq = compiledSources
.foldLeftProduct({ env: Environment -> env }) { currEnvT: (Environment) -> Environment, currSource: CompiledFromSource ->
+ interruptionCheck()
// [currSource] - the next FROM currSource to add to the join
// [currEnvT] - the environment add-on that previous sources of the join have constructed
// and that can be used for evaluating this [currSource] (if it depends on the previous sources)
@@ -2521,6 +2545,7 @@ internal open class EvaluatingCompiler(
}
.asSequence()
.map { joinedValues ->
+ interruptionCheck()
// bind the joined value to the bindings for the filter/project
FromProduction(joinedValues, fromEnv.nest(localsBinder.bindLocals(joinedValues)))
}
diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt
index 07a51d512d..3b489a7d5d 100644
--- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt
+++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt
@@ -118,7 +118,13 @@ internal val DEFAULT_EXCEPTION_HANDLER_FOR_LEGACY_MODE: ThunkExceptionHandlerFor
)
}
-internal val DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE: ThunkExceptionHandlerForPermissiveMode = { _, _ -> }
+internal val DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE: ThunkExceptionHandlerForPermissiveMode = { e, _ ->
+ when (e) {
+ is InterruptedException -> { throw e }
+ is StackOverflowError -> { throw e }
+ else -> {}
+ }
+}
/**
* An extension method for creating [ThunkFactory] based on the type of [TypingMode]
diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt
new file mode 100644
index 0000000000..b8d1063d35
--- /dev/null
+++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt
@@ -0,0 +1,192 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at:
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
+ */
+
+package org.partiql.lang.eval
+
+import org.junit.jupiter.api.Assertions
+import org.junit.jupiter.api.Test
+import org.partiql.lang.syntax.PartiQLParserBuilder
+import org.partiql.lang.syntax.impl.INTERRUPT_AFTER_MS
+import org.partiql.lang.syntax.impl.WAIT_FOR_THREAD_TERMINATION_MS
+import java.util.concurrent.atomic.AtomicBoolean
+import kotlin.concurrent.thread
+
+/**
+ * Making sure we can interrupt the [EvaluatingCompiler].
+ */
+class EvaluatingCompilerInterruptTests {
+
+ private val parser = PartiQLParserBuilder.standard().build()
+ private val session = EvaluationSession.standard()
+ private val options = CompileOptions.standard().copy(interruptible = true)
+ private val compiler = EvaluatingCompiler(
+ emptyList(),
+ emptyMap(),
+ emptyMap(),
+ options
+ )
+
+ /**
+ * Joins are only evaluated during the materialization of the ExprValue's elements. Cross Joins.
+ */
+ @Test
+ fun evalCrossJoins() {
+ val query = """
+ SELECT
+ *
+ FROM
+ ([1, 2, 3, 4]) as x1,
+ ([1, 2, 3, 4]) as x2,
+ ([1, 2, 3, 4]) as x3,
+ ([1, 2, 3, 4]) as x4,
+ ([1, 2, 3, 4]) as x5,
+ ([1, 2, 3, 4]) as x6,
+ ([1, 2, 3, 4]) as x7,
+ ([1, 2, 3, 4]) as x8,
+ ([1, 2, 3, 4]) as x9,
+ ([1, 2, 3, 4]) as x10,
+ ([1, 2, 3, 4]) as x11,
+ ([1, 2, 3, 4]) as x12,
+ ([1, 2, 3, 4]) as x13,
+ ([1, 2, 3, 4]) as x14,
+ ([1, 2, 3, 4]) as x15
+ """.trimIndent()
+ val ast = parser.parseAstStatement(query)
+ val expression = compiler.compile(ast)
+ val result = expression.evaluate(session) as PartiQLResult.Value
+ testThreadInterrupt {
+ result.value.forEach { it }
+ }
+ }
+
+ /**
+ * Joins are only evaluated during the materialization of the ExprValue's elements. Making sure left
+ * joins can be interrupted.
+ */
+ @Test
+ fun evalLeftJoins() {
+ val query = """
+ SELECT
+ *
+ FROM
+ ([1, 2, 3, 4]) as x1 LEFT JOIN
+ ([1, 2, 3, 4]) as x2 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x3 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x4 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x5 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x6 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x7 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x8 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x9 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x10 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x11 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x12 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x13 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x14 ON TRUE LEFT JOIN
+ ([1, 2, 3, 4]) as x15 ON TRUE
+ """.trimIndent()
+ val ast = parser.parseAstStatement(query)
+ val expression = compiler.compile(ast)
+ val result = expression.evaluate(session) as PartiQLResult.Value
+ testThreadInterrupt {
+ result.value.forEach { it }
+ }
+ }
+
+ /**
+ * Aggregations currently get materialized during [Expression.evaluate], so we need to check that we can
+ * interrupt there.
+ */
+ @Test
+ fun compileLargeAggregation() {
+ val query = """
+ SELECT
+ COUNT(*)
+ FROM
+ ([1, 2, 3, 4]) as x1,
+ ([1, 2, 3, 4]) as x2,
+ ([1, 2, 3, 4]) as x3,
+ ([1, 2, 3, 4]) as x4,
+ ([1, 2, 3, 4]) as x5,
+ ([1, 2, 3, 4]) as x6,
+ ([1, 2, 3, 4]) as x7,
+ ([1, 2, 3, 4]) as x8,
+ ([1, 2, 3, 4]) as x9,
+ ([1, 2, 3, 4]) as x10,
+ ([1, 2, 3, 4]) as x11,
+ ([1, 2, 3, 4]) as x12,
+ ([1, 2, 3, 4]) as x13,
+ ([1, 2, 3, 4]) as x14,
+ ([1, 2, 3, 4]) as x15
+ """.trimIndent()
+ val ast = parser.parseAstStatement(query)
+ val expression = compiler.compile(ast)
+ testThreadInterrupt {
+ expression.evaluate(session) as PartiQLResult.Value
+ }
+ }
+
+ /**
+ * We need to make sure that we can end a never-ending query. These sorts of queries get materialized during the
+ * iteration of [ExprValue].
+ */
+ @Test
+ fun neverEndingScan() {
+ val indefiniteCollection = ExprValue.newBag(
+ sequence {
+ while (true) {
+ yield(ExprValue.nullValue)
+ }
+ }
+ )
+ val query = """
+ SELECT *
+ FROM ?
+ """.trimIndent()
+ val session = EvaluationSession.build {
+ parameters(listOf(indefiniteCollection))
+ }
+ val ast = parser.parseAstStatement(query)
+ val expression = compiler.compile(ast)
+ val result = expression.evaluate(session) as PartiQLResult.Value
+ testThreadInterrupt {
+ result.value.forEach { it }
+ }
+ }
+
+ private fun testThreadInterrupt(
+ interruptAfter: Long = INTERRUPT_AFTER_MS,
+ interruptWait: Long = WAIT_FOR_THREAD_TERMINATION_MS,
+ block: () -> Unit
+ ) {
+ val wasInterrupted = AtomicBoolean(false)
+ val t = thread(start = false) {
+ try {
+ block()
+ } catch (_: InterruptedException) {
+ wasInterrupted.set(true)
+ } catch (e: EvaluationException) {
+ if (e.cause is InterruptedException) {
+ wasInterrupted.set(true)
+ }
+ }
+ }
+ t.setUncaughtExceptionHandler { _, ex -> throw ex }
+ t.start()
+ Thread.sleep(interruptAfter)
+ t.interrupt()
+ t.join(interruptWait)
+ Assertions.assertTrue(wasInterrupted.get(), "Thread should have been interrupted.")
+ }
+}
diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt
index 8384d7de3a..1c463fb1a3 100644
--- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt
+++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt
@@ -37,10 +37,10 @@ import java.util.concurrent.atomic.AtomicBoolean
import kotlin.concurrent.thread
/** How long (in millis) to wait after starting a thread to set the interrupted flag. */
-private const val INTERRUPT_AFTER_MS: Long = 100
+const val INTERRUPT_AFTER_MS: Long = 100
/** How long (in millis) to wait for a thread to terminate after setting the interrupted flag. */
-private const val WAIT_FOR_THREAD_TERMINATION_MS: Long = 1000
+const val WAIT_FOR_THREAD_TERMINATION_MS: Long = 1000
/**
* At various locations in this codebase we check the state of [Thread.interrupted] and throw an