Skip to content

Commit

Permalink
Add Send + Sync + Clone bounds to Tape (#203)
Browse files Browse the repository at this point in the history
Closes #201
  • Loading branch information
mkeeter authored Nov 25, 2024
1 parent fdfe9bc commit 5a1bb4a
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 52 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
- Add `vars()` to `Function` trait, because there are cases where we want to get
the variable map without building a tape (and it must always be the same).
- Fix soundness bug in `Mmap` (probably not user-visible)
- Add `Send + Sync + Clone` bounds to the `trait Tape`, to make them easily
shared between threads. Previously, we used an `Arc<Tape>` to share tapes
between threads, but tapes were _already_ using an `Arc<..>` under the hood.
- Changed `Tape::recycle` from returning a `Storage` to returning an
`Option<Storage>`, as tapes may now be shared between threads.

# 0.3.3
- `Function` and evaluator types now produce multiple outputs
Expand Down
2 changes: 1 addition & 1 deletion fidget/src/core/eval/bulk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub trait BulkEvaluator: Default {
///
/// This may be a literal instruction tape (in the case of VM evaluation),
/// or a metaphorical instruction tape (e.g. a JIT function).
type Tape: Tape<Storage = Self::TapeStorage> + Send + Sync;
type Tape: Tape<Storage = Self::TapeStorage>;

/// Associated type for tape storage
///
Expand Down
9 changes: 6 additions & 3 deletions fidget/src/core/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ pub use tracing::TracingEvaluator;
///
/// It includes some kind of storage (which could be empty) and the ability to
/// look up variable mapping.
pub trait Tape {
///
/// Tapes may be shared between threads, so they should be cheap to clone (i.e.
/// a wrapper around an `Arc<..>`).
pub trait Tape: Send + Sync + Clone {
/// Associated type for this tape's data storage
type Storage: Default;

/// Retrieves the internal storage from this tape
/// Tries to retrieve the internal storage from this tape
///
/// This matters most for JIT evaluators, whose tapes are regions of
/// executable memory-mapped RAM (which is expensive to map and unmap).
fn recycle(self) -> Self::Storage;
fn recycle(self) -> Option<Self::Storage>;

/// Returns a mapping from [`Var`](crate::var::Var) to evaluation index
///
Expand Down
2 changes: 1 addition & 1 deletion fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
assert_eq!(&out[0], [0.0, 1.0, 2.0, 3.0]);

// TODO: reuse tape data here
let t = tape.recycle();
let t = tape.recycle().unwrap();

let tape = shape_x1.float_slice_tape(t);
let out = eval
Expand Down
8 changes: 4 additions & 4 deletions fidget/src/core/eval/test/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ where
args[i_index] = lhs;

let (out, _trace) = eval.eval(&tape, &args).unwrap();
tape_data = Some(tape.recycle());
tape_data = Some(tape.recycle().unwrap());

Self::compare_interval_results(
lhs,
Expand Down Expand Up @@ -1004,7 +1004,7 @@ where

let args = [lhs];
let (out, _trace) = eval.eval(&tape, &args).unwrap();
tape_data = Some(tape.recycle());
tape_data = Some(tape.recycle().unwrap());

Self::compare_interval_results(
lhs,
Expand Down Expand Up @@ -1035,7 +1035,7 @@ where
let tape = shape.interval_tape(tape_data.unwrap_or_default());

let (out, _trace) = eval.eval(&tape, &[lhs]).unwrap();
tape_data = Some(tape.recycle());
tape_data = Some(tape.recycle().unwrap());

Self::compare_interval_results(
lhs,
Expand Down Expand Up @@ -1066,7 +1066,7 @@ where
let tape = shape.interval_tape(tape_data.unwrap_or_default());

let (out, _trace) = eval.eval(&tape, &[rhs]).unwrap();
tape_data = Some(tape.recycle());
tape_data = Some(tape.recycle().unwrap());

Self::compare_interval_results(
lhs.into(),
Expand Down
2 changes: 1 addition & 1 deletion fidget/src/core/eval/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub trait TracingEvaluator: Default {
///
/// This may be a literal instruction tape (in the case of VM evaluation),
/// or a metaphorical instruction tape (e.g. a JIT function).
type Tape: Tape<Storage = Self::TapeStorage> + Send + Sync;
type Tape: Tape<Storage = Self::TapeStorage>;

/// Associated type for tape storage
///
Expand Down
3 changes: 2 additions & 1 deletion fidget/src/core/shape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ impl<F: MathFunction> From<Tree> for Shape<F> {
}

/// Wrapper around a function tape, with axes and an optional transform matrix
#[derive(Clone)]
pub struct ShapeTape<T> {
tape: T,

Expand All @@ -373,7 +374,7 @@ pub struct ShapeTape<T> {

impl<T: Tape> ShapeTape<T> {
/// Recycles the inner tape's storage for reuse
pub fn recycle(self) -> T::Storage {
pub fn recycle(self) -> Option<T::Storage> {
self.tape.recycle()
}

Expand Down
5 changes: 3 additions & 2 deletions fidget/src/core/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct EmptyTapeStorage;
///
/// This tape type is equivalent to a [`GenericVmFunction`], but implements
/// different traits ([`Tape`] instead of [`Function`]).
#[derive(Clone)]
pub struct GenericVmTape<const N: usize>(Arc<VmData<N>>);

impl<const N: usize> GenericVmTape<N> {
Expand All @@ -53,8 +54,8 @@ impl<const N: usize> GenericVmTape<N> {

impl<const N: usize> Tape for GenericVmTape<N> {
type Storage = EmptyTapeStorage;
fn recycle(self) -> Self::Storage {
EmptyTapeStorage
fn recycle(self) -> Option<Self::Storage> {
Some(EmptyTapeStorage)
}

fn vars(&self) -> &VarMap {
Expand Down
24 changes: 12 additions & 12 deletions fidget/src/jit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ impl JitFunction {
let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
let ptr = f.as_ptr();
JitTracingFn {
mmap: f,
mmap: f.into(),
vars: self.0.data().vars.clone(),
choice_count: self.0.choice_count(),
output_count: self.0.output_count(),
Expand All @@ -842,7 +842,7 @@ impl JitFunction {
let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
let ptr = f.as_ptr();
JitBulkFn {
mmap: f,
mmap: f.into(),
output_count: self.0.output_count(),
vars: self.0.data().vars.clone(),
fn_bulk: unsafe {
Expand Down Expand Up @@ -994,19 +994,19 @@ pub type JitTracingFnPointer<T> = jit_fn!(
);

/// Handle to an owned function pointer for tracing evaluation
#[derive(Clone)]
pub struct JitTracingFn<T> {
#[allow(unused)]
mmap: Mmap,
mmap: Arc<Mmap>,
choice_count: usize,
output_count: usize,
vars: Arc<VarMap>,
fn_trace: JitTracingFnPointer<T>,
}

impl<T> Tape for JitTracingFn<T> {
impl<T: Clone> Tape for JitTracingFn<T> {
type Storage = Mmap;
fn recycle(self) -> Self::Storage {
self.mmap
fn recycle(self) -> Option<Self::Storage> {
Arc::into_inner(self.mmap)
}

fn vars(&self) -> &VarMap {
Expand Down Expand Up @@ -1105,18 +1105,18 @@ pub type JitBulkFnPointer<T> = jit_fn!(
);

/// Handle to an owned function pointer for bulk evaluation
#[derive(Clone)]
pub struct JitBulkFn<T> {
#[allow(unused)]
mmap: Mmap,
mmap: Arc<Mmap>,
vars: Arc<VarMap>,
output_count: usize,
fn_bulk: JitBulkFnPointer<T>,
}

impl<T> Tape for JitBulkFn<T> {
impl<T: Clone> Tape for JitBulkFn<T> {
type Storage = Mmap;
fn recycle(self) -> Self::Storage {
self.mmap
fn recycle(self) -> Option<Self::Storage> {
Arc::into_inner(self.mmap)
}

fn vars(&self) -> &VarMap {
Expand Down
6 changes: 3 additions & 3 deletions fidget/src/mesh/octree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -948,13 +948,13 @@ impl<F: Function + RenderHints> OctreeBuilder<F> {
self.shape_storage.push(s);
}
if let Some(i_tape) = e.interval.take() {
self.tape_storage.push(i_tape.recycle());
self.tape_storage.push(i_tape.recycle().unwrap());
}
if let Some(f_tape) = e.float_slice.take() {
self.tape_storage.push(f_tape.recycle());
self.tape_storage.push(f_tape.recycle().unwrap());
}
if let Some(g_tape) = e.grad_slice.take() {
self.tape_storage.push(g_tape.recycle());
self.tape_storage.push(g_tape.recycle().unwrap());
}
}
}
Expand Down
35 changes: 11 additions & 24 deletions fidget/src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::{
shape::{Shape, ShapeTape},
Error,
};
use std::sync::Arc;

mod config;
mod region;
Expand Down Expand Up @@ -35,9 +34,9 @@ pub use render2d::{
pub struct RenderHandle<F: Function> {
shape: Shape<F>,

i_tape: Option<Arc<ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape>>>,
f_tape: Option<Arc<ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape>>>,
g_tape: Option<Arc<ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape>>>,
i_tape: Option<ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape>>,
f_tape: Option<ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape>>,
g_tape: Option<ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape>>,

next: Option<(F::Trace, Box<Self>)>,
}
Expand Down Expand Up @@ -74,9 +73,7 @@ impl<F: Function> RenderHandle<F> {
storage: &mut Vec<F::TapeStorage>,
) -> &ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
self.i_tape.get_or_insert_with(|| {
Arc::new(
self.shape.interval_tape(storage.pop().unwrap_or_default()),
)
self.shape.interval_tape(storage.pop().unwrap_or_default())
})
}

Expand All @@ -86,10 +83,8 @@ impl<F: Function> RenderHandle<F> {
storage: &mut Vec<F::TapeStorage>,
) -> &ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
self.f_tape.get_or_insert_with(|| {
Arc::new(
self.shape
.float_slice_tape(storage.pop().unwrap_or_default()),
)
self.shape
.float_slice_tape(storage.pop().unwrap_or_default())
})
}

Expand All @@ -99,10 +94,8 @@ impl<F: Function> RenderHandle<F> {
storage: &mut Vec<F::TapeStorage>,
) -> &ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
self.g_tape.get_or_insert_with(|| {
Arc::new(
self.shape
.grad_slice_tape(storage.pop().unwrap_or_default()),
)
self.shape
.grad_slice_tape(storage.pop().unwrap_or_default())
})
}

Expand Down Expand Up @@ -178,19 +171,13 @@ impl<F: Function> RenderHandle<F> {
}

if let Some(i_tape) = self.i_tape.take() {
if let Ok(i_tape) = Arc::try_unwrap(i_tape) {
tape_storage.push(i_tape.recycle());
}
tape_storage.extend(i_tape.recycle());
}
if let Some(g_tape) = self.g_tape.take() {
if let Ok(g_tape) = Arc::try_unwrap(g_tape) {
tape_storage.push(g_tape.recycle());
}
tape_storage.extend(g_tape.recycle());
}
if let Some(f_tape) = self.f_tape.take() {
if let Ok(f_tape) = Arc::try_unwrap(f_tape) {
tape_storage.push(f_tape.recycle());
}
tape_storage.extend(f_tape.recycle());
}

// Do this step last because the evaluators may borrow the shape
Expand Down

0 comments on commit 5a1bb4a

Please sign in to comment.