diff --git a/Cargo.lock b/Cargo.lock index 4fc2ae68..7fc3a0b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -986,6 +986,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "fern" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9f0c14694cbd524c8720dd69b0e3179344f04ebb5f90f2e4a440c6ea3b2f1ee" +dependencies = [ + "log", +] + [[package]] name = "flate2" version = "1.0.28" @@ -1046,12 +1055,13 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.6" +version = "0.5.0" dependencies = [ "burn", "chrono", "chrono-tz", "criterion", + "fern", "itertools 0.12.1", "log", "ndarray", diff --git a/Cargo.toml b/Cargo.toml index fe90b364..7539c314 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.4.6" +version = "0.5.0" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" @@ -45,6 +45,7 @@ strum = { version = "0.26.1", features = ["derive"] } chrono = { version = "0.4.31", default-features = false, features = ["std", "clock"] } chrono-tz = "0.8.4" criterion = { version = "0.5.1" } +fern = "0.6.0" rusqlite = { version = "0.30.0" } [[bench]] diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index e29a2b96..553287d0 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -1,34 +1,27 @@ use burn::data::{ dataloader::{ - batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, FixBatchStrategy, - MultiThreadDataLoader, Progress, + batcher::Batcher, BatchStrategy, DataLoader, DataLoaderIterator, FixBatchStrategy, Progress, }, - dataset::{transform::PartialDataset, Dataset}, + dataset::Dataset, }; -use rand::{ - distributions::Standard, - prelude::{Distribution, SliceRandom}, - rngs::StdRng, - Rng, SeedableRng, -}; +use rand::{distributions::Standard, prelude::SliceRandom, rngs::StdRng, Rng, SeedableRng}; use std::{ marker::PhantomData, sync::{Arc, Mutex}, }; -pub(crate) struct BatchShuffledDataset { - dataset: D, +use crate::{dataset::FSRSDataset, FSRSItem}; + +pub(crate) struct BatchShuffledDataset { + dataset: Arc, indices: Vec, input: PhantomData, } -impl BatchShuffledDataset -where - D: Dataset, -{ +impl BatchShuffledDataset { /// Creates a new shuffled dataset. - pub fn new(dataset: D, batch_size: usize, rng: &mut StdRng) -> Self { + pub fn new(dataset: Arc, batch_size: usize, rng: &mut StdRng) -> Self { let len = dataset.len(); // Calculate the number of batches @@ -57,22 +50,22 @@ where } /// Creates a new shuffled dataset with a fixed seed. - pub fn with_seed(dataset: D, batch_size: usize, seed: u64) -> Self { + pub fn with_seed(dataset: Arc, batch_size: usize, seed: u64) -> Self { let mut rng = StdRng::seed_from_u64(seed); Self::new(dataset, batch_size, &mut rng) } } -impl Dataset for BatchShuffledDataset -where - D: Dataset, - I: Clone + Send + Sync, -{ - fn get(&self, index: usize) -> Option { - let Some(index) = self.indices.get(index) else { +impl Dataset for BatchShuffledDataset { + fn get(&self, index: usize) -> Option { + let Some(shuffled_index) = self.indices.get(index) else { return None; }; - self.dataset.get(*index) + // info!( + // "original index: {}, shuffled index: {}", + // index, shuffled_index + // ); + self.dataset.get(*shuffled_index) } fn len(&self) -> usize { @@ -83,9 +76,9 @@ where /// A data loader that can be used to iterate over a dataset in batches. pub struct BatchShuffledDataLoader { strategy: Box>, - dataset: Arc>, + dataset: Arc, batcher: Arc>, - rng: Option>, + rng: Mutex, batch_size: usize, } @@ -105,16 +98,16 @@ impl BatchShuffledDataLoader { /// The batch data loader. pub fn new( strategy: Box>, - dataset: Arc>, + dataset: Arc, batcher: Arc>, - rng: Option, + rng: rand::rngs::StdRng, batch_size: usize, ) -> Self { Self { strategy, dataset, batcher, - rng: rng.map(Mutex::new), + rng: Mutex::new(rng), batch_size, } } @@ -128,77 +121,21 @@ struct BatchShuffledDataloaderIterator { batcher: Arc>, } -impl BatchShuffledDataLoader -where - I: Send + Sync + Clone + 'static, - O: Send + Sync + Clone + 'static, -{ - /// Creates a new multi-threaded batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `num_threads` - The number of threads. - /// - /// # Returns - /// - /// The multi-threaded batch data loader. - pub fn multi_thread( - strategy: Box>, - dataset: Arc>, - batcher: Arc>, - num_threads: usize, - mut rng: Option, - batch_size: usize, - ) -> MultiThreadDataLoader { - let datasets = PartialDataset::split(dataset, num_threads); - - let mut dataloaders: Vec + Send + Sync>> = - Vec::with_capacity(num_threads); - - // Create more rngs from the first one, one for each new dataloader. - let rngs = (0..num_threads).map(|_| { - rng.as_mut() - .map(|rng| StdRng::seed_from_u64(Distribution::sample(&Standard, rng))) - }); - - for (dataset, rng) in datasets.into_iter().zip(rngs) { - let strategy = strategy.new_like(); - let dataloader = Self::new( - strategy, - Arc::new(dataset), - batcher.clone(), - rng, - batch_size, - ); - let dataloader = Arc::new(dataloader); - dataloaders.push(dataloader); - } - MultiThreadDataLoader::new(dataloaders) - } -} - impl DataLoader for BatchShuffledDataLoader +where + BatchShuffledDataset: Dataset, { fn iter<'a>(&'a self) -> Box + 'a> { // When starting a new iteration, we first check if the dataloader was created with an rng, // implying that we should shuffle the dataset beforehand, while advancing the current // rng to ensure that each new iteration shuffles the dataset differently. - let dataset = match &self.rng { - Some(rng) => { - let mut rng = rng.lock().unwrap(); - - Arc::new(BatchShuffledDataset::with_seed( - self.dataset.clone(), - self.batch_size, - rng.sample(Standard), - )) - } - None => self.dataset.clone(), - }; + let mut rng = self.rng.lock().unwrap(); + let dataset = Arc::new(BatchShuffledDataset::with_seed( + self.dataset.clone(), + self.batch_size, + rng.sample(Standard), + )); Box::new(BatchShuffledDataloaderIterator::new( self.strategy.new_like(), dataset, @@ -211,7 +148,10 @@ impl DataLoader } } -impl BatchShuffledDataloaderIterator { +impl BatchShuffledDataloaderIterator +where + BatchShuffledDataset: Dataset, +{ /// Creates a new batch data loader iterator. /// /// # Arguments @@ -225,7 +165,7 @@ impl BatchShuffledDataloaderIterator { /// The batch data loader iterator. pub fn new( strategy: Box>, - dataset: Arc>, + dataset: Arc>, batcher: Arc>, ) -> Self { Self { @@ -271,14 +211,13 @@ impl DataLoaderIterator for BatchShuffledDataloaderIterator { pub struct BatchShuffledDataLoaderBuilder { strategy: Option>>, batcher: Arc>, - num_threads: Option, - shuffle: Option, } impl BatchShuffledDataLoaderBuilder where I: Send + Sync + Clone + std::fmt::Debug + 'static, O: Send + Sync + Clone + std::fmt::Debug + 'static, + BatchShuffledDataset: Dataset, { /// Creates a new data loader builder. /// @@ -296,8 +235,6 @@ where Self { batcher: Arc::new(batcher), strategy: None, - num_threads: None, - shuffle: None, } } @@ -316,36 +253,6 @@ where self } - /// Sets the seed for shuffling. - /// - /// Each time the dataloader starts a new iteration, the dataset will be shuffled. - /// - /// # Arguments - /// - /// * `seed` - The seed. - /// - /// # Returns - /// - /// The data loader builder. - pub const fn shuffle(mut self, seed: u64) -> Self { - self.shuffle = Some(seed); - self - } - - /// Sets the number of workers. - /// - /// # Arguments - /// - /// * `num_workers` - The number of workers. - /// - /// # Returns - /// - /// The data loader builder. - pub const fn num_workers(mut self, num_workers: usize) -> Self { - self.num_threads = Some(num_workers); - self - } - /// Builds the data loader. /// /// # Arguments @@ -355,27 +262,19 @@ where /// # Returns /// /// The data loader. - pub fn build(self, dataset: D, batch_size: usize) -> Arc> - where - D: Dataset + 'static, - { + pub fn build( + self, + dataset: FSRSDataset, + batch_size: usize, + seed: u64, + ) -> Arc> { let dataset = Arc::new(dataset); - let rng = self.shuffle.map(StdRng::seed_from_u64); + let rng = StdRng::seed_from_u64(seed); let strategy = match self.strategy { Some(strategy) => strategy, None => Box::new(FixBatchStrategy::new(1)), }; - if let Some(num_threads) = self.num_threads { - return Arc::new(BatchShuffledDataLoader::multi_thread( - strategy, - dataset, - self.batcher, - num_threads, - rng, - batch_size, - )); - } Arc::new(BatchShuffledDataLoader::new( strategy, @@ -395,7 +294,7 @@ mod tests { #[test] fn batch_shuffle() { use crate::dataset::FSRSDataset; - let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs()); + let dataset = Arc::new(FSRSDataset::from(anki21_sample_file_converted_to_fsrs())); let batch_size = 10; let seed = 42; let batch_shuffled_dataset = BatchShuffledDataset::with_seed(dataset, batch_size, seed); diff --git a/src/cosine_annealing.rs b/src/cosine_annealing.rs index 4d6fbe71..325acf10 100644 --- a/src/cosine_annealing.rs +++ b/src/cosine_annealing.rs @@ -1,5 +1,4 @@ use burn::{lr_scheduler::LrScheduler, tensor::backend::Backend, LearningRate}; -use log::info; #[derive(Clone, Debug)] pub(crate) struct CosineAnnealingLR { t_max: f64, @@ -49,7 +48,7 @@ impl LrScheduler for CosineAnnealingLR { self.t_max, self.eta_min, ); - info!("lr: {}", self.current_lr); + // info!("lr: {}", self.current_lr); self.current_lr } diff --git a/src/dataset.rs b/src/dataset.rs index b6f5a7ca..82ec6e8a 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -5,6 +5,7 @@ use burn::{ data::dataset::Dataset, tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor}, }; + use serde::{Deserialize, Serialize}; /// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds @@ -142,6 +143,7 @@ impl Dataset for FSRSDataset { } fn get(&self, index: usize) -> Option { + // info!("get {}", index); self.items.get(index).cloned() } } diff --git a/src/inference.rs b/src/inference.rs index 4c8ff7ef..68ac0f0b 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::ops::{Add, Sub}; use crate::model::{Get, MemoryStateTensors, FSRS}; +use burn::nn::loss::Reduction; use burn::tensor::{Data, Shape, Tensor}; use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend}; @@ -255,7 +256,7 @@ impl FSRS { .sqrt(); let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); - let loss = BCELoss::new().forward(all_retention, all_labels); + let loss = BCELoss::new().forward(all_retention, all_labels, Reduction::Mean); Ok(ModelEvaluation { log_loss: loss.to_data().value[0].elem(), rmse_bins: rmse, diff --git a/src/training.rs b/src/training.rs index b1c7397a..394ebc44 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,36 +1,31 @@ -use crate::batch_shuffle::{BatchShuffledDataLoaderBuilder, BatchShuffledDataset}; +use crate::batch_shuffle::BatchShuffledDataLoaderBuilder; use crate::cosine_annealing::CosineAnnealingLR; -use crate::dataset::{split_data, FSRSBatch, FSRSBatcher, FSRSDataset, FSRSItem}; +use crate::dataset::{split_data, FSRSBatcher, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::pre_training::pretrain; use crate::weight_clipper::weight_clipper; use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff; + use burn::data::dataloader::DataLoaderBuilder; -use burn::module::Module; -use burn::optim::AdamConfig; -use burn::record::{FullPrecisionSettings, PrettyJsonFileRecorder, Recorder}; +use burn::lr_scheduler::LrScheduler; +use burn::module::AutodiffModule; +use burn::nn::loss::Reduction; +use burn::optim::Optimizer; +use burn::optim::{AdamConfig, GradientsParams}; use burn::tensor::backend::Backend; use burn::tensor::{Int, Tensor}; -use burn::train::logger::InMemoryMetricLogger; -use burn::train::metric::store::{Aggregate, Direction, Split}; -use burn::train::metric::LossMetric; use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress}; - -use burn::train::{ - ClassificationOutput, MetricEarlyStoppingStrategy, StoppingCondition, TrainOutput, TrainStep, - TrainingInterrupter, ValidStep, -}; -use burn::{ - config::Config, module::Param, tensor::backend::AutodiffBackend, train::LearnerBuilder, -}; +use burn::train::TrainingInterrupter; +use burn::{config::Config, module::Param, tensor::backend::AutodiffBackend}; use core::marker::PhantomData; use log::info; + use rayon::prelude::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, }; -use std::path::Path; + use std::sync::{Arc, Mutex}; pub struct BCELoss { @@ -43,11 +38,20 @@ impl BCELoss { backend: PhantomData, } } - pub fn forward(&self, retentions: Tensor, labels: Tensor) -> Tensor { + pub fn forward( + &self, + retentions: Tensor, + labels: Tensor, + mean: Reduction, + ) -> Tensor { let loss = labels.clone() * retentions.clone().log() + (-labels + 1) * (-retentions + 1).log(); // info!("loss: {}", &loss); - loss.mean().neg() + match mean { + Reduction::Mean => loss.mean().neg(), + Reduction::Sum => loss.sum().neg(), + Reduction::Auto => loss.neg(), + } } } @@ -58,15 +62,13 @@ impl Model { r_historys: Tensor, delta_ts: Tensor, labels: Tensor, - ) -> ClassificationOutput { + reduce: Reduction, + ) -> Tensor { // info!("t_historys: {}", &t_historys); // info!("r_historys: {}", &r_historys); let state = self.forward(t_historys, r_historys, None); let retention = self.power_forgetting_curve(delta_ts, state.stability); - let logits = - Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).unsqueeze::<2>(); - let loss = BCELoss::new().forward(retention, labels.clone().float()); - ClassificationOutput::new(loss, logits, labels) + BCELoss::new().forward(retention, labels.float(), reduce) } } @@ -82,47 +84,6 @@ impl Model { } } -impl TrainStep, ClassificationOutput> for Model { - fn step(&self, batch: FSRSBatch) -> TrainOutput> { - let item = self.forward_classification( - batch.t_historys, - batch.r_historys, - batch.delta_ts, - batch.labels, - ); - let mut gradients = item.loss.backward(); - - if self.config.freeze_stability { - gradients = self.freeze_initial_stability(gradients); - } - - TrainOutput::new(self, gradients, item) - } - - fn optimize(self, optim: &mut O, lr: f64, grads: burn::optim::GradientsParams) -> Self - where - B: AutodiffBackend, - O: burn::optim::Optimizer, - B1: burn::tensor::backend::AutodiffBackend, - Self: burn::module::AutodiffModule, - { - let mut model = optim.step(lr, self, grads); - model.w = Param::from(weight_clipper(model.w.val())); - model - } -} - -impl ValidStep, ClassificationOutput> for Model { - fn step(&self, batch: FSRSBatch) -> ClassificationOutput { - self.forward_classification( - batch.t_historys, - batch.r_historys, - batch.delta_ts, - batch.labels, - ) - } -} - #[derive(Debug, Default, Clone)] pub struct ProgressState { pub epoch: usize, @@ -327,6 +288,24 @@ impl FSRS { Ok(average_parameters) } + + pub fn benchmark(&self, train_set: Vec, test_set: Vec) -> Vec { + let average_recall = calculate_average_recall(&train_set.clone()); + let (pre_train_set, next_train_set) = train_set + .into_iter() + .partition(|item| item.reviews.len() == 2); + let initial_stability = pretrain(pre_train_set, average_recall).unwrap(); + let config = TrainingConfig::new( + ModelConfig { + freeze_stability: true, + initial_stability: Some(initial_stability), + }, + AdamConfig::new(), + ); + let model = train::>(next_train_set, test_set, &config, self.device(), None); + let parameters: Vec = model.unwrap().w.val().to_data().convert().value; + parameters + } } fn train( @@ -343,82 +322,98 @@ fn train( let batcher_train = FSRSBatcher::::new(device.clone()); let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) - .shuffle(config.seed) - .num_workers(config.num_workers) - .build( - BatchShuffledDataset::with_seed( - FSRSDataset::from(trainset), - config.batch_size, - config.seed, - ), - config.batch_size, - ); + .build(FSRSDataset::from(trainset), config.batch_size, config.seed); let batcher_valid = FSRSBatcher::new(device.clone()); let dataloader_valid = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) - .num_workers(config.num_workers) - .build(FSRSDataset::from(testset)); - - let lr_scheduler = CosineAnnealingLR::init(iterations as f64, config.learning_rate); - - let artifact_dir = std::env::var("BURN_LOG"); - - let mut builder = LearnerBuilder::new(&artifact_dir.clone().unwrap_or_default()) - .metric_loggers( - InMemoryMetricLogger::default(), - InMemoryMetricLogger::default(), - ) - .metric_valid_numeric(LossMetric::new()) - .early_stopping(MetricEarlyStoppingStrategy::new::>( - Aggregate::Mean, - Direction::Lowest, - Split::Valid, - StoppingCondition::NoImprovementSince { - n_epochs: config.num_epochs, - }, - )) - .devices(vec![device]) - .num_epochs(config.num_epochs) - .log_to_file(false); - let interrupter = builder.interrupter(); - - if let Some(mut progress) = progress { - progress.interrupter = interrupter.clone(); - builder = builder.renderer(progress); - } else { - // comment out if you want to see text interface - builder = builder.renderer(NoProgress {}); - } + .build(FSRSDataset::from(testset.clone())); + + let mut lr_scheduler = CosineAnnealingLR::init(iterations as f64, config.learning_rate); + let interrupter = TrainingInterrupter::new(); + let mut renderer: Box = match progress { + Some(mut progress) => { + progress.interrupter = interrupter.clone(); + Box::new(progress) + } + None => Box::new(NoProgress {}), + }; + + let mut model: Model = config.model.init(); + let mut optim = config.optimizer.init::>(); + + let mut best_loss = std::f64::INFINITY; + let mut best_model = model.clone(); + for epoch in 1..=config.num_epochs { + let mut iterator = dataloader_train.iter(); + let mut iteration = 0; + while let Some(item) = iterator.next() { + iteration += 1; + let lr = LrScheduler::::step(&mut lr_scheduler); + let progress = iterator.progress(); + let loss = model.forward_classification( + item.t_historys, + item.r_historys, + item.delta_ts, + item.labels, + Reduction::Mean, + ); + let mut gradients = loss.backward(); + if model.config.freeze_stability { + gradients = model.freeze_initial_stability(gradients); + } + let grads = GradientsParams::from_grads(gradients, &model); + model = optim.step(lr, model, grads); + model.w = Param::from(weight_clipper(model.w.val())); + // info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr); + renderer.render_train(TrainingProgress { + progress, + epoch, + epoch_total: config.num_epochs, + iteration, + }); + + if interrupter.should_stop() { + break; + } + } - if artifact_dir.is_ok() { - builder = builder - .log_to_file(true) - .with_file_checkpointer(PrettyJsonFileRecorder::::new()); - } + if interrupter.should_stop() { + break; + } - let learner = builder.build(config.model.init(), config.optimizer.init(), lr_scheduler); + let model_valid = model.valid(); + let mut loss_valid = 0.0; + for batch in dataloader_valid.iter() { + let loss = model_valid.forward_classification( + batch.t_historys, + batch.r_historys, + batch.delta_ts, + batch.labels, + Reduction::Sum, + ); + let loss = loss.into_data().convert::().value[0]; + loss_valid += loss; + + if interrupter.should_stop() { + break; + } + } + loss_valid /= testset.len() as f64; + info!("epoch: {:?} loss: {:?}", epoch, loss_valid); + if loss_valid < best_loss { + best_loss = loss_valid; + best_model = model.clone(); + } + } - let mut model_trained = learner.fit(dataloader_train, dataloader_valid); + info!("best_loss: {:?}", best_loss); if interrupter.should_stop() { return Err(FSRSError::Interrupted); } - info!("trained parameters: {}", &model_trained.w.val()); - model_trained.w = Param::from(weight_clipper(model_trained.w.val())); - info!("clipped parameters: {}", &model_trained.w.val()); - - if let Ok(path) = artifact_dir { - PrettyJsonFileRecorder::::new() - .record( - model_trained.clone().into_record(), - Path::new(&path).join("model"), - ) - .expect("Failed to save trained model"); - } - - Ok(model_trained) + Ok(best_model) } struct NoProgress {} @@ -435,12 +430,17 @@ impl MetricsRenderer for NoProgress { #[cfg(test)] mod tests { + use std::fs::create_dir_all; + use std::path::Path; + use std::thread; + use std::time::Duration; + use super::*; use crate::convertor_tests::anki21_sample_file_converted_to_fsrs; use crate::pre_training::pretrain; use crate::test_helpers::NdArrayAutodiff; use burn::backend::ndarray::NdArrayDevice; - use rayon::prelude::IntoParallelIterator; + use log::LevelFilter; #[test] fn test_calculate_average_recall() { @@ -455,6 +455,26 @@ mod tests { println!("Skipping test in CI"); return; } + + let artifact_dir = std::env::var("BURN_LOG"); + + if let Ok(artifact_dir) = artifact_dir { + let _ = create_dir_all(&artifact_dir); + let log_file = Path::new(&artifact_dir).join("training.log"); + fern::Dispatch::new() + .format(|out, message, record| { + out.finish(format_args!( + "[{}][{}] {}", + record.target(), + record.level(), + message + )) + }) + .level(LevelFilter::Info) + .chain(fern::log_file(log_file).unwrap()) + .apply() + .unwrap(); + } let n_splits = 5; let device = NdArrayDevice::Cpu; let items = anki21_sample_file_converted_to_fsrs(); @@ -471,6 +491,21 @@ mod tests { }, AdamConfig::new(), ); + let progress = CombinedProgressState::new_shared(); + let progress2 = Some(progress.clone()); + thread::spawn(move || { + let mut finished = false; + while !finished { + thread::sleep(Duration::from_millis(10)); + let guard = progress.lock().unwrap(); + finished = guard.finished(); + info!("progress: {}/{}", guard.current(), guard.total()); + } + }); + + if let Some(progress2) = &progress2 { + progress2.lock().unwrap().splits = vec![ProgressState::default(); n_splits]; + } let parameters_sets: Vec> = (0..n_splits) .into_par_iter() @@ -481,8 +516,13 @@ mod tests { .filter(|&(j, _)| j != i) .flat_map(|(_, trainset)| trainset.clone()) .collect(); - let model = - train::(trainset, testset.clone(), &config, device, None); + let model = train::( + trainset, + items.clone(), + &config, + device, + progress2.clone().map(|p| ProgressCollector::new(p, i)), + ); model.unwrap().w.val().to_data().convert().value }) .collect();