diff --git a/Cargo.lock b/Cargo.lock index 034a94d2..4fc2ae68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1046,7 +1046,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.5" +version = "0.4.6" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 49e14aa9..fe90b364 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.4.5" +version = "0.4.6" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/training.rs b/src/training.rs index 3022616c..b1c7397a 100644 --- a/src/training.rs +++ b/src/training.rs @@ -374,7 +374,9 @@ fn train( Aggregate::Mean, Direction::Lowest, Split::Valid, - StoppingCondition::NoImprovementSince { n_epochs: 1 }, + StoppingCondition::NoImprovementSince { + n_epochs: config.num_epochs, + }, )) .devices(vec![device]) .num_epochs(config.num_epochs) @@ -457,8 +459,11 @@ mod tests { let device = NdArrayDevice::Cpu; let items = anki21_sample_file_converted_to_fsrs(); let (pre_trainset, trainsets, testset) = split_data(items.clone(), n_splits); - let average_recall = calculate_average_recall(&pre_trainset); + let items = [pre_trainset.clone(), testset.clone()].concat(); + let average_recall = calculate_average_recall(&items); + dbg!(average_recall); let initial_stability = pretrain(pre_trainset, average_recall).unwrap(); + dbg!(initial_stability); let config = TrainingConfig::new( ModelConfig { freeze_stability: true, @@ -492,6 +497,10 @@ mod tests { .par_iter() .map(|&sum| sum / n_splits as f32) .collect(); - dbg!(average_parameters); + dbg!(&average_parameters); + + let fsrs = FSRS::new(Some(&average_parameters)).unwrap(); + let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); + dbg!(&metrics); } }