Skip to content

Commit

Permalink
Merge pull request #39 from sportsball-ai/arbitrary-work-waiter
Browse files Browse the repository at this point in the history
Allow to replace yielding behavior with arbitrary future generator
  • Loading branch information
cksac authored Oct 29, 2024
2 parents 5c5dd8e + c124896 commit f8b87c8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 39 deletions.
37 changes: 17 additions & 20 deletions src/cached.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::runtime::{yield_now, Arc, Mutex};
use crate::BatchFn;
use crate::runtime::{Arc, Mutex};
use crate::{yield_fn, BatchFn, WaitForWorkFn};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::hash::{BuildHasher, Hash};
Expand Down Expand Up @@ -72,7 +72,7 @@ where
{
state: Arc<Mutex<State<K, V, C>>>,
load_fn: Arc<Mutex<F>>,
yield_count: usize,
wait_for_work_fn: Arc<dyn WaitForWorkFn>,
max_batch_size: usize,
}

Expand All @@ -88,7 +88,7 @@ where
state: self.state.clone(),
max_batch_size: self.max_batch_size,
load_fn: self.load_fn.clone(),
yield_count: self.yield_count,
wait_for_work_fn: self.wait_for_work_fn.clone(),
}
}
}
Expand Down Expand Up @@ -117,7 +117,7 @@ where
state: Arc::new(Mutex::new(State::with_cache(cache))),
load_fn: Arc::new(Mutex::new(load_fn)),
max_batch_size: 200,
yield_count: 10,
wait_for_work_fn: Arc::new(yield_fn(10)),
}
}

Expand All @@ -127,10 +127,17 @@ where
}

pub fn with_yield_count(mut self, yield_count: usize) -> Self {
self.yield_count = yield_count;
self.wait_for_work_fn = Arc::new(yield_fn(yield_count));
self
}

/// Replaces the yielding for work behavior with an arbitrary future. Rather than yielding
/// the runtime repeatedly this will generate and `.await` a future of your choice.
/// ***This is incompatible with*** [`Self::with_yield_count()`].
pub fn with_custom_wait_for_work(mut self, wait_for_work_fn: impl WaitForWorkFn) {
self.wait_for_work_fn = Arc::new(wait_for_work_fn);
}

pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
Expand All @@ -141,7 +148,7 @@ where
return Ok((*v).clone());
}

if state.pending.get(&key).is_none() {
if !state.pending.contains(&key) {
state.pending.insert(key.clone());
if state.pending.len() >= self.max_batch_size {
let keys = state.pending.drain().collect::<Vec<K>>();
Expand All @@ -159,12 +166,7 @@ where
}
drop(state);

// yield for other load to append request
let mut i = 0;
while i < self.yield_count {
yield_now().await;
i += 1;
}
(self.wait_for_work_fn)().await;

let mut state = self.state.lock().await;
if let Some(v) = state.completed.get(&key) {
Expand Down Expand Up @@ -200,7 +202,7 @@ where
ret.insert(key, v);
continue;
}
if state.pending.get(&key).is_none() {
if !state.pending.contains(&key) {
state.pending.insert(key.clone());
if state.pending.len() >= self.max_batch_size {
let keys = state.pending.drain().collect::<Vec<K>>();
Expand All @@ -216,12 +218,7 @@ where
}
drop(state);

// yield for other load to append request
let mut i = 0;
while i < self.yield_count {
yield_now().await;
i += 1;
}
(self.wait_for_work_fn)().await;

if !rest.is_empty() {
let mut state = self.state.lock().await;
Expand Down
24 changes: 24 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,27 @@ pub mod non_cached;
mod runtime;

pub use batch_fn::BatchFn;

use std::{future::Future, pin::Pin};

/// A trait alias. Read as "a function which returns a pinned box containing a future"
pub trait WaitForWorkFn:
Fn() -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> + Send + Sync + 'static
{
}

impl<T> WaitForWorkFn for T where
T: Fn() -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> + Send + Sync + 'static
{
}

pub(crate) fn yield_fn(count: usize) -> impl WaitForWorkFn {
move || {
Box::pin(async move {
// yield for other load to append request
for _ in 0..count {
runtime::yield_now().await;
}
})
}
}
35 changes: 16 additions & 19 deletions src/non_cached.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::runtime::{yield_now, Arc, Mutex};
use crate::BatchFn;
use crate::runtime::{Arc, Mutex};
use crate::{yield_fn, BatchFn, WaitForWorkFn};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::hash::Hash;
Expand Down Expand Up @@ -37,7 +37,7 @@ where
{
state: Arc<Mutex<State<K, V>>>,
load_fn: Arc<Mutex<F>>,
yield_count: usize,
wait_for_work_fn: Arc<dyn WaitForWorkFn>,
max_batch_size: usize,
}

Expand All @@ -52,7 +52,7 @@ where
state: self.state.clone(),
load_fn: self.load_fn.clone(),
max_batch_size: self.max_batch_size,
yield_count: self.yield_count,
wait_for_work_fn: self.wait_for_work_fn.clone(),
}
}
}
Expand All @@ -68,7 +68,7 @@ where
state: Arc::new(Mutex::new(State::new())),
load_fn: Arc::new(Mutex::new(load_fn)),
max_batch_size: 200,
yield_count: 10,
wait_for_work_fn: Arc::new(yield_fn(10)),
}
}

Expand All @@ -78,10 +78,17 @@ where
}

pub fn with_yield_count(mut self, yield_count: usize) -> Self {
self.yield_count = yield_count;
self.wait_for_work_fn = Arc::new(yield_fn(yield_count));
self
}

/// Replaces the yielding for work behavior with an arbitrary future. Rather than yielding
/// the runtime repeatedly this will generate and `.await` a future of your choice.
/// ***This is incompatible with*** [`Self::with_yield_count()`].
pub fn with_custom_wait_for_work(mut self, wait_for_work_fn: impl WaitForWorkFn) {
self.wait_for_work_fn = Arc::new(wait_for_work_fn);
}

pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
Expand Down Expand Up @@ -122,16 +129,11 @@ where
}
drop(state);

// yield for other load to append request
let mut i = 0;
while i < self.yield_count {
yield_now().await;
i += 1;
}
(self.wait_for_work_fn)().await;

let mut state = self.state.lock().await;

if state.completed.get(&request_id).is_none() {
if !state.completed.contains_key(&request_id) {
let batch = state.pending.drain().collect::<HashMap<usize, K>>();
if !batch.is_empty() {
let keys: Vec<K> = batch
Expand Down Expand Up @@ -208,12 +210,7 @@ where

drop(state);

// yield for other load to append request
let mut i = 0;
while i < self.yield_count {
yield_now().await;
i += 1;
}
(self.wait_for_work_fn)().await;

let mut state = self.state.lock().await;

Expand Down

0 comments on commit f8b87c8

Please sign in to comment.