From 1c08148730c82aa245393a748ddec2d8ab905c46 Mon Sep 17 00:00:00 2001 From: John Ed Quinn <40360967+johnedquinn@users.noreply.github.com> Date: Tue, 19 Sep 2023 16:07:15 -0700 Subject: [PATCH] Adds support for thread interruption in compilation and execution (#1211) * Adds support for thread interruption in compilation and execution * Adds support for CLI users to use CTRL-C --- CHANGELOG.md | 4 + .../partiql/cli/pipeline/AbstractPipeline.kt | 1 + .../org/partiql/cli/shell/RunnablePipeline.kt | 56 ++++ .../org/partiql/cli/shell/RunnableWriter.kt | 53 +++ .../kotlin/org/partiql/cli/shell/Shell.kt | 316 +++++++++++------- .../CompilerInterruptionBenchmark.kt | 291 ++++++++++++++++ .../org/partiql/lang/eval/CompileOptions.kt | 8 +- .../partiql/lang/eval/EvaluatingCompiler.kt | 27 +- .../kotlin/org/partiql/lang/eval/Thunk.kt | 8 +- .../eval/EvaluatingCompilerInterruptTests.kt | 192 +++++++++++ .../PartiQLPigParserThreadInterruptTests.kt | 4 +- 11 files changed, 838 insertions(+), 122 deletions(-) create mode 100644 partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt create mode 100644 partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt create mode 100644 partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt create mode 100644 partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 49f416af9a..c853179b1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,10 @@ Thank you to all who have contributed! ## [Unreleased] ### Added +- Adds `isInterruptible` property to `CompileOptions`. The default value is `false`. Please see the KDocs for more information. +- Adds support for thread interruption in compilation and execution. If you'd like to opt-in to this addition, please see + the `isInterruptible` addition above for more information. +- Adds support for CLI users to use CTRL-C to cancel long-running compilation/execution of queries ### Changed diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt index 925a2dbb14..ac8dcbc6c7 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/AbstractPipeline.kt @@ -145,6 +145,7 @@ internal sealed class AbstractPipeline(open val options: PipelineOptions) { projectionIteration(options.projectionIterationBehavior) undefinedVariable(options.undefinedVariableBehavior) typingMode(options.typingMode) + isInterruptible(true) } private val compilerPipeline = CompilerPipeline.build { diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt new file mode 100644 index 0000000000..e52b396fe1 --- /dev/null +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnablePipeline.kt @@ -0,0 +1,56 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.cli.shell + +import org.partiql.cli.pipeline.AbstractPipeline +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.PartiQLResult +import java.util.concurrent.BlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +/** + * A wrapper over [AbstractPipeline]. It constantly grabs input queries from [inputs] and places the [PartiQLResult] + * in [results]. When it is done compiling a single statement, it sets [doneCompiling] to true. + */ +internal class RunnablePipeline( + private val inputs: BlockingQueue, + private val results: BlockingQueue, + val pipeline: AbstractPipeline, + private val doneCompiling: AtomicBoolean +) : Runnable { + /** + * When the Thread running this [Runnable] is interrupted, the underlying [AbstractPipeline] should catch the + * interruption and fail with some exception. Then, this will break out of [run]. + */ + override fun run() { + while (true) { + val input = inputs.poll(3, TimeUnit.SECONDS) + if (input != null) { + val result = pipeline.compile(input.input, input.session) + results.put(result) + doneCompiling.set(true) + } + } + } + + /** + * Represents a PartiQL statement ([input]) and the [EvaluationSession] to evaluate with. + */ + internal data class Input( + val input: String, + val session: EvaluationSession + ) +} diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt new file mode 100644 index 0000000000..caae3e1069 --- /dev/null +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/RunnableWriter.kt @@ -0,0 +1,53 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.cli.shell + +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.util.ExprValueFormatter +import java.io.PrintStream +import java.util.concurrent.BlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +/** + * A wrapper over [ExprValueFormatter]. It constantly grabs [ExprValue]s from [values] and writes them to [out]. + * When it is done printing a single item, it sets [donePrinting] to true. + */ +internal class RunnableWriter( + private val out: PrintStream, + private val formatter: ExprValueFormatter, + private val values: BlockingQueue, + private val donePrinting: AtomicBoolean +) : Runnable { + + /** + * When the Thread running this [Runnable] is interrupted, the underlying formatter should check the interruption + * flag and fail with some exception. The formatter itself doesn't do this, but, since [ExprValue]s are lazily created, + * the creation of the [ExprValue] (by means of the thunks produced by the EvaluatingCompiler) should throw an exception + * when the thread is interrupted. Then, this will break out of [run]. + */ + override fun run() { + while (true) { + val value = values.poll(3, TimeUnit.SECONDS) + if (value != null) { + out.info(BAR_1) + formatter.formatTo(value, out) + out.println() + out.info(BAR_2) + donePrinting.set(true) + } + } + } +} diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt index 63fc5b914d..bb78d782ea 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt @@ -23,6 +23,7 @@ import org.jline.reader.LineReader import org.jline.reader.LineReaderBuilder import org.jline.reader.UserInterruptException import org.jline.reader.impl.completer.AggregateCompleter +import org.jline.terminal.Terminal import org.jline.terminal.TerminalBuilder import org.jline.utils.AttributedString import org.jline.utils.AttributedStringBuilder @@ -43,7 +44,6 @@ import org.partiql.lang.graph.ExternalGraphException import org.partiql.lang.graph.ExternalGraphReader import org.partiql.lang.syntax.PartiQLParserBuilder import org.partiql.lang.util.ConfigurableExprValueFormatter -import org.partiql.lang.util.ExprValueFormatter import java.io.Closeable import java.io.File import java.io.OutputStream @@ -52,15 +52,19 @@ import java.nio.file.Path import java.nio.file.Paths import java.util.Locale import java.util.Properties +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.BlockingQueue import java.util.concurrent.CountDownLatch +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy private const val PROMPT_1 = "PartiQL> " private const val PROMPT_2 = " | " -private const val BAR_1 = "===' " -private const val BAR_2 = "--- " +internal const val BAR_1 = "===' " +internal const val BAR_2 = "--- " private const val WELCOME_MSG = "Welcome to the PartiQL shell!" private const val DEBUG_MSG = """ ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ @@ -98,10 +102,15 @@ private val EXIT_DELAY: Duration = Duration(3000) * Initial work to replace the REPL with JLine3. I have attempted to keep this similar to Repl.kt, but have some * opinions on ways to clean this up in later PRs. */ + +val exiting = AtomicBoolean(false) +val doneCompiling = AtomicBoolean(false) +val donePrinting = AtomicBoolean(false) + internal class Shell( - private val output: OutputStream, + output: OutputStream, private val compiler: AbstractPipeline, - private val initialGlobal: Bindings, + initialGlobal: Bindings, private val config: ShellConfiguration = ShellConfiguration() ) { private val homeDir: Path = Paths.get(System.getProperty("user.home")) @@ -110,10 +119,17 @@ internal class Shell( private val out = PrintStream(output) private val currentUser = System.getProperty("user.name") + private val inputs: BlockingQueue = ArrayBlockingQueue(1) + private val results: BlockingQueue = ArrayBlockingQueue(1) + private var pipelineService: ExecutorService = Executors.newFixedThreadPool(1) + private val values: BlockingQueue = ArrayBlockingQueue(1) + private var printingService: ExecutorService = Executors.newFixedThreadPool(1) + fun start() { - val exiting = AtomicBoolean() val interrupter = ThreadInterrupter() val exited = CountDownLatch(1) + pipelineService.submit(RunnablePipeline(inputs, results, compiler, doneCompiling)) + printingService.submit(RunnableWriter(out, ConfigurableExprValueFormatter.pretty, values, donePrinting)) Runtime .getRuntime() .addShutdownHook( @@ -135,119 +151,135 @@ internal class Shell( } } - private fun run(exiting: AtomicBoolean) = TerminalBuilder.builder().build().use { terminal -> - val highlighter = when { - this.config.isMonochrome -> null - else -> ShellHighlighter() - } - val completer = AggregateCompleter(CompleterDefault()) - val reader = LineReaderBuilder.builder() - .terminal(terminal) - .parser(ShellParser) - .completer(completer) - .option(LineReader.Option.GROUP_PERSIST, true) - .option(LineReader.Option.AUTO_LIST, true) - .option(LineReader.Option.CASE_INSENSITIVE, true) - .variable(LineReader.LIST_MAX, 10) - .highlighter(highlighter) - .expander(ShellExpander) - .variable(LineReader.HISTORY_FILE, homeDir.resolve(".partiql/.history")) - .variable(LineReader.SECONDARY_PROMPT_PATTERN, PROMPT_2) - .build() - - out.info(WELCOME_MSG) - out.info("Typing mode: ${compiler.options.typingMode.name}") - out.info("Using version: ${retrievePartiQLVersionAndHash()}") - if (compiler is AbstractPipeline.PipelineDebug) { - out.println("\n\n") - out.success(DEBUG_MSG) - out.println("\n\n") + private val signalHandler = Terminal.SignalHandler { sig -> + if (sig == Terminal.Signal.INT) { + exiting.set(true) } + } - while (!exiting.get()) { - val line: String = try { - reader.readLine(PROMPT_1) - } catch (ex: UserInterruptException) { - if (ex.partialLine.isNotEmpty()) { - reader.history.add(ex.partialLine) - } - continue - } catch (ex: EndOfFileException) { - out.info("^D") - return - } + private fun run(exiting: AtomicBoolean) = TerminalBuilder.builder() + .name("PartiQL") + .nativeSignals(true) + .signalHandler(signalHandler) + .build().use { terminal -> - // Pretty print AST - if (line.endsWith("\n!!")) { - printAST(line.removeSuffix("!!")) - continue + val highlighter = when { + this.config.isMonochrome -> null + else -> ShellHighlighter() } - - if (line.isBlank()) { - out.success("OK!") - continue + val completer = AggregateCompleter(CompleterDefault()) + val reader = LineReaderBuilder.builder() + .terminal(terminal) + .parser(ShellParser) + .completer(completer) + .option(LineReader.Option.GROUP_PERSIST, true) + .option(LineReader.Option.AUTO_LIST, true) + .option(LineReader.Option.CASE_INSENSITIVE, true) + .variable(LineReader.LIST_MAX, 10) + .highlighter(highlighter) + .expander(ShellExpander) + .variable(LineReader.HISTORY_FILE, homeDir.resolve(".partiql/.history")) + .variable(LineReader.SECONDARY_PROMPT_PATTERN, PROMPT_2) + .build() + + out.info(WELCOME_MSG) + out.info("Typing mode: ${compiler.options.typingMode.name}") + out.info("Using version: ${retrievePartiQLVersionAndHash()}") + if (compiler is AbstractPipeline.PipelineDebug) { + out.println("\n\n") + out.success(DEBUG_MSG) + out.println("\n\n") } - // Handle commands - val command = when (val end: Int = CharMatcher.`is`(';').or(CharMatcher.whitespace()).indexIn(line)) { - -1 -> "" - else -> line.substring(0, end) - }.lowercase(Locale.ENGLISH).trim() - when (command) { - "!exit" -> return - "!add_to_global_env" -> { - // Consider PicoCLI + Jline, but it doesn't easily place nice with commands + raw SQL - // https://github.com/partiql/partiql-lang-kotlin/issues/63 - val arg = requireInput(line, command) ?: continue - executeAndPrint { - val locals = refreshBindings() - val result = evaluatePartiQL(arg, locals) as PartiQLResult.Value - globals.add(result.value.bindings) - result + while (!exiting.get()) { + val line: String = try { + reader.readLine(PROMPT_1) + } catch (ex: UserInterruptException) { + if (ex.partialLine.isNotEmpty()) { + reader.history.add(ex.partialLine) } continue + } catch (ex: EndOfFileException) { + out.info("^D") + return } - "!add_graph" -> { - val input = requireInput(line, command) ?: continue - val (name, graphStr) = requireTokenAndMore(input, command) ?: continue - bringGraph(name, graphStr) - continue - } - "!add_graph_from_file" -> { - val input = requireInput(line, command) ?: continue - val (name, filename) = requireTokenAndMore(input, command) ?: continue - val graphStr = readTextFile(filename) ?: continue - bringGraph(name, graphStr) - continue - } - "!global_env" -> { - executeAndPrint { AbstractPipeline.convertExprValue(globals.asExprValue()) } + + // Pretty print AST + if (line.endsWith("\n!!")) { + printAST(line.removeSuffix("!!")) continue } - "!clear" -> { - terminal.puts(InfoCmp.Capability.clear_screen) - terminal.flush() + + if (line.isBlank()) { + out.success("OK!") continue } - "!history" -> { - for (entry in reader.history) { - out.println(entry.pretty()) + + // Handle commands + val command = when (val end: Int = CharMatcher.`is`(';').or(CharMatcher.whitespace()).indexIn(line)) { + -1 -> "" + else -> line.substring(0, end) + }.lowercase(Locale.ENGLISH).trim() + when (command) { + "!exit" -> return + "!add_to_global_env" -> { + // Consider PicoCLI + Jline, but it doesn't easily place nice with commands + raw SQL + // https://github.com/partiql/partiql-lang-kotlin/issues/63 + val arg = requireInput(line, command) ?: continue + executeAndPrint { + val locals = refreshBindings() + val result = evaluatePartiQL( + arg, + locals, + exiting + ) as PartiQLResult.Value + globals.add(result.value.bindings) + result + } + continue + } + "!add_graph" -> { + val input = requireInput(line, command) ?: continue + val (name, graphStr) = requireTokenAndMore(input, command) ?: continue + bringGraph(name, graphStr) + continue + } + "!add_graph_from_file" -> { + val input = requireInput(line, command) ?: continue + val (name, filename) = requireTokenAndMore(input, command) ?: continue + val graphStr = readTextFile(filename) ?: continue + bringGraph(name, graphStr) + continue + } + "!global_env" -> { + executeAndPrint { AbstractPipeline.convertExprValue(globals.asExprValue()) } + continue + } + "!clear" -> { + terminal.puts(InfoCmp.Capability.clear_screen) + terminal.flush() + continue + } + "!history" -> { + for (entry in reader.history) { + out.println(entry.pretty()) + } + continue + } + "!list_commands", "!help" -> { + out.info(HELP) + continue } - continue - } - "!list_commands", "!help" -> { - out.info(HELP) - continue } - } - // Execute PartiQL - executeAndPrint { - val locals = refreshBindings() - evaluatePartiQL(line, locals) + // Execute PartiQL + executeAndPrint { + val locals = refreshBindings() + evaluatePartiQL(line, locals, exiting) + } } + out.println("Thanks for using PartiQL!") } - } /** After a command [detectedCommand] has been detected to start the user input, * analyze the entire [wholeLine] user input again, expecting to find more input after the command. @@ -278,7 +310,7 @@ internal class Shell( val file = File(filename) file.readText() } catch (ex: Exception) { - out.error("Could not read text from file '$filename'${ex.message?.let { ":\n$it"} ?: "."}") + out.error("Could not read text from file '$filename'${ex.message?.let { ":\n$it" } ?: "."}") null } @@ -292,14 +324,31 @@ internal class Shell( } /** Evaluate a textual PartiQL query [textPartiQL] in the context of given [bindings]. */ - private fun evaluatePartiQL(textPartiQL: String, bindings: Bindings): PartiQLResult = - compiler.compile( - textPartiQL, - EvaluationSession.build { - globals(bindings) - user(currentUser) - } + private fun evaluatePartiQL( + textPartiQL: String, + bindings: Bindings, + exiting: AtomicBoolean + ): PartiQLResult { + doneCompiling.set(false) + inputs.put( + RunnablePipeline.Input( + textPartiQL, + EvaluationSession.build { + globals(bindings) + user(currentUser) + } + ) ) + return catchCancellation( + doneCompiling, + exiting, + pipelineService, + PartiQLResult.Value(value = ExprValue.newString("Compilation cancelled.")) + ) { + pipelineService = Executors.newFixedThreadPool(1) + pipelineService.submit(RunnablePipeline(inputs, results, compiler, doneCompiling)) + } ?: results.poll(5, TimeUnit.SECONDS)!! + } private fun bringGraph(name: String, graphIonText: String) { try { @@ -333,7 +382,19 @@ internal class Shell( } is PartiQLResult.Value -> { try { - printExprValue(ConfigurableExprValueFormatter.pretty, result.value) + donePrinting.set(false) + values.put(result.value) + catchCancellation(donePrinting, exiting, printingService, 1) { + printingService = Executors.newFixedThreadPool(1) + printingService.submit( + RunnableWriter( + out, + ConfigurableExprValueFormatter.pretty, + values, + donePrinting + ) + ) + } } catch (ex: EvaluationException) { // should not need to do this here; see https://github.com/partiql/partiql-lang-kotlin/issues/1002 out.error(ex.generateMessage()) out.error(ex.message) @@ -355,12 +416,33 @@ internal class Shell( out.flush() } - private fun printExprValue(formatter: ExprValueFormatter, result: ExprValue) { - out.info(BAR_1) - formatter.formatTo(result, out) - out.println() - out.info(BAR_2) - previousResult = result + /** + * If nothing was caught and execution finished: return null + * If something was caught: resets service and returns defaultReturn + */ + private fun catchCancellation( + doneExecuting: AtomicBoolean, + cancellationFlag: AtomicBoolean, + service: ExecutorService, + defaultReturn: T, + resetService: () -> Unit + ): T? { + while (!doneExecuting.get()) { + if (exiting.get()) { + service.shutdown() + service.shutdownNow() + when (service.awaitTermination(2, TimeUnit.SECONDS)) { + true -> { + cancellationFlag.set(false) + doneExecuting.set(false) + resetService() + return defaultReturn + } + false -> throw Exception("Printing service couldn't terminate") + } + } + } + return null } private fun retrievePartiQLVersionAndHash(): String { @@ -418,7 +500,7 @@ private fun PrintStream.success(string: String) = this.println(ansi(string, SUCC private fun PrintStream.error(string: String) = this.println(ansi(string, ERROR)) -private fun PrintStream.info(string: String) = this.println(ansi(string, INFO)) +internal fun PrintStream.info(string: String) = this.println(ansi(string, INFO)) private fun PrintStream.warn(string: String) = this.println(ansi(string, WARN)) diff --git a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt new file mode 100644 index 0000000000..15a4e6c0bb --- /dev/null +++ b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/CompilerInterruptionBenchmark.kt @@ -0,0 +1,291 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.jmh.benchmarks + +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.BenchmarkMode +import org.openjdk.jmh.annotations.Fork +import org.openjdk.jmh.annotations.Measurement +import org.openjdk.jmh.annotations.Mode +import org.openjdk.jmh.annotations.OutputTimeUnit +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import org.openjdk.jmh.annotations.Warmup +import org.openjdk.jmh.infra.Blackhole +import org.partiql.jmh.utils.FORK_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_TIME_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_TIME_VALUE_RECOMMENDED +import org.partiql.lang.CompilerPipeline +import org.partiql.lang.eval.CompileOptions +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.PartiQLResult +import org.partiql.lang.syntax.PartiQLParserBuilder +import java.util.concurrent.TimeUnit + +/** + * These are the sample benchmarks to demonstrate how JMH benchmarks in PartiQL should be set up. + * Refer this [JMH tutorial](http://tutorials.jenkov.com/java-performance/jmh.html) for more information on [Benchmark]s, + * [BenchmarkMode]s, etc. + */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +open class CompilerInterruptionBenchmark { + + companion object { + private const val FORK_VALUE: Int = FORK_VALUE_RECOMMENDED + private const val MEASUREMENT_ITERATION_VALUE: Int = MEASUREMENT_ITERATION_VALUE_RECOMMENDED + private const val MEASUREMENT_TIME_VALUE: Int = MEASUREMENT_TIME_VALUE_RECOMMENDED + private const val WARMUP_ITERATION_VALUE: Int = WARMUP_ITERATION_VALUE_RECOMMENDED + private const val WARMUP_TIME_VALUE: Int = WARMUP_TIME_VALUE_RECOMMENDED + } + + @State(Scope.Thread) + open class MyState { + val parser = PartiQLParserBuilder.standard().build() + val session = EvaluationSession.standard() + val pipeline = CompilerPipeline.standard() + val pipelineWithoutInterruption = CompilerPipeline.build { + compileOptions(CompileOptions.standard().copy(interruptible = false)) + } + + val crossJoins = """ + SELECT + * + FROM + ([1, 2, 3, 4]) as x1, + ([1, 2, 3, 4]) as x2, + ([1, 2, 3, 4]) as x3, + ([1, 2, 3, 4]) as x4, + ([1, 2, 3, 4]) as x5, + ([1, 2, 3, 4]) as x6, + ([1, 2, 3, 4]) as x7, + ([1, 2, 3, 4]) as x8 + """.trimIndent() + val crossJoinsAst = parser.parseAstStatement(crossJoins) + + val crossJoinsWithAggFunction = """ + SELECT + COUNT(*) + FROM + ([1, 2, 3, 4]) as x1, + ([1, 2, 3, 4]) as x2, + ([1, 2, 3, 4]) as x3, + ([1, 2, 3, 4]) as x4, + ([1, 2, 3, 4]) as x5, + ([1, 2, 3, 4]) as x6, + ([1, 2, 3, 4]) as x7, + ([1, 2, 3, 4]) as x8, + ([1, 2, 3, 4]) as x9, + ([1, 2, 3, 4]) as x10, + ([1, 2, 3, 4]) as x11 + """.trimIndent() + val crossJoinsAggAst = parser.parseAstStatement(crossJoinsWithAggFunction) + + val crossJoinsWithAggFunctionAndGroupBy = """ + SELECT + COUNT(*) + FROM + ([1, 2, 3, 4]) as x1, + ([1, 2, 3, 4]) as x2, + ([1, 2, 3, 4]) as x3, + ([1, 2, 3, 4]) as x4, + ([1, 2, 3, 4]) as x5, + ([1, 2, 3, 4]) as x6, + ([1, 2, 3, 4]) as x7, + ([1, 2, 3, 4]) as x8, + ([1, 2, 3, 4]) as x9, + ([1, 2, 3, 4]) as x10, + ([1, 2, 3, 4]) as x11 + GROUP BY x1._1 + """.trimIndent() + val crossJoinsAggGroupAst = parser.parseAstStatement(crossJoinsWithAggFunctionAndGroupBy) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun compileCrossJoinWithInterruptible(state: MyState, blackhole: Blackhole) { + val exprValue = state.pipeline.compile(state.crossJoins) + blackhole.consume(exprValue) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun compileCrossJoinWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val exprValue = state.pipelineWithoutInterruption.compile(state.crossJoins) + blackhole.consume(exprValue) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun compileCrossJoinAggFuncWithInterruptible(state: MyState, blackhole: Blackhole) { + val exprValue = state.pipeline.compile(state.crossJoinsWithAggFunction) + blackhole.consume(exprValue) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun compileCrossJoinAggFuncWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val exprValue = state.pipelineWithoutInterruption.compile(state.crossJoinsWithAggFunction) + blackhole.consume(exprValue) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun compileCrossJoinAggFuncGroupingWithInterruptible(state: MyState, blackhole: Blackhole) { + val exprValue = state.pipeline.compile(state.crossJoinsWithAggFunctionAndGroupBy) + blackhole.consume(exprValue) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun compileCrossJoinAggFuncGroupingWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val exprValue = state.pipelineWithoutInterruption.compile(state.crossJoinsWithAggFunctionAndGroupBy) + blackhole.consume(exprValue) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun evalCrossJoinWithInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipeline.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + blackhole.consume(value) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun evalCrossJoinWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + blackhole.consume(value) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun evalCrossJoinAggWithInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipeline.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + blackhole.consume(value) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun evalCrossJoinAggWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + blackhole.consume(value) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun evalCrossJoinAggGroupWithInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipeline.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + blackhole.consume(value) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun evalCrossJoinAggGroupWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + blackhole.consume(value) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun iterCrossJoinWithInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipeline.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + value.forEach { blackhole.consume(it) } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun iterCrossJoinWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + value.forEach { blackhole.consume(it) } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun iterCrossJoinAggWithInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipeline.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + value.forEach { blackhole.consume(it) } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun iterCrossJoinAggWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + value.forEach { blackhole.consume(it) } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun iterCrossJoinAggGroupWithInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipeline.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + value.forEach { blackhole.consume(it) } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + fun iterCrossJoinAggGroupWithoutInterruptible(state: MyState, blackhole: Blackhole) { + val result = state.pipelineWithoutInterruption.compile(state.crossJoinsAggGroupAst).evaluate(state.session) as PartiQLResult.Value + val value = result.value + value.forEach { blackhole.consume(it) } + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt index 59f570f15b..7083c9dc12 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/CompileOptions.kt @@ -137,6 +137,10 @@ enum class ThunkReturnTypeAssertions { * * @param defaultTimezoneOffset Default timezone offset to be used when TIME WITH TIME ZONE does not explicitly * specify the time zone. Defaults to [ZoneOffset.UTC] + * @param interruptible specifies whether the compilation and execution of the compiled statement is interruptible. If + * set to true, the compilation and execution of statements will check [Thread.interrupted] frequently. If set to + * false, the compilation and execution of statements is not guaranteed to be interruptible. It *may* still be interrupted, + * however, it is not guaranteed. The default is false. */ @Suppress("DataClassPrivateConstructor") data class CompileOptions private constructor ( @@ -146,7 +150,8 @@ data class CompileOptions private constructor ( val thunkOptions: ThunkOptions = ThunkOptions.standard(), val typingMode: TypingMode = TypingMode.LEGACY, val typedOpBehavior: TypedOpBehavior = TypedOpBehavior.HONOR_PARAMETERS, - val defaultTimezoneOffset: ZoneOffset = ZoneOffset.UTC + val defaultTimezoneOffset: ZoneOffset = ZoneOffset.UTC, + val interruptible: Boolean = false ) { companion object { @@ -193,6 +198,7 @@ data class CompileOptions private constructor ( fun thunkOptions(value: ThunkOptions) = set { copy(thunkOptions = value) } fun thunkOptions(build: ThunkOptions.Builder.() -> Unit) = set { copy(thunkOptions = ThunkOptions.build(build)) } fun defaultTimezoneOffset(value: ZoneOffset) = set { copy(defaultTimezoneOffset = value) } + fun isInterruptible(value: Boolean) = set { copy(interruptible = value) } private inline fun set(block: CompileOptions.() -> CompileOptions): Builder { options = block(options) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt index cd0e316ff8..630cc155fb 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/EvaluatingCompiler.kt @@ -150,6 +150,25 @@ internal open class EvaluatingCompiler( "compilationContextStack was empty.", ErrorCode.EVALUATOR_UNEXPECTED_VALUE, internal = true ) + /** + * This checks whether the thread has been interrupted. Specifically, it currently checks during the compilation + * of aggregations and joins, the "evaluation" of aggregations and joins, and the materialization of joins + * and from source scans. + * + * Note: This is essentially a way to avoid constantly checking [CompileOptions.interruptible]. By writing it this + * way, we statically determine whether to introduce checks. If the compiler has specified + * [CompileOptions.interruptible], the invocation of this function will insert a Thread interruption check. If not + * specified, it will not perform the check during compilation/evaluation/materialization. + */ + private val interruptionCheck: () -> Unit = when (compileOptions.interruptible) { + true -> { -> + if (Thread.interrupted()) { + throw InterruptedException() + } + } + false -> { -> Unit } + } + // Note: please don't make this inline -- it messes up [EvaluationException] stack traces and // isn't a huge benefit because this is only used at SQL-compile time anyway. internal fun nestCompilationContext( @@ -1824,7 +1843,10 @@ internal open class EvaluatingCompiler( // Grouping is not needed -- simply project the results from the FROM clause directly. thunkFactory.thunkEnv(metas) { env -> - val sourcedRows = sourceThunks(env) + val sourcedRows = sourceThunks(env).map { + interruptionCheck() + it + } val orderedRows = when (orderByThunk) { null -> sourcedRows @@ -1908,6 +1930,7 @@ internal open class EvaluatingCompiler( // iterate over the values from the FROM clause and populate our // aggregate register values. fromProductions.forEach { fromProduction -> + interruptionCheck() compiledAggregates?.forEachIndexed { index, ca -> registers[index].aggregator.next(ca.argThunk(fromProduction.env)) } @@ -2473,6 +2496,7 @@ internal open class EvaluatingCompiler( // compute the join over the data sources var seq = compiledSources .foldLeftProduct({ env: Environment -> env }) { currEnvT: (Environment) -> Environment, currSource: CompiledFromSource -> + interruptionCheck() // [currSource] - the next FROM currSource to add to the join // [currEnvT] - the environment add-on that previous sources of the join have constructed // and that can be used for evaluating this [currSource] (if it depends on the previous sources) @@ -2521,6 +2545,7 @@ internal open class EvaluatingCompiler( } .asSequence() .map { joinedValues -> + interruptionCheck() // bind the joined value to the bindings for the filter/project FromProduction(joinedValues, fromEnv.nest(localsBinder.bindLocals(joinedValues))) } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt index 07a51d512d..3b489a7d5d 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/Thunk.kt @@ -118,7 +118,13 @@ internal val DEFAULT_EXCEPTION_HANDLER_FOR_LEGACY_MODE: ThunkExceptionHandlerFor ) } -internal val DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE: ThunkExceptionHandlerForPermissiveMode = { _, _ -> } +internal val DEFAULT_EXCEPTION_HANDLER_FOR_PERMISSIVE_MODE: ThunkExceptionHandlerForPermissiveMode = { e, _ -> + when (e) { + is InterruptedException -> { throw e } + is StackOverflowError -> { throw e } + else -> {} + } +} /** * An extension method for creating [ThunkFactory] based on the type of [TypingMode] diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt new file mode 100644 index 0000000000..b8d1063d35 --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt @@ -0,0 +1,192 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.lang.eval + +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.partiql.lang.syntax.PartiQLParserBuilder +import org.partiql.lang.syntax.impl.INTERRUPT_AFTER_MS +import org.partiql.lang.syntax.impl.WAIT_FOR_THREAD_TERMINATION_MS +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.concurrent.thread + +/** + * Making sure we can interrupt the [EvaluatingCompiler]. + */ +class EvaluatingCompilerInterruptTests { + + private val parser = PartiQLParserBuilder.standard().build() + private val session = EvaluationSession.standard() + private val options = CompileOptions.standard().copy(interruptible = true) + private val compiler = EvaluatingCompiler( + emptyList(), + emptyMap(), + emptyMap(), + options + ) + + /** + * Joins are only evaluated during the materialization of the ExprValue's elements. Cross Joins. + */ + @Test + fun evalCrossJoins() { + val query = """ + SELECT + * + FROM + ([1, 2, 3, 4]) as x1, + ([1, 2, 3, 4]) as x2, + ([1, 2, 3, 4]) as x3, + ([1, 2, 3, 4]) as x4, + ([1, 2, 3, 4]) as x5, + ([1, 2, 3, 4]) as x6, + ([1, 2, 3, 4]) as x7, + ([1, 2, 3, 4]) as x8, + ([1, 2, 3, 4]) as x9, + ([1, 2, 3, 4]) as x10, + ([1, 2, 3, 4]) as x11, + ([1, 2, 3, 4]) as x12, + ([1, 2, 3, 4]) as x13, + ([1, 2, 3, 4]) as x14, + ([1, 2, 3, 4]) as x15 + """.trimIndent() + val ast = parser.parseAstStatement(query) + val expression = compiler.compile(ast) + val result = expression.evaluate(session) as PartiQLResult.Value + testThreadInterrupt { + result.value.forEach { it } + } + } + + /** + * Joins are only evaluated during the materialization of the ExprValue's elements. Making sure left + * joins can be interrupted. + */ + @Test + fun evalLeftJoins() { + val query = """ + SELECT + * + FROM + ([1, 2, 3, 4]) as x1 LEFT JOIN + ([1, 2, 3, 4]) as x2 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x3 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x4 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x5 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x6 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x7 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x8 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x9 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x10 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x11 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x12 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x13 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x14 ON TRUE LEFT JOIN + ([1, 2, 3, 4]) as x15 ON TRUE + """.trimIndent() + val ast = parser.parseAstStatement(query) + val expression = compiler.compile(ast) + val result = expression.evaluate(session) as PartiQLResult.Value + testThreadInterrupt { + result.value.forEach { it } + } + } + + /** + * Aggregations currently get materialized during [Expression.evaluate], so we need to check that we can + * interrupt there. + */ + @Test + fun compileLargeAggregation() { + val query = """ + SELECT + COUNT(*) + FROM + ([1, 2, 3, 4]) as x1, + ([1, 2, 3, 4]) as x2, + ([1, 2, 3, 4]) as x3, + ([1, 2, 3, 4]) as x4, + ([1, 2, 3, 4]) as x5, + ([1, 2, 3, 4]) as x6, + ([1, 2, 3, 4]) as x7, + ([1, 2, 3, 4]) as x8, + ([1, 2, 3, 4]) as x9, + ([1, 2, 3, 4]) as x10, + ([1, 2, 3, 4]) as x11, + ([1, 2, 3, 4]) as x12, + ([1, 2, 3, 4]) as x13, + ([1, 2, 3, 4]) as x14, + ([1, 2, 3, 4]) as x15 + """.trimIndent() + val ast = parser.parseAstStatement(query) + val expression = compiler.compile(ast) + testThreadInterrupt { + expression.evaluate(session) as PartiQLResult.Value + } + } + + /** + * We need to make sure that we can end a never-ending query. These sorts of queries get materialized during the + * iteration of [ExprValue]. + */ + @Test + fun neverEndingScan() { + val indefiniteCollection = ExprValue.newBag( + sequence { + while (true) { + yield(ExprValue.nullValue) + } + } + ) + val query = """ + SELECT * + FROM ? + """.trimIndent() + val session = EvaluationSession.build { + parameters(listOf(indefiniteCollection)) + } + val ast = parser.parseAstStatement(query) + val expression = compiler.compile(ast) + val result = expression.evaluate(session) as PartiQLResult.Value + testThreadInterrupt { + result.value.forEach { it } + } + } + + private fun testThreadInterrupt( + interruptAfter: Long = INTERRUPT_AFTER_MS, + interruptWait: Long = WAIT_FOR_THREAD_TERMINATION_MS, + block: () -> Unit + ) { + val wasInterrupted = AtomicBoolean(false) + val t = thread(start = false) { + try { + block() + } catch (_: InterruptedException) { + wasInterrupted.set(true) + } catch (e: EvaluationException) { + if (e.cause is InterruptedException) { + wasInterrupted.set(true) + } + } + } + t.setUncaughtExceptionHandler { _, ex -> throw ex } + t.start() + Thread.sleep(interruptAfter) + t.interrupt() + t.join(interruptWait) + Assertions.assertTrue(wasInterrupted.get(), "Thread should have been interrupted.") + } +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt index 8384d7de3a..1c463fb1a3 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/impl/PartiQLPigParserThreadInterruptTests.kt @@ -37,10 +37,10 @@ import java.util.concurrent.atomic.AtomicBoolean import kotlin.concurrent.thread /** How long (in millis) to wait after starting a thread to set the interrupted flag. */ -private const val INTERRUPT_AFTER_MS: Long = 100 +const val INTERRUPT_AFTER_MS: Long = 100 /** How long (in millis) to wait for a thread to terminate after setting the interrupted flag. */ -private const val WAIT_FOR_THREAD_TERMINATION_MS: Long = 1000 +const val WAIT_FOR_THREAD_TERMINATION_MS: Long = 1000 /** * At various locations in this codebase we check the state of [Thread.interrupted] and throw an