Skip to content

Commit

Permalink
Add Handle for effect handlers
Browse files Browse the repository at this point in the history
Allow resuming Cont with exception
  • Loading branch information
kyay10 committed Jun 8, 2024
1 parent ea1ee91 commit 5627d99
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 11 deletions.
64 changes: 64 additions & 0 deletions library/src/commonMain/kotlin/handle.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import kotlin.jvm.JvmInline

@JvmInline
public value class Handle<Error, T> internal constructor(private val reader: Reader<suspend (Error) -> T>) {
public constructor() : this(Reader())

@ResetDsl
public suspend fun call(error: Error): T = reader.ask()(error)

@ResetDsl
public suspend fun <R> handle(
handler: suspend (Error, Cont<T, R>) -> R, body: suspend () -> R
): R {
val prompt = Prompt<R>()
return prompt.pushPrompt(
extraContext = reader.context(DeepHandler(prompt, this, handler)::invoke), body = body
)
}

private class DeepHandler<Error, T, R>(
private val prompt: Prompt<R>,
private val handle: Handle<Error, T>,
private val handler: suspend (Error, Cont<T, R>) -> R
) {
suspend operator fun invoke(error: Error): T = prompt.takeSubCont { k ->
handler(error) {
k.pushSubContWith(it, isDelimiting = true, extraContext = handle.reader.context(::invoke))
}
}
}

@ResetDsl
public suspend fun <R> handleShallow(
handler: suspend (Error, Cont<T, R>) -> R, body: suspend () -> R
): R {
val prompt = Prompt<R>()
return prompt.pushPrompt(extraContext = reader.context(ShallowHandler(prompt, handler)::invoke), body = body)
}

private class ShallowHandler<Error, T, R>(
private val prompt: Prompt<R>,
private val handler: suspend (Error, Cont<T, R>) -> R
) {
suspend operator fun invoke(error: Error): T = prompt.takeSubCont { k ->
handler(error) { result ->
k.pushSubContWith(result)
}
}
}
}

@ResetDsl
public suspend fun <Error, T, R> newHandle(
handler: suspend (Error, Cont<T, R>) -> R, body: suspend Handle<Error, T>.() -> R
): R = with(Handle<Error, T>()) {
handle(handler) { body() }
}

@ResetDsl
public suspend fun <Error, T, R> newHandleShallow(
handler: suspend Handle<Error, T>.(Error, Cont<T, R>) -> R, body: suspend Handle<Error, T>.() -> R
): R = with(Handle<Error, T>()) {
handleShallow({ e, k -> handler(e, k) }) { body() }
}
14 changes: 9 additions & 5 deletions library/src/commonMain/kotlin/resetDsl.kt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
public typealias Cont<T, R> = suspend (T) -> R
public fun interface Cont<in T, out R> {
public suspend fun with(value: Result<T>): R
public suspend operator fun invoke(value: T): R = with(Result.success(value))
public suspend fun withException(exception: Throwable): R = with(Result.failure(exception))
}

