Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

feat(core): use custom scheduler to avoid stack overflow #273

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion optd-core/src/cascades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

mod memo;
mod optimizer;
pub mod rule_match;
pub(crate) mod rule_match;
pub(crate) mod scheduler;
mod tasks2;

pub use memo::{Memo, NaiveMemo};
Expand Down
9 changes: 1 addition & 8 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Display;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use anyhow::Result;
Expand Down Expand Up @@ -293,14 +291,9 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
}

pub fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
use pollster::FutureExt as _;
trace!(event = "fire_optimize_tasks", root_group_id = %group_id);
let mut task = TaskContext::new(self);
// 32MB stack for the optimization process, TODO: reduce memory footprint
stacker::grow(32 * 1024 * 1024, || {
let fut: Pin<Box<dyn Future<Output = ()>>> = Box::pin(task.fire_optimize(group_id));
fut.block_on();
});
task.fire_optimize(group_id);
Ok(())
}

Expand Down
101 changes: 101 additions & 0 deletions optd-core/src/cascades/scheduler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2023-2024 CMU Database Group
//
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

//! A single-thread scheduler for the cascades tasks. The tasks are queued in a stack of `Vec` so that
//! we won't overflow the system stack. The cascades task are compute-only and don't have I/O.

use std::{
cell::RefCell,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Wake},
};

struct Task {
// The task to be executed.
inner: Pin<Box<dyn Future<Output = ()> + 'static>>,
}

pub struct Executor {}

impl Wake for Task {
fn wake(self: Arc<Self>) {
unreachable!("cascades tasks shouldn't yield");
}
}

// This needs nightly feature and we use stable Rust, so we had to copy-paste it here. TODO: license

mod optd_futures_task {
use std::{
ptr,
task::{RawWaker, RawWakerVTable, Waker},
};
const NOOP: RawWaker = {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
// Cloning just returns a new no-op raw waker
|_| NOOP,
// `wake` does nothing
|_| {},
// `wake_by_ref` does nothing
|_| {},
// Dropping does nothing as we don't allocate anything
|_| {},
);
RawWaker::new(ptr::null(), &VTABLE)
};

#[inline]
#[must_use]
pub const fn noop() -> &'static Waker {
const WAKER: &Waker = &unsafe { Waker::from_raw(NOOP) };
WAKER
}
}

thread_local! {
pub static OPTD_SCHEDULER_QUEUE: RefCell<Vec<Task>> = RefCell::new(Vec::new());
}

pub fn spawn<F>(task: F)
where
F: Future<Output = ()> + 'static,
{
OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| {
tasks.push(
Task {
inner: Box::pin(task),
}
.into(),
)
});
}

impl Executor {
pub fn new() -> Self {
Executor {}
}

pub fn spawn<F>(&self, task: F)
where
F: Future<Output = ()> + 'static,
{
spawn(task);
}

/// SAFETY: The caller must ensure all futures running on this runtime does not have I/O. Otherwise it will deadloop
/// with all futures pending.
pub fn run(&self) {
let waker = optd_futures_task::noop();
let mut cx: Context<'_> = Context::from_waker(&waker);

while let Some(mut task) = OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| tasks.pop()) {
if task.inner.as_mut().poll(&mut cx).is_pending() {
OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| tasks.push(task))
}
}
}
}
72 changes: 61 additions & 11 deletions optd-core/src/cascades/tasks2.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
// Copyright (c) 2023-2024 CMU Database Group
//
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

//! The v2 implementation of cascades tasks. The code uses Rust async/await to generate the state machine,
//! so that the logic is much more clear and easier to follow.

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use itertools::Itertools;
use tracing::trace;

use super::memo::MemoPlanNode;
use super::rule_match::match_and_pick_expr;
use super::scheduler::{self, Executor};
use super::{optimizer::RuleId, CascadesOptimizer, ExprId, GroupId, Memo};
use crate::cascades::{
memo::{Winner, WinnerInfo},
Expand All @@ -31,6 +42,12 @@ pub enum TaskDesc {
OptimizeInput(ExprId, GroupId),
}

unsafe fn extend_to_static<'x>(
f: Pin<Box<dyn Future<Output = ()> + 'x>>,
) -> Pin<Box<dyn Future<Output = ()> + 'static>> {
unsafe { std::mem::transmute(f) }
}

impl<'a, T: NodeType, M: Memo<T>> TaskContext<'a, T, M> {
pub fn new(optimizer: &'a mut CascadesOptimizer<T, M>) -> Self {
Self {
Expand All @@ -39,24 +56,45 @@ impl<'a, T: NodeType, M: Memo<T>> TaskContext<'a, T, M> {
}
}

pub async fn fire_optimize(&mut self, group_id: GroupId) {
self.optimize_group(SearchContext {
group_id,
upper_bound: None,
})
.await;
pub fn fire_optimize(&mut self, group_id: GroupId) {
let executor = Executor::new();
executor.spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_group(SearchContext {
group_id,
upper_bound: None,
})) as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
executor.run();
}

async fn optimize_group(&mut self, ctx: SearchContext) {
Box::pin(self.optimize_group_inner(ctx)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_group_inner(ctx)) as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn optimize_expr(&mut self, ctx: SearchContext, expr_id: ExprId, exploring: bool) {
Box::pin(self.optimize_expr_inner(ctx, expr_id, exploring)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_expr_inner(ctx, expr_id, exploring))
as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn explore_group(&mut self, ctx: SearchContext) {
Box::pin(self.explore_group_inner(ctx)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.explore_group_inner(ctx)) as Pin<Box<dyn Future<Output = ()>>>).await
}))
});
}

async fn apply_rule(
Expand All @@ -66,11 +104,23 @@ impl<'a, T: NodeType, M: Memo<T>> TaskContext<'a, T, M> {
expr_id: ExprId,
exploring: bool,
) {
Box::pin(self.apply_rule_inner(ctx, rule_id, expr_id, exploring)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.apply_rule_inner(ctx, rule_id, expr_id, exploring))
as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn optimize_input(&mut self, ctx: SearchContext, expr_id: ExprId) {
Box::pin(self.optimize_input_inner(ctx, expr_id)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_input_inner(ctx, expr_id))
as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn optimize_group_inner(&mut self, ctx: SearchContext) {
Expand Down
Loading