diff --git a/packages/toolchain-ffi/src/lib.rs b/packages/toolchain-ffi/src/lib.rs index 63df361c..d76e2fbc 100644 --- a/packages/toolchain-ffi/src/lib.rs +++ b/packages/toolchain-ffi/src/lib.rs @@ -1,12 +1,13 @@ mod runtime; +use std::sync::mpsc::Receiver; +use std::sync::mpsc::TryRecvError; use std::{ collections::HashMap, ffi::{CStr, CString}, os::raw::c_char, sync::atomic::{AtomicU64, Ordering}, sync::{mpsc, Mutex}, - thread, }; use tokio::sync::mpsc as tokio_mpsc; use toolchain::util::task; @@ -16,26 +17,56 @@ type TaskId = u64; struct TaskHandle { abort_tx: tokio_mpsc::Sender<()>, + event_rx: Receiver, } lazy_static::lazy_static! { static ref TASK_HANDLES: Mutex> = Mutex::new(HashMap::new()); } -/// Callback type used to receive events from tasks. -type EventCallback = extern "C" fn(TaskId, *const c_char); +#[repr(u8)] +pub enum ErrorCode { + Success = 0, + NullPointer = 1, + ParseError = 2, + LockError = 3, + CStringNew = 4, + TaskNotFound = 5, +} + +#[repr(C)] +pub struct RunTaskResult { + task_id: TaskId, + error_code: u8, +} #[no_mangle] -pub extern "C" fn rivet_run_task( - name: *const c_char, - input_json: *const c_char, - callback: EventCallback, -) -> TaskId { +pub extern "C" fn rivet_run_task(name: *const c_char, input_json: *const c_char) -> RunTaskResult { + match inner_run_task(name, input_json) { + Ok(task_id) => RunTaskResult { + task_id, + error_code: ErrorCode::Success as u8, + }, + Err(error_code) => RunTaskResult { + task_id: 0, + error_code: error_code as u8, + }, + } +} + +fn inner_run_task(name: *const c_char, input_json: *const c_char) -> Result { + // Handle null pointers + if name.is_null() || input_json.is_null() { + return Err(ErrorCode::NullPointer); + } + // Parse input - let name_tmp = unsafe { CStr::from_ptr(name).to_str().unwrap() }; - let name = name_tmp.to_string(); - let input_json_tmp = unsafe { CStr::from_ptr(input_json).to_str().unwrap() }; - let input_json = input_json_tmp.to_string(); + let name = unsafe { CStr::from_ptr(name).to_str() } + .map_err(|_| ErrorCode::ParseError)? + .to_string(); + let input_json = unsafe { CStr::from_ptr(input_json).to_str() } + .map_err(|_| ErrorCode::ParseError)? + .to_string(); runtime::setup(); @@ -44,13 +75,17 @@ pub extern "C" fn rivet_run_task( let (output_tx, output_rx) = mpsc::channel(); let (run_config, mut handles) = task::RunConfig::build(); - // Store abort sender - TASK_HANDLES.lock().unwrap().insert( - task_id, - TaskHandle { - abort_tx: handles.abort_tx.clone(), - }, - ); + // Store abort sender and event receiver + TASK_HANDLES + .lock() + .map_err(|_| ErrorCode::LockError)? + .insert( + task_id, + TaskHandle { + abort_tx: handles.abort_tx.clone(), + event_rx: output_rx, + }, + ); // Run the task runtime::spawn(Box::pin(async move { @@ -71,40 +106,96 @@ pub extern "C" fn rivet_run_task( } })); - thread::spawn(move || { - // Pass events to callback - while let Ok(event) = output_rx.recv() { - // Serialize event - let event_json = match serde_json::to_string(&event) { - Ok(x) => x, - Err(err) => { - eprintln!("error with event: {err:?}"); - return; + Ok(task_id) +} + +#[repr(C)] +pub struct TaskEvent { + task_id: TaskId, + event_json: *mut c_char, +} + +#[repr(C)] +pub struct PollTaskEventsResult { + count: usize, + error_code: u8, +} + +#[no_mangle] +pub extern "C" fn rivet_poll_task_events( + events: *mut TaskEvent, + max_events: usize, +) -> PollTaskEventsResult { + match inner_poll_task_events(events, max_events) { + Ok(count) => PollTaskEventsResult { + count, + error_code: ErrorCode::Success as u8, + }, + Err(error_code) => PollTaskEventsResult { + count: 0, + error_code: error_code as u8, + }, + } +} + +fn inner_poll_task_events(events: *mut TaskEvent, max_events: usize) -> Result { + let mut task_handles = TASK_HANDLES.lock().map_err(|_| ErrorCode::LockError)?; + + let mut count = 0; + let mut completed_tasks = Vec::new(); + + for (task_id, handle) in task_handles.iter_mut() { + match handle.event_rx.try_recv() { + Ok(event) => { + let event_json = + serde_json::to_string(&event).map_err(|_| ErrorCode::ParseError)?; + + // Store event in TaskEvent + let event_ptr = CString::new(event_json) + .map_err(|_| ErrorCode::CStringNew)? + .into_raw(); + unsafe { + if events.is_null() { + return Err(ErrorCode::NullPointer); + } + (*events.add(count)).task_id = *task_id; + (*events.add(count)).event_json = event_ptr; } - }; - // Call the callback function - let c_str = CString::new(event_json).unwrap(); - callback(task_id, c_str.into_raw()); + count += 1; + if count >= max_events { + break; + } + } + Err(TryRecvError::Empty) => continue, + Err(TryRecvError::Disconnected) => { + completed_tasks.push(*task_id); + } } + } - // Remove task handle when finished - TASK_HANDLES.lock().unwrap().remove(&task_id); - }); + // Remove completed tasks + for task_id in completed_tasks { + task_handles.remove(&task_id); + } - task_id + Ok(count) } #[no_mangle] -pub extern "C" fn rivet_abort_task(task_id: TaskId) -> bool { - if let Some(handle) = TASK_HANDLES.lock().unwrap().remove(&task_id) { - runtime::spawn(async move { - let _ = handle.abort_tx.send(()).await; - }); - true - } else { - eprintln!("failed to abort event"); - false +pub extern "C" fn rivet_abort_task(task_id: TaskId) -> u8 { + match TASK_HANDLES.lock() { + Result::Ok(mut lock) => { + if let Some(handle) = lock.remove(&task_id) { + runtime::spawn(async move { + let _ = handle.abort_tx.send(()).await; + }); + ErrorCode::Success as u8 + } else { + ErrorCode::TaskNotFound as u8 + } + } + Err(_) => ErrorCode::LockError as u8, } } diff --git a/packages/toolchain/src/tasks/game_server/start.rs b/packages/toolchain/src/tasks/game_server/start.rs index f1277fc6..3ce0e694 100644 --- a/packages/toolchain/src/tasks/game_server/start.rs +++ b/packages/toolchain/src/tasks/game_server/start.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::*; use serde::{Deserialize, Serialize}; @@ -8,6 +10,7 @@ pub struct Input { pub cmd: String, pub args: Vec, pub cwd: String, + pub envs: HashMap, } #[derive(Serialize)] @@ -31,7 +34,7 @@ impl task::Task for Task { Ok(CommandOpts { command: input.cmd, args: input.args, - envs: Vec::new(), + envs: input.envs.into_iter().collect(), current_dir: input.cwd, }) }) diff --git a/packages/toolchain/src/util/process_manager/mod.rs b/packages/toolchain/src/util/process_manager/mod.rs index 075e97b2..3f868fdb 100644 --- a/packages/toolchain/src/util/process_manager/mod.rs +++ b/packages/toolchain/src/util/process_manager/mod.rs @@ -98,7 +98,7 @@ pub struct ProcessManager { impl ProcessManager { pub fn new(key: &'static str, kill_grace: Duration) -> Arc { let (status_tx, status_rx) = watch::channel(ProcessStatus::NotRunning); - let (event_tx, event_rx) = broadcast::channel(16); + let (event_tx, event_rx) = broadcast::channel(1024); Arc::new(Self { key, kill_grace, @@ -175,14 +175,20 @@ impl ProcessManager { } // Wait for events - while let Result::Ok(event) = event_rx.recv().await { - match event { - ProcessEvent::Log(ProcessLog::Stdout(line)) => { + loop { + match event_rx.recv().await { + Result::Ok(ProcessEvent::Log(ProcessLog::Stdout(line))) => { task_ctx.log(format!("[stdout] {line}")); } - ProcessEvent::Log(ProcessLog::Stderr(line)) => { + Result::Ok(ProcessEvent::Log(ProcessLog::Stderr(line))) => { task_ctx.log(format!("[stderr] {line}")); } + Err(broadcast::error::RecvError::Lagged(amount)) => { + eprintln!("event_rx lagged by {amount}"); + } + Err(broadcast::error::RecvError::Closed) => { + break; + } } } };