Skip to content
This repository has been archived by the owner on Dec 21, 2024. It is now read-only.

Commit

Permalink
fix: gracefully handle all ffi unwraps
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanFlurry committed Oct 7, 2024
1 parent 0009a79 commit 0f30ad8
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 51 deletions.
181 changes: 136 additions & 45 deletions packages/toolchain-ffi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mod runtime;

use std::sync::mpsc::Receiver;
use std::sync::mpsc::TryRecvError;
use std::{
collections::HashMap,
ffi::{CStr, CString},
os::raw::c_char,
sync::atomic::{AtomicU64, Ordering},
sync::{mpsc, Mutex},
thread,
};
use tokio::sync::mpsc as tokio_mpsc;
use toolchain::util::task;
Expand All @@ -16,26 +17,56 @@ type TaskId = u64;

struct TaskHandle {
abort_tx: tokio_mpsc::Sender<()>,
event_rx: Receiver<task::TaskEvent>,
}

lazy_static::lazy_static! {
static ref TASK_HANDLES: Mutex<HashMap<TaskId, TaskHandle>> = Mutex::new(HashMap::new());
}

/// Callback type used to receive events from tasks.
type EventCallback = extern "C" fn(TaskId, *const c_char);
#[repr(u8)]
pub enum ErrorCode {
Success = 0,
NullPointer = 1,
ParseError = 2,
LockError = 3,
CStringNew = 4,
TaskNotFound = 5,
}

#[repr(C)]
pub struct RunTaskResult {
task_id: TaskId,
error_code: u8,
}