@ResetDsl
public suspend fun <R> Prompt<R>.reset(body: suspend () -> R): R = pushPrompt(body = body)
Expand All @@ -10,19 +14,19 @@ public suspend fun <R> topReset(body: suspend Prompt<R>.() -> R): R = runCC { ne

@ResetDsl
public suspend inline fun <T, R> Prompt<R>.shift(crossinline block: suspend (Cont<T, R>) -> R): T =
takeSubCont(deleteDelimiter = false) { sk -> block { sk.pushSubContWith(Result.success(it), isDelimiting = true) } }
takeSubCont(deleteDelimiter = false) { sk -> block { sk.pushSubContWith(it, isDelimiting = true) } }

@ResetDsl
public suspend inline fun <T, R> Prompt<R>.control(crossinline block: suspend (Cont<T, R>) -> R): T =
takeSubCont(deleteDelimiter = false) { sk -> block { sk.pushSubContWith(Result.success(it)) } }
takeSubCont(deleteDelimiter = false) { sk -> block { sk.pushSubContWith(it) } }

@ResetDsl
public suspend inline fun <T, R> Prompt<R>.shift0(crossinline block: suspend (Cont<T, R>) -> R): T =
takeSubCont { sk -> block { sk.pushSubContWith(Result.success(it), isDelimiting = true) } }
takeSubCont { sk -> block { sk.pushSubContWith(it, isDelimiting = true) } }

@ResetDsl
public suspend inline fun <T, R> Prompt<R>.control0(crossinline block: suspend (Cont<T, R>) -> R): T =
takeSubCont { sk -> block { sk.pushSubContWith(Result.success(it)) } }
takeSubCont { sk -> block { sk.pushSubContWith(it) } }

@ResetDsl
public fun <R> Prompt<R>.abortWith(value: Result<R>): Nothing = abortWith(deleteDelimiter = false, value)
Expand Down
70 changes: 70 additions & 0 deletions library/src/commonTest/kotlin/HandleTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import io.kotest.matchers.shouldBe
import kotlinx.coroutines.test.runTest
import kotlin.test.Test

class HandleTest {
@Test
fun coroutineAndReader() = runTest {
val printed = mutableListOf<Int>()
runCC {
runReader(10) {
suspend fun Handle<Int, Unit>.handler(error: Int, cont: Cont<Unit, Unit>) {
printed.add(error)
handleShallow(::handler) { pushReader(ask() + 1) { cont(Unit) } }
}
newHandleShallow<Int, Unit, _>(Handle<Int, Unit>::handler) {
call(ask())
call(ask())
pushReader(ask() + 10) {
call(ask())
call(ask())
}
}
}
}
printed shouldBe listOf(10, 11, 21, 21)
}

@Test
fun coroutineAndReaderWithNestedHandler() = runTest {
val printed = mutableListOf<Int>()
runCC {
runReader(10) {
suspend fun Handle<Int, Unit>.handler(error: Int, cont: Cont<Unit, Unit>) {
handleShallow({ e, k ->
printed.add(e)
handler(e, k)
}) { pushReader(ask() + 1) { cont(Unit) } }
}
newHandleShallow<Int, Unit, _>(Handle<Int, Unit>::handler) {
call(ask())
call(ask())
pushReader(ask() + 10) {
call(ask())
call(ask())
}
}
}
}
printed shouldBe listOf(11, 21, 21)
}

@Test
fun readerSimulation() = runTest {
runCC {
newHandle<Unit, Int, _>({ _, cont -> cont(42) }) {
call(Unit) shouldBe 42
handleShallow({ _, cont -> cont(43) }) {
call(Unit) shouldBe 43
call(Unit) shouldBe 42
}
call(Unit) shouldBe 42
handle({ _, cont -> cont(44) }) {
call(Unit) shouldBe 44
call(Unit) shouldBe 44
}
call(Unit) shouldBe 42
}
}
}
}
7 changes: 3 additions & 4 deletions library/src/commonTest/kotlin/MonadTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MonadTest {
}
}

suspend fun <S, A, B> Prompt<State<S, A>>.bind(state: State<S, B>): B = shift(state::flatMap)
suspend fun <S, A, B> Prompt<State<S, A>>.bind(state: State<S, B>): B = shift { k -> state.flatMap { k(it) } }

suspend fun <S, R> stateReset(body: suspend Prompt<State<S, R>>.() -> R): State<S, R> =
newReset { State.of(body(this)) }
Expand All @@ -40,8 +40,7 @@ class MonadTest {
}.run(CounterState(0))
}

result shouldBe incrementCounter().flatMap { doubleCounter().flatMap { doubleCounter() } }
.run(CounterState(0))
result shouldBe incrementCounter().flatMap { doubleCounter().flatMap { doubleCounter() } }.run(CounterState(0))
}

class Reader<R, A>(val reader: suspend (R) -> A) {
Expand All @@ -55,7 +54,7 @@ class MonadTest {
}
}

suspend fun <R, A, B> Prompt<Reader<R, A>>.bind(reader: Reader<R, B>): B = shift(reader::flatMap)
suspend fun <R, A, B> Prompt<Reader<R, A>>.bind(reader: Reader<R, B>): B = shift { k -> reader.flatMap { k(it) } }

suspend fun <R, A> readerReset(body: suspend Prompt<Reader<R, A>>.() -> A): Reader<R, A> =
newReset { Reader.of(body(this)) }
Expand Down
2 changes: 1 addition & 1 deletion library/src/commonTest/kotlin/ReaderTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ReaderTest {

sealed class R<in A, out B> {
data class R<out B>(val b: B) : ReaderTest.R<Any?, B>()
data class J<in A, out B>(val f: suspend (A) -> ReaderTest.R<A, B>) : ReaderTest.R<A, B>()
data class J<in A, out B>(val f: Cont<A, ReaderTest.R<A, B>>) : ReaderTest.R<A, B>()
}

// https://www.brinckerhoff.org/clements/csc530-sp08/Readings/kiselyov-2006.pdf
Expand Down
2 changes: 1 addition & 1 deletion library/src/jsMain/kotlin/compilerCloning.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ private val contClass: KClass<*> = run {
var c: Continuation<*>? = null
suspend { c = foo() }.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) { })
val cont = c!!
(js("Object.getPrototypeOf(Object.getPrototypeOf(cont)).constructor") as JsClass<*>).kotlin
js("Object.getPrototypeOf(Object.getPrototypeOf(cont)).constructor").unsafeCast<JsClass<*>>().kotlin
}

private suspend fun foo(): Continuation<*> = suspendCoroutineUninterceptedOrReturn { it }
Expand Down

0 comments on commit 5627d99

Please sign in to comment.