diff --git a/Cargo.lock b/Cargo.lock index 7fc3a0b2..a63302a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1055,7 +1055,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.5.0" +version = "0.5.1" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 7539c314..bc90171d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.5.0" +version = "0.5.1" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/training.rs b/src/training.rs index 394ebc44..4825c55a 100644 --- a/src/training.rs +++ b/src/training.rs @@ -214,9 +214,6 @@ impl FSRS { }; let n_splits = 5; - if let Some(progress) = &progress { - progress.lock().unwrap().splits = vec![ProgressState::default(); n_splits]; - } let average_recall = calculate_average_recall(&items); let (pre_trainset, trainsets, testset) = split_data(items, n_splits); let initial_stability = pretrain(pre_trainset, average_recall).map_err(|e| { @@ -239,18 +236,32 @@ impl FSRS { AdamConfig::new(), ); - let weight_sets: Result>> = (0..n_splits) + let trainsets: Vec> = (0..n_splits) .into_par_iter() .map(|i| { - let trainset = trainsets + trainsets .par_iter() .enumerate() .filter(|&(j, _)| j != i) .flat_map(|(_, trainset)| trainset.clone()) - .collect(); + .collect() + }) + .collect(); + + if let Some(progress) = &progress { + let mut progress_states = vec![ProgressState::default(); n_splits]; + for (i, progress_state) in progress_states.iter_mut().enumerate() { + progress_state.epoch_total = config.num_epochs; + progress_state.items_total = trainsets[i].len(); + } + progress.lock().unwrap().splits = progress_states + } + let weight_sets: Result>> = (0..n_splits) + .into_par_iter() + .map(|i| { let model = train::>( - trainset, + trainsets[i].clone(), testset.clone(), &config, self.device(), @@ -280,10 +291,8 @@ impl FSRS { .map(|&sum| sum / n_splits as f32) .collect(); - for weight in &average_parameters { - if !weight.is_finite() { - return Err(FSRSError::InvalidInput); - } + if average_parameters.iter().any(|weight| weight.is_infinite()) { + return Err(FSRSError::InvalidInput); } Ok(average_parameters) @@ -499,25 +508,36 @@ mod tests { thread::sleep(Duration::from_millis(10)); let guard = progress.lock().unwrap(); finished = guard.finished(); - info!("progress: {}/{}", guard.current(), guard.total()); + println!("progress: {}/{}", guard.current(), guard.total()); } }); - if let Some(progress2) = &progress2 { - progress2.lock().unwrap().splits = vec![ProgressState::default(); n_splits]; - } - - let parameters_sets: Vec> = (0..n_splits) + let trainsets: Vec> = (0..n_splits) .into_par_iter() .map(|i| { - let trainset = trainsets + trainsets .par_iter() .enumerate() .filter(|&(j, _)| j != i) .flat_map(|(_, trainset)| trainset.clone()) - .collect(); + .collect() + }) + .collect(); + + if let Some(progress2) = &progress2 { + let mut progress_states = vec![ProgressState::default(); n_splits]; + for (i, progress_state) in progress_states.iter_mut().enumerate() { + progress_state.epoch_total = config.num_epochs; + progress_state.items_total = trainsets[i].len(); + } + progress2.lock().unwrap().splits = progress_states + } + + let parameters_sets: Vec> = (0..n_splits) + .into_par_iter() + .map(|i| { let model = train::( - trainset, + trainsets[i].clone(), items.clone(), &config, device,