#[no_mangle]
pub extern "C" fn rivet_run_task(
name: *const c_char,
input_json: *const c_char,
callback: EventCallback,
) -> TaskId {
pub extern "C" fn rivet_run_task(name: *const c_char, input_json: *const c_char) -> RunTaskResult {
match inner_run_task(name, input_json) {
Ok(task_id) => RunTaskResult {
task_id,
error_code: ErrorCode::Success as u8,
},
Err(error_code) => RunTaskResult {
task_id: 0,
error_code: error_code as u8,
},
}
}

fn inner_run_task(name: *const c_char, input_json: *const c_char) -> Result<TaskId, ErrorCode> {
// Handle null pointers
if name.is_null() || input_json.is_null() {
return Err(ErrorCode::NullPointer);
}

// Parse input
let name_tmp = unsafe { CStr::from_ptr(name).to_str().unwrap() };
let name = name_tmp.to_string();
let input_json_tmp = unsafe { CStr::from_ptr(input_json).to_str().unwrap() };
let input_json = input_json_tmp.to_string();
let name = unsafe { CStr::from_ptr(name).to_str() }
.map_err(|_| ErrorCode::ParseError)?
.to_string();
let input_json = unsafe { CStr::from_ptr(input_json).to_str() }
.map_err(|_| ErrorCode::ParseError)?
.to_string();

runtime::setup();

Expand All @@ -44,13 +75,17 @@ pub extern "C" fn rivet_run_task(
let (output_tx, output_rx) = mpsc::channel();
let (run_config, mut handles) = task::RunConfig::build();

// Store abort sender
TASK_HANDLES.lock().unwrap().insert(
task_id,
TaskHandle {
abort_tx: handles.abort_tx.clone(),
},
);
// Store abort sender and event receiver
TASK_HANDLES
.lock()
.map_err(|_| ErrorCode::LockError)?
.insert(
task_id,
TaskHandle {
abort_tx: handles.abort_tx.clone(),
event_rx: output_rx,
},
);

// Run the task
runtime::spawn(Box::pin(async move {
Expand All @@ -71,40 +106,96 @@ pub extern "C" fn rivet_run_task(
}
}));

thread::spawn(move || {
// Pass events to callback
while let Ok(event) = output_rx.recv() {
// Serialize event
let event_json = match serde_json::to_string(&event) {
Ok(x) => x,
Err(err) => {
eprintln!("error with event: {err:?}");
return;
Ok(task_id)
}

#[repr(C)]
pub struct TaskEvent {
task_id: TaskId,
event_json: *mut c_char,
}

#[repr(C)]
pub struct PollTaskEventsResult {
count: usize,
error_code: u8,
}

#[no_mangle]
pub extern "C" fn rivet_poll_task_events(
events: *mut TaskEvent,
max_events: usize,
) -> PollTaskEventsResult {
match inner_poll_task_events(events, max_events) {
Ok(count) => PollTaskEventsResult {
count,
error_code: ErrorCode::Success as u8,
},
Err(error_code) => PollTaskEventsResult {
count: 0,
error_code: error_code as u8,
},
}
}

fn inner_poll_task_events(events: *mut TaskEvent, max_events: usize) -> Result<usize, ErrorCode> {
let mut task_handles = TASK_HANDLES.lock().map_err(|_| ErrorCode::LockError)?;

let mut count = 0;
let mut completed_tasks = Vec::new();

for (task_id, handle) in task_handles.iter_mut() {
match handle.event_rx.try_recv() {
Ok(event) => {
let event_json =
serde_json::to_string(&event).map_err(|_| ErrorCode::ParseError)?;

// Store event in TaskEvent
let event_ptr = CString::new(event_json)
.map_err(|_| ErrorCode::CStringNew)?
.into_raw();
unsafe {
if events.is_null() {
return Err(ErrorCode::NullPointer);
}
(*events.add(count)).task_id = *task_id;
(*events.add(count)).event_json = event_ptr;
}
};

// Call the callback function
let c_str = CString::new(event_json).unwrap();
callback(task_id, c_str.into_raw());
count += 1;
if count >= max_events {
break;
}
}
Err(TryRecvError::Empty) => continue,
Err(TryRecvError::Disconnected) => {
completed_tasks.push(*task_id);
}
}
}

// Remove task handle when finished
TASK_HANDLES.lock().unwrap().remove(&task_id);
});
// Remove completed tasks
for task_id in completed_tasks {
task_handles.remove(&task_id);
}

task_id
Ok(count)
}

#[no_mangle]
pub extern "C" fn rivet_abort_task(task_id: TaskId) -> bool {
if let Some(handle) = TASK_HANDLES.lock().unwrap().remove(&task_id) {
runtime::spawn(async move {
let _ = handle.abort_tx.send(()).await;
});
true
} else {
eprintln!("failed to abort event");
false
pub extern "C" fn rivet_abort_task(task_id: TaskId) -> u8 {
match TASK_HANDLES.lock() {
Result::Ok(mut lock) => {
if let Some(handle) = lock.remove(&task_id) {
runtime::spawn(async move {
let _ = handle.abort_tx.send(()).await;
});
ErrorCode::Success as u8
} else {
ErrorCode::TaskNotFound as u8
}
}
Err(_) => ErrorCode::LockError as u8,
}
}

Expand Down
5 changes: 4 additions & 1 deletion packages/toolchain/src/tasks/game_server/start.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use anyhow::*;
use serde::{Deserialize, Serialize};

Expand All @@ -8,6 +10,7 @@ pub struct Input {
pub cmd: String,
pub args: Vec<String>,
pub cwd: String,
pub envs: HashMap<String, String>,
}

#[derive(Serialize)]
Expand All @@ -31,7 +34,7 @@ impl task::Task for Task {
Ok(CommandOpts {
command: input.cmd,
args: input.args,
envs: Vec::new(),
envs: input.envs.into_iter().collect(),
current_dir: input.cwd,
})
})
Expand Down
16 changes: 11 additions & 5 deletions packages/toolchain/src/util/process_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ pub struct ProcessManager {
impl ProcessManager {
pub fn new(key: &'static str, kill_grace: Duration) -> Arc<Self> {
let (status_tx, status_rx) = watch::channel(ProcessStatus::NotRunning);
let (event_tx, event_rx) = broadcast::channel(16);
let (event_tx, event_rx) = broadcast::channel(1024);
Arc::new(Self {
key,
kill_grace,
Expand Down Expand Up @@ -175,14 +175,20 @@ impl ProcessManager {
}

// Wait for events
while let Result::Ok(event) = event_rx.recv().await {
match event {
ProcessEvent::Log(ProcessLog::Stdout(line)) => {
loop {
match event_rx.recv().await {
Result::Ok(ProcessEvent::Log(ProcessLog::Stdout(line))) => {
task_ctx.log(format!("[stdout] {line}"));
}
ProcessEvent::Log(ProcessLog::Stderr(line)) => {
Result::Ok(ProcessEvent::Log(ProcessLog::Stderr(line))) => {
task_ctx.log(format!("[stderr] {line}"));
}
Err(broadcast::error::RecvError::Lagged(amount)) => {
eprintln!("event_rx lagged by {amount}");
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
};
Expand Down

0 comments on commit 0f30ad8

Please sign in to comment.