From 333a63c232fdbcbd2e4824038b8baf45f2e4359e Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Fri, 1 Mar 2024 16:06:01 +0800 Subject: [PATCH] update default parameters (#159) --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/inference.rs | 10 +++++----- src/model.rs | 14 +++++++------- src/optimal_retention.rs | 8 ++++---- src/pre_training.rs | 8 +++++--- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 03753291..a7c4d3dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1046,7 +1046,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.2" +version = "0.4.3" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index f7858c45..bf62d012 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.4.2" +version = "0.4.3" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/inference.rs b/src/inference.rs index cebfa75a..25a9d2e1 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -21,8 +21,8 @@ pub type Parameters = [f32]; use itertools::izip; pub static DEFAULT_PARAMETERS: [f32; 17] = [ - 0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132, - 0.0839, 0.3204, 1.3547, 0.219, 2.7849, + 0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342, 1.0166, 2.1174, + 0.0839, 0.3204, 1.4676, 0.219, 2.8237, ]; fn infer( @@ -478,7 +478,7 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.204_001, 0.025_387]), 5); + .assert_approx_eq(&Data::from([0.203_023, 0.024_624]), 5); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); @@ -491,7 +491,7 @@ mod tests { .unwrap(); Data::from([self_by_other, other_by_self]) - .assert_approx_eq(&Data::from([0.015_987, 0.019_767]), 5); + .assert_approx_eq(&Data::from([0.016_727, 0.019_213]), 5); Ok(()) } @@ -578,7 +578,7 @@ mod tests { fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(), MemoryState { stability: 9.999995, - difficulty: 7.200902 + difficulty: 7.255334 } ); assert_eq!( diff --git a/src/model.rs b/src/model.rs index dcca83d9..08ef3cd7 100644 --- a/src/model.rs +++ b/src/model.rs @@ -283,7 +283,7 @@ mod tests { let stability = model.init_stability(rating); assert_eq!( stability.to_data(), - Data::from([0.5614, 1.2546, 3.5878, 7.9731, 0.5614, 1.2546]) + Data::from([0.5701, 1.4436, 4.1386, 10.9355, 0.5701, 1.4436]) ) } @@ -295,7 +295,7 @@ mod tests { let difficulty = model.init_difficulty(rating); assert_eq!( difficulty.to_data(), - Data::from([7.3649, 6.2346, 5.1043, 3.974, 7.3649, 6.2346]) + Data::from([7.5455, 6.3449, 5.1443, 3.9436998, 7.5455, 6.3449]) ) } @@ -331,13 +331,13 @@ mod tests { next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.646, 5.823, 5.0, 4.177]) + Data::from([6.7254, 5.8627, 5.0, 4.1373]) ); let next_difficulty = model.mean_reversion(next_difficulty); next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.574311, 5.7895803, 5.00485, 4.2201195]) + Data::from([6.6681643, 5.836694, 5.0052238, 4.1737533]) ) } @@ -358,19 +358,19 @@ mod tests { s_recall.clone().backward(); assert_eq!( s_recall.to_data(), - Data::from([26.678038, 13.996968, 62.718544, 202.76956]) + Data::from([26.980936, 14.128489, 63.600677, 208.72739]) ); let s_forget = model.stability_after_failure(stability, difficulty, retention); s_forget.clone().backward(); assert_eq!( s_forget.to_data(), - Data::from([1.8932177, 2.0453987, 2.2637987, 2.5304008]) + Data::from([1.9016013, 2.0777826, 2.3257504, 2.6291647]) ); let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); next_stability.clone().backward(); assert_eq!( next_stability.to_data(), - Data::from([1.8932177, 13.996968, 62.718544, 202.76956]) + Data::from([1.9016013, 14.128489, 63.600677, 208.72739]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 0739dcbd..2822215d 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -677,7 +677,7 @@ mod tests { None, ) .0; - assert_eq!(memorization[memorization.len() - 1], 3022.055014122344) + assert_eq!(memorization[memorization.len() - 1], 3130.8465582271774) } #[test] @@ -732,8 +732,8 @@ mod tests { assert_eq!( results.1.to_vec(), vec![ - 0, 16, 27, 29, 86, 73, 96, 95, 96, 105, 112, 113, 124, 131, 139, 124, 130, 141, - 162, 175, 168, 179, 186, 185, 198, 189, 200, 200, 200, 200 + 0, 16, 27, 34, 84, 80, 91, 92, 103, 107, 111, 113, 138, 132, 133, 116, 134, 148, + 152, 162, 172, 177, 188, 189, 200, 185, 185, 200, 198, 200 ] ); assert_eq!( @@ -747,7 +747,7 @@ mod tests { let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.864870726919112); + assert_eq!(optimal_retention, 0.8468471175527587); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 67050e08..c7964edc 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -335,8 +335,10 @@ mod tests { let items = anki21_sample_file_converted_to_fsrs(); let average_recall = calculate_average_recall(&items); let pretrainset = split_data(items, 1).0; - Data::from(pretrain(pretrainset, average_recall).unwrap()) - .assert_approx_eq(&Data::from([1.001_131, 1.810_561, 4.403_481, 8.530_161]), 4) + Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq( + &Data::from([1.001_131, 1.810_561, 4.403_226, 10.935_509]), + 4, + ) } #[test] @@ -349,6 +351,6 @@ mod tests { let mut rating_stability = HashMap::from([(2, 0.35)]); let rating_count = HashMap::from([(2, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.15661564, 0.35, 1.0009006, 2.2242827,]); + assert_eq!(actual, [0.13822041, 0.35, 1.0034012, 2.6513057,]); } }