Skip to content

Commit

Permalink
Feat/custom training loop & benchmark API & fix BatchShuffledDataload…
Browse files Browse the repository at this point in the history
…erIterator next() call get() twice. (#163)

Co-authored-by: Asuka Minato <[email protected]>
  • Loading branch information
L-M-Sherlock and asukaminato0721 authored Mar 7, 2024
1 parent 3d0dd3a commit f1afdd7
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 283 deletions.
12 changes: 11 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.4.6"
version = "0.5.0"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down Expand Up @@ -45,6 +45,7 @@ strum = { version = "0.26.1", features = ["derive"] }
chrono = { version = "0.4.31", default-features = false, features = ["std", "clock"] }
chrono-tz = "0.8.4"
criterion = { version = "0.5.1" }
fern = "0.6.0"
rusqlite = { version = "0.30.0" }

[[bench]]
Expand Down
191 changes: 45 additions & 146 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,27 @@
use burn::data::{
dataloader::{
batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, FixBatchStrategy,
MultiThreadDataLoader, Progress,
batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, FixBatchStrategy, Progress,
},
dataset::{transform::PartialDataset, Dataset},
dataset::Dataset,
};

use rand::{
distributions::Standard,
prelude::{Distribution, SliceRandom},
rngs::StdRng,
Rng, SeedableRng,
};
use rand::{distributions::Standard, prelude::SliceRandom, rngs::StdRng, Rng, SeedableRng};
use std::{
marker::PhantomData,
sync::{Arc, Mutex},
};

pub(crate) struct BatchShuffledDataset<D, I> {
dataset: D,
use crate::{dataset::FSRSDataset, FSRSItem};

pub(crate) struct BatchShuffledDataset<I> {
dataset: Arc<FSRSDataset>,
indices: Vec<usize>,
input: PhantomData<I>,
}

impl<D, I> BatchShuffledDataset<D, I>
where
D: Dataset<I>,
{
impl<FSRSItem> BatchShuffledDataset<FSRSItem> {
/// Creates a new shuffled dataset.
pub fn new(dataset: D, batch_size: usize, rng: &mut StdRng) -> Self {
pub fn new(dataset: Arc<FSRSDataset>, batch_size: usize, rng: &mut StdRng) -> Self {
let len = dataset.len();

// Calculate the number of batches
Expand Down Expand Up @@ -57,22 +50,22 @@ where
}

/// Creates a new shuffled dataset with a fixed seed.
pub fn with_seed(dataset: D, batch_size: usize, seed: u64) -> Self {
pub fn with_seed(dataset: Arc<FSRSDataset>, batch_size: usize, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
Self::new(dataset, batch_size, &mut rng)
}
}

impl<D, I> Dataset<I> for BatchShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let Some(index) = self.indices.get(index) else {
impl Dataset<FSRSItem> for BatchShuffledDataset<FSRSItem> {
fn get(&self, index: usize) -> Option<FSRSItem> {
let Some(shuffled_index) = self.indices.get(index) else {
return None;
};
self.dataset.get(*index)
// info!(
// "original index: {}, shuffled index: {}",
// index, shuffled_index
// );
self.dataset.get(*shuffled_index)
}

fn len(&self) -> usize {
Expand All @@ -83,9 +76,9 @@ where
/// A data loader that can be used to iterate over a dataset in batches.
pub struct BatchShuffledDataLoader<I, O> {
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
dataset: Arc<FSRSDataset>,
batcher: Arc<dyn Batcher<I, O>>,
rng: Option<Mutex<rand::rngs::StdRng>>,
rng: Mutex<rand::rngs::StdRng>,
batch_size: usize,
}

Expand All @@ -105,16 +98,16 @@ impl<I, O> BatchShuffledDataLoader<I, O> {
/// The batch data loader.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
dataset: Arc<FSRSDataset>,
batcher: Arc<dyn Batcher<I, O>>,
rng: Option<rand::rngs::StdRng>,
rng: rand::rngs::StdRng,
batch_size: usize,
) -> Self {
Self {
strategy,
dataset,
batcher,
rng: rng.map(Mutex::new),
rng: Mutex::new(rng),
batch_size,
}
}
Expand All @@ -128,77 +121,21 @@ struct BatchShuffledDataloaderIterator<I, O> {
batcher: Arc<dyn Batcher<I, O>>,
}

impl<I, O> BatchShuffledDataLoader<I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + Sync + Clone + 'static,
{
/// Creates a new multi-threaded batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
/// * `num_threads` - The number of threads.
///
/// # Returns
///
/// The multi-threaded batch data loader.
pub fn multi_thread(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
num_threads: usize,
mut rng: Option<rand::rngs::StdRng>,
batch_size: usize,
) -> MultiThreadDataLoader<O> {
let datasets = PartialDataset::split(dataset, num_threads);

let mut dataloaders: Vec<Arc<dyn DataLoader<_> + Send + Sync>> =
Vec::with_capacity(num_threads);

// Create more rngs from the first one, one for each new dataloader.
let rngs = (0..num_threads).map(|_| {
rng.as_mut()
.map(|rng| StdRng::seed_from_u64(Distribution::sample(&Standard, rng)))
});

for (dataset, rng) in datasets.into_iter().zip(rngs) {
let strategy = strategy.new_like();
let dataloader = Self::new(
strategy,
Arc::new(dataset),
batcher.clone(),
rng,
batch_size,
);
let dataloader = Arc::new(dataloader);
dataloaders.push(dataloader);
}
MultiThreadDataLoader::new(dataloaders)
}
}

impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O>
for BatchShuffledDataLoader<I, O>
where
BatchShuffledDataset<I>: Dataset<I>,
{
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
// When starting a new iteration, we first check if the dataloader was created with an rng,
// implying that we should shuffle the dataset beforehand, while advancing the current
// rng to ensure that each new iteration shuffles the dataset differently.
let dataset = match &self.rng {
Some(rng) => {
let mut rng = rng.lock().unwrap();

Arc::new(BatchShuffledDataset::with_seed(
self.dataset.clone(),
self.batch_size,
rng.sample(Standard),
))
}
None => self.dataset.clone(),
};
let mut rng = self.rng.lock().unwrap();
let dataset = Arc::new(BatchShuffledDataset::with_seed(
self.dataset.clone(),
self.batch_size,
rng.sample(Standard),
));
Box::new(BatchShuffledDataloaderIterator::new(
self.strategy.new_like(),
dataset,
Expand All @@ -211,7 +148,10 @@ impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O>
}
}

impl<I, O> BatchShuffledDataloaderIterator<I, O> {
impl<I: 'static, O> BatchShuffledDataloaderIterator<I, O>
where
BatchShuffledDataset<I>: Dataset<I>,
{
/// Creates a new batch data loader iterator.
///
/// # Arguments
Expand All @@ -225,7 +165,7 @@ impl<I, O> BatchShuffledDataloaderIterator<I, O> {
/// The batch data loader iterator.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
dataset: Arc<BatchShuffledDataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
) -> Self {
Self {
Expand Down Expand Up @@ -271,14 +211,13 @@ impl<I, O> DataLoaderIterator<O> for BatchShuffledDataloaderIterator<I, O> {
pub struct BatchShuffledDataLoaderBuilder<I, O> {
strategy: Option<Box<dyn BatchStrategy<I>>>,
batcher: Arc<dyn Batcher<I, O>>,
num_threads: Option<usize>,
shuffle: Option<u64>,
}

impl<I, O> BatchShuffledDataLoaderBuilder<I, O>
where
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Sync + Clone + std::fmt::Debug + 'static,
BatchShuffledDataset<I>: Dataset<I>,
{
/// Creates a new data loader builder.
///
Expand All @@ -296,8 +235,6 @@ where
Self {
batcher: Arc::new(batcher),
strategy: None,
num_threads: None,
shuffle: None,
}
}

Expand All @@ -316,36 +253,6 @@ where
self
}

/// Sets the seed for shuffling.
///
/// Each time the dataloader starts a new iteration, the dataset will be shuffled.
///
/// # Arguments
///
/// * `seed` - The seed.
///
/// # Returns
///
/// The data loader builder.
pub const fn shuffle(mut self, seed: u64) -> Self {
self.shuffle = Some(seed);
self
}

/// Sets the number of workers.
///
/// # Arguments
///
/// * `num_workers` - The number of workers.
///
/// # Returns
///
/// The data loader builder.
pub const fn num_workers(mut self, num_workers: usize) -> Self {
self.num_threads = Some(num_workers);
self
}

/// Builds the data loader.
///
/// # Arguments
Expand All @@ -355,27 +262,19 @@ where
/// # Returns
///
/// The data loader.
pub fn build<D>(self, dataset: D, batch_size: usize) -> Arc<dyn DataLoader<O>>
where
D: Dataset<I> + 'static,
{
pub fn build(
self,
dataset: FSRSDataset,
batch_size: usize,
seed: u64,
) -> Arc<dyn DataLoader<O>> {
let dataset = Arc::new(dataset);

let rng = self.shuffle.map(StdRng::seed_from_u64);
let rng = StdRng::seed_from_u64(seed);
let strategy = match self.strategy {
Some(strategy) => strategy,
None => Box::new(FixBatchStrategy::new(1)),
};
if let Some(num_threads) = self.num_threads {
return Arc::new(BatchShuffledDataLoader::multi_thread(
strategy,
dataset,
self.batcher,
num_threads,
rng,
batch_size,
));
}

Arc::new(BatchShuffledDataLoader::new(
strategy,
Expand All @@ -395,7 +294,7 @@ mod tests {
#[test]
fn batch_shuffle() {
use crate::dataset::FSRSDataset;
let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs());
let dataset = Arc::new(FSRSDataset::from(anki21_sample_file_converted_to_fsrs()));
let batch_size = 10;
let seed = 42;
let batch_shuffled_dataset = BatchShuffledDataset::with_seed(dataset, batch_size, seed);
Expand Down
3 changes: 1 addition & 2 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use burn::{lr_scheduler::LrScheduler, tensor::backend::Backend, LearningRate};
use log::info;
#[derive(Clone, Debug)]
pub(crate) struct CosineAnnealingLR {
t_max: f64,
Expand Down Expand Up @@ -49,7 +48,7 @@ impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
self.t_max,
self.eta_min,
);
info!("lr: {}", self.current_lr);
// info!("lr: {}", self.current_lr);
self.current_lr
}

Expand Down
2 changes: 2 additions & 0 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use burn::{
data::dataset::Dataset,
tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor},
};

use serde::{Deserialize, Serialize};

/// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds
Expand Down Expand Up @@ -142,6 +143,7 @@ impl Dataset<FSRSItem> for FSRSDataset {
}

fn get(&self, index: usize) -> Option<FSRSItem> {
// info!("get {}", index);
self.items.get(index).cloned()
}
}
Expand Down
Loading

0 comments on commit f1afdd7

Please sign in to comment.