From ce712c9d0fe8905d94214c7937628760f00602ff Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sat, 2 Mar 2024 20:45:17 +0800 Subject: [PATCH] Feat/RMSE based on R-Matrix (#160) --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/dataset.rs | 13 +++++++ src/inference.rs | 94 ++++++++++++++++++++++++------------------------ 4 files changed, 61 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a7c4d3dc..da467917 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1046,7 +1046,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.3" +version = "0.4.4" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index bf62d012..c9cb6a2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.4.3" +version = "0.4.4" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/dataset.rs b/src/dataset.rs index 3e894d1c..b6f5a7ca 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -34,6 +34,19 @@ impl FSRSItem { pub(crate) fn current(&self) -> &FSRSReview { self.reviews.last().unwrap() } + + pub(crate) fn r_matrix_index(&self) -> (u32, u32, u32) { + let delta_t = self.current().delta_t as f64; + let delta_t_bin = (2.48 * 3.62f64.powf(delta_t.log(3.62).floor()) * 100.0).round() as u32; + let length = self.reviews.len() as f64; + let length_bin = (1.99 * 1.89f64.powf(length.log(1.89).floor())).round() as u32; + let lapse = self.history().filter(|review| review.rating == 1).count(); + if lapse == 0 { + return (delta_t_bin, length_bin, 0); + } + let lapse_bin = (1.65 * 1.73f64.powf((lapse as f64).log(1.73).floor())).round() as u32; + (delta_t_bin, length_bin, lapse_bin) + } } pub(crate) struct FSRSBatcher { diff --git a/src/inference.rs b/src/inference.rs index 25a9d2e1..4c8ff7ef 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -215,8 +215,6 @@ impl FSRS { return Err(FSRSError::NotEnoughData); } let batcher = FSRSBatcher::new(self.device()); - let mut all_predictions = vec![]; - let mut all_true_val = vec![]; let mut all_retention = vec![]; let mut all_labels = vec![]; let mut progress_info = ItemProgress { @@ -224,21 +222,37 @@ impl FSRS { total: items.len(), }; let model = self.model(); + let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32)> = HashMap::new(); + for chunk in items.chunks(512) { let batch = batcher.batch(chunk.to_vec()); let (_state, retention) = infer::(model, batch.clone()); let pred = retention.clone().to_data().convert::().value; - all_predictions.extend(pred); let true_val = batch.labels.clone().to_data().convert::().value; - all_true_val.extend(true_val); all_retention.push(retention); all_labels.push(batch.labels); + izip!(chunk, pred, true_val).for_each(|(item, p, y)| { + let bin = item.r_matrix_index(); + let (pred, real, count) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0)); + *pred += p; + *real += y; + *count += 1.0; + }); progress_info.current += chunk.len(); if !progress(progress_info) { return Err(FSRSError::Interrupted); } } - let rmse = calibration_rmse(&all_predictions, &all_true_val); + let rmse = (r_matrix + .values() + .map(|(pred, real, count)| { + let pred = pred / count; + let real = real / count; + (pred - real).powi(2) * count + }) + .sum::() + / r_matrix.values().map(|(_, _, count)| count).sum::()) + .sqrt(); let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); let loss = BCELoss::new().forward(all_retention, all_labels); @@ -337,13 +351,6 @@ fn get_bin(x: f32, bins: i32) -> i32 { (binned_x as i32).clamp(0, bins - 1) } -fn calibration_rmse(pred: &[f32], true_val: &[f32]) -> f32 { - if pred.len() != true_val.len() { - panic!("Vectors pred and true_val must have the same length"); - } - measure_a_by_b(pred, pred, true_val) -} - fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 { let mut groups = HashMap::new(); izip!(pred_a, pred_b, true_val).for_each(|(a, b, t)| { @@ -368,26 +375,13 @@ fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 { #[cfg(test)] mod tests { use super::*; - use crate::{convertor_tests::anki21_sample_file_converted_to_fsrs, FSRSReview}; + use crate::{ + convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::filter_outlier, FSRSReview, + }; static PARAMETERS: &[f32] = &[ - 0.81497127, - 1.5411042, - 4.007436, - 9.045982, - 4.9264183, - 1.039322, - 0.93803364, - 0.0, - 1.5530516, - 0.10299722, - 0.9981442, - 2.210701, - 0.018248068, - 0.3422524, - 1.3384504, - 0.22278537, - 2.6646678, + 1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321, + 2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244, ]; #[test] @@ -435,8 +429,8 @@ mod tests { assert_eq!( fsrs.memory_state(item, None).unwrap(), MemoryState { - stability: 51.31289, - difficulty: 7.005062 + stability: 43.055424, + difficulty: 7.7609 } ); @@ -453,7 +447,7 @@ mod tests { .good .memory, MemoryState { - stability: 51.339684, + stability: 51.441338, difficulty: 7.005062 } ); @@ -473,25 +467,29 @@ mod tests { #[test] fn test_evaluate() -> Result<()> { let items = anki21_sample_file_converted_to_fsrs(); + let (mut pretrainset, mut trainset): (Vec, Vec) = + items.into_iter().partition(|item| item.reviews.len() == 2); + (pretrainset, trainset) = filter_outlier(pretrainset, trainset); + let items = [pretrainset, trainset].concat(); let fsrs = FSRS::new(Some(&[]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.203_023, 0.024_624]), 5); + .assert_approx_eq(&Data::from([0.203_888, 0.029_732]), 5); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.201_908, 0.013_964]), 5); + .assert_approx_eq(&Data::from([0.202_188, 0.021_781]), 5); let (self_by_other, other_by_self) = fsrs .universal_metrics(items, &DEFAULT_PARAMETERS, |_| true) .unwrap(); Data::from([self_by_other, other_by_self]) - .assert_approx_eq(&Data::from([0.016_727, 0.019_213]), 5); + .assert_approx_eq(&Data::from([0.014_089, 0.016_483]), 5); Ok(()) } @@ -524,31 +522,31 @@ mod tests { NextStates { again: ItemState { memory: MemoryState { - stability: 4.577856, - difficulty: 8.881129, + stability: 3.9653313, + difficulty: 9.7949 }, - interval: 5 + interval: 4 }, hard: ItemState { memory: MemoryState { - stability: 27.6745, - difficulty: 7.9430957 + stability: 22.415548, + difficulty: 8.7779 }, - interval: 28, + interval: 22 }, good: ItemState { memory: MemoryState { - stability: 51.31289, - difficulty: 7.005062 + stability: 43.055424, + difficulty: 7.7609 }, - interval: 51, + interval: 43 }, easy: ItemState { memory: MemoryState { - stability: 101.94249, - difficulty: 6.0670285 + stability: 90.86977, + difficulty: 6.7439003 }, - interval: 102, + interval: 91 } } );