From 9a20ec2fdc7de9c0a963d2ac35b838bdbc8b78c7 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 4 Mar 2024 17:57:35 +0800 Subject: [PATCH] Fix/calculate average recall in item level & fix laplace smoothing & use f64 in pretrain (#161) --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/pre_training.rs | 110 ++++++++++++++++++++++++-------------------- src/training.rs | 10 +++- 4 files changed, 71 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index da467917..034a94d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1046,7 +1046,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.4" +version = "0.4.5" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index c9cb6a2c..49e14aa9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.4.4" +version = "0.4.5" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/pre_training.rs b/src/pre_training.rs index c7964edc..872c1dc1 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -53,9 +53,9 @@ fn create_pretrain_data(fsrs_items: Vec) -> HashMap() / ratings.len() as f64; data.push(AverageRecall { - delta_t: *second_delta_t as f32, - recall: avg as f32, - count: ratings.len() as f32, + delta_t: *second_delta_t as f64, + recall: avg, + count: ratings.len() as f64, }) } @@ -69,9 +69,9 @@ fn create_pretrain_data(fsrs_items: Vec) -> HashMap() as u32; + let count = data.iter().map(|d| d.count).sum::() as u32; (*first_rating, count) }) .collect() } -fn power_forgetting_curve(t: &Array1, s: f32) -> Array1 { - (t / s * FACTOR as f32 + 1.0).mapv(|v| v.powf(DECAY as f32)) +fn power_forgetting_curve(t: &Array1, s: f64) -> Array1 { + (t / s * FACTOR + 1.0).mapv(|v| v.powf(DECAY)) } fn loss( - delta_t: &Array1, - recall: &Array1, - count: &Array1, - init_s0: f32, - default_s0: f32, -) -> f32 { + delta_t: &Array1, + recall: &Array1, + count: &Array1, + init_s0: f64, + default_s0: f64, +) -> f64 { let y_pred = power_forgetting_curve(delta_t, init_s0); let logloss = (-(recall * y_pred.clone().mapv_into(|v| v.ln()) + (1.0 - recall) * (1.0 - &y_pred).mapv_into(|v| v.ln())) @@ -113,25 +113,23 @@ fn search_parameters( average_recall: f32, ) -> HashMap { let mut optimal_stabilities = HashMap::new(); - let epsilon = f32::EPSILON; + let epsilon = f64::EPSILON; for (first_rating, data) in &mut pretrainset { let r_s0_default: HashMap = R_S0_DEFAULT_ARRAY.iter().cloned().collect(); - let default_s0 = r_s0_default[first_rating]; + let default_s0 = r_s0_default[first_rating] as f64; let delta_t = Array1::from_iter(data.iter().map(|d| d.delta_t)); + let count = Array1::from_iter(data.iter().map(|d| d.count)); let recall = { // Laplace smoothing // (real_recall * n + average_recall * 1) / (n + 1) // https://github.com/open-spaced-repetition/fsrs4anki/pull/358/files#diff-35b13c8e3466e8bd1231a51c71524fc31a945a8f332290726214d3a6fa7f442aR491 let real_recall = Array1::from_iter(data.iter().map(|d| d.recall)); - let n = data.iter().map(|d| d.count).sum::(); - (real_recall * n + average_recall) / (n + 1.0) + (real_recall * count.clone() + average_recall as f64) / (count.clone() + 1.0) }; - let count = Array1::from_iter(data.iter().map(|d| d.count)); - - let mut low = S_MIN; - let mut high = INIT_S_MAX; - let mut optimal_s = 1.0; + let mut low = S_MIN as f64; + let mut high = INIT_S_MAX as f64; + let mut optimal_s = default_s0; let mut iter = 0; while high - low > epsilon && iter < 1000 { @@ -151,7 +149,7 @@ fn search_parameters( optimal_s = (high + low) / 2.0; } - optimal_stabilities.insert(*first_rating, optimal_s); + optimal_stabilities.insert(*first_rating, optimal_s as f32); } optimal_stabilities @@ -274,7 +272,7 @@ mod tests { use burn::tensor::Data; use super::*; - use crate::dataset::split_data; + use crate::dataset::filter_outlier; use crate::training::calculate_average_recall; #[test] @@ -282,62 +280,76 @@ mod tests { let t = Array1::from(vec![0.0, 1.0, 2.0, 3.0]); let s = 1.0; let y = power_forgetting_curve(&t, s); - let expected = Array1::from(vec![1.0, 0.90000004, 0.82502866, 0.76613086]); + let expected = Array1::from(vec![1.0, 0.9, 0.8250286473253902, 0.7661308776828737]); assert_eq!(y, expected); } #[test] fn test_loss() { - let delta_t = Array1::from(vec![1.0, 2.0, 3.0]); - let recall = Array1::from(vec![0.9, 0.8181818, 0.75]); - let count = Array1::from(vec![100.0, 100.0, 100.0]); - let init_s0 = 1.0; - let actual = loss(&delta_t, &recall, &count, init_s0, init_s0); - assert_eq!(actual, 13.624332); - Data::from([loss(&delta_t, &recall, &count, 2.0, init_s0)]) - .assert_approx_eq(&Data::from([14.5771]), 5); + let delta_t = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); + let recall = Array1::from(vec![ + 0.86684181, 0.90758192, 0.73348482, 0.76776996, 0.68769064, + ]); + let count = Array1::from(vec![435.0, 97.0, 63.0, 38.0, 28.0]); + let default_s0 = DEFAULT_PARAMETERS[0] as f64; + let actual = loss(&delta_t, &recall, &count, 1.017056, default_s0); + dbg!(actual); + assert_eq!(actual, 22.922578338789826); + let actual = loss(&delta_t, &recall, &count, 1.017011, default_s0); + dbg!(actual); + assert_eq!(actual, 22.922578344493953); } #[test] fn test_search_parameters() { + let first_rating = 1; let pretrainset = HashMap::from([( - 4, + first_rating, vec![ AverageRecall { delta_t: 1.0, - recall: 0.9, - count: 30.0, + recall: 0.86666667, + count: 435.0, }, AverageRecall { delta_t: 2.0, - recall: 0.8181818, - count: 30.0, + recall: 0.90721649, + count: 97.0, }, AverageRecall { delta_t: 3.0, - recall: 0.75, - count: 30.0, + recall: 0.73015873, + count: 63.0, }, AverageRecall { delta_t: 4.0, - recall: 0.6923077, - count: 30.0, + recall: 0.76315789, + count: 38.0, + }, + AverageRecall { + delta_t: 5.0, + recall: 0.67857143, + count: 28.0, }, ], )]); - let actual = search_parameters(pretrainset, 0.9); - Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([0.943_921]), 3); + let actual = search_parameters(pretrainset, 0.9430285915990116); + Data::from([*actual.get(&first_rating).unwrap()]) + .assert_approx_eq(&Data::from([1.017_056]), 6); } #[test] fn test_pretrain() { use crate::convertor_tests::anki21_sample_file_converted_to_fsrs; let items = anki21_sample_file_converted_to_fsrs(); + let (mut pretrainset, mut trainset) = + items.into_iter().partition(|item| item.reviews.len() == 2); + (pretrainset, trainset) = filter_outlier(pretrainset, trainset); + let items = [pretrainset.clone(), trainset].concat(); let average_recall = calculate_average_recall(&items); - let pretrainset = split_data(items, 1).0; Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq( - &Data::from([1.001_131, 1.810_561, 4.403_226, 10.935_509]), - 4, + &Data::from([1.017_056, 1.829_625, 4.414_563, 10.935_500]), + 6, ) } diff --git a/src/training.rs b/src/training.rs index 55685709..3022616c 100644 --- a/src/training.rs +++ b/src/training.rs @@ -223,7 +223,7 @@ pub(crate) struct TrainingConfig { pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 { let (total_recall, total_reviews) = items .iter() - .flat_map(|item| item.reviews.iter()) + .map(|item| item.current()) .fold((0u32, 0u32), |(sum, count), review| { (sum + (review.rating > 1) as u32, count + 1) }); @@ -231,7 +231,6 @@ pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 { if total_reviews == 0 { return 0.0; } - total_recall as f32 / total_reviews as f32 } @@ -441,6 +440,13 @@ mod tests { use burn::backend::ndarray::NdArrayDevice; use rayon::prelude::IntoParallelIterator; + #[test] + fn test_calculate_average_recall() { + let items = anki21_sample_file_converted_to_fsrs(); + let average_recall = calculate_average_recall(&items); + assert_eq!(average_recall, 0.9435269); + } + #[test] fn training() { if std::env::var("SKIP_TRAINING").is_ok() {