Skip to content

Commit

Permalink
Feat/RMSE based on R-Matrix (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 2, 2024
1 parent 333a63c commit ce712c9
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.4.3"
version = "0.4.4"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
13 changes: 13 additions & 0 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ impl FSRSItem {
pub(crate) fn current(&self) -> &FSRSReview {
self.reviews.last().unwrap()
}

pub(crate) fn r_matrix_index(&self) -> (u32, u32, u32) {
let delta_t = self.current().delta_t as f64;
let delta_t_bin = (2.48 * 3.62f64.powf(delta_t.log(3.62).floor()) * 100.0).round() as u32;
let length = self.reviews.len() as f64;
let length_bin = (1.99 * 1.89f64.powf(length.log(1.89).floor())).round() as u32;
let lapse = self.history().filter(|review| review.rating == 1).count();
if lapse == 0 {
return (delta_t_bin, length_bin, 0);
}
let lapse_bin = (1.65 * 1.73f64.powf((lapse as f64).log(1.73).floor())).round() as u32;
(delta_t_bin, length_bin, lapse_bin)
}
}

pub(crate) struct FSRSBatcher<B: Backend> {
Expand Down
94 changes: 46 additions & 48 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,30 +215,44 @@ impl<B: Backend> FSRS<B> {
return Err(FSRSError::NotEnoughData);
}
let batcher = FSRSBatcher::new(self.device());
let mut all_predictions = vec![];
let mut all_true_val = vec![];
let mut all_retention = vec![];
let mut all_labels = vec![];
let mut progress_info = ItemProgress {
current: 0,
total: items.len(),
};
let model = self.model();
let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32)> = HashMap::new();

for chunk in items.chunks(512) {
let batch = batcher.batch(chunk.to_vec());
let (_state, retention) = infer::<B>(model, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
all_predictions.extend(pred);
let true_val = batch.labels.clone().to_data().convert::<f32>().value;
all_true_val.extend(true_val);
all_retention.push(retention);
all_labels.push(batch.labels);
izip!(chunk, pred, true_val).for_each(|(item, p, y)| {
let bin = item.r_matrix_index();
let (pred, real, count) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0));
*pred += p;
*real += y;
*count += 1.0;
});
progress_info.current += chunk.len();
if !progress(progress_info) {
return Err(FSRSError::Interrupted);
}
}
let rmse = calibration_rmse(&all_predictions, &all_true_val);
let rmse = (r_matrix
.values()
.map(|(pred, real, count)| {
let pred = pred / count;
let real = real / count;
(pred - real).powi(2) * count
})
.sum::<f32>()
/ r_matrix.values().map(|(_, _, count)| count).sum::<f32>())
.sqrt();
let all_retention = Tensor::cat(all_retention, 0);
let all_labels = Tensor::cat(all_labels, 0).float();
let loss = BCELoss::new().forward(all_retention, all_labels);
Expand Down Expand Up @@ -337,13 +351,6 @@ fn get_bin(x: f32, bins: i32) -> i32 {
(binned_x as i32).clamp(0, bins - 1)
}

fn calibration_rmse(pred: &[f32], true_val: &[f32]) -> f32 {
if pred.len() != true_val.len() {
panic!("Vectors pred and true_val must have the same length");
}
measure_a_by_b(pred, pred, true_val)
}

fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 {
let mut groups = HashMap::new();
izip!(pred_a, pred_b, true_val).for_each(|(a, b, t)| {
Expand All @@ -368,26 +375,13 @@ fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 {
#[cfg(test)]
mod tests {
use super::*;
use crate::{convertor_tests::anki21_sample_file_converted_to_fsrs, FSRSReview};
use crate::{
convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::filter_outlier, FSRSReview,
};

static PARAMETERS: &[f32] = &[
0.81497127,
1.5411042,
4.007436,
9.045982,
4.9264183,
1.039322,
0.93803364,
0.0,
1.5530516,
0.10299722,
0.9981442,
2.210701,
0.018248068,
0.3422524,
1.3384504,
0.22278537,
2.6646678,
1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244,
];

#[test]
Expand Down Expand Up @@ -435,8 +429,8 @@ mod tests {
assert_eq!(
fsrs.memory_state(item, None).unwrap(),
MemoryState {
stability: 51.31289,
difficulty: 7.005062
stability: 43.055424,
difficulty: 7.7609
}
);

Expand All @@ -453,7 +447,7 @@ mod tests {
.good
.memory,
MemoryState {
stability: 51.339684,
stability: 51.441338,
difficulty: 7.005062
}
);
Expand All @@ -473,25 +467,29 @@ mod tests {
#[test]
fn test_evaluate() -> Result<()> {
let items = anki21_sample_file_converted_to_fsrs();
let (mut pretrainset, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) =
items.into_iter().partition(|item| item.reviews.len() == 2);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
let items = [pretrainset, trainset].concat();
let fsrs = FSRS::new(Some(&[]))?;

let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.203_023, 0.024_624]), 5);
.assert_approx_eq(&Data::from([0.203_888, 0.029_732]), 5);

let fsrs = FSRS::new(Some(PARAMETERS))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.201_908, 0.013_964]), 5);
.assert_approx_eq(&Data::from([0.202_188, 0.021_781]), 5);

let (self_by_other, other_by_self) = fsrs
.universal_metrics(items, &DEFAULT_PARAMETERS, |_| true)
.unwrap();

Data::from([self_by_other, other_by_self])
.assert_approx_eq(&Data::from([0.016_727, 0.019_213]), 5);
.assert_approx_eq(&Data::from([0.014_089, 0.016_483]), 5);
Ok(())
}

Expand Down Expand Up @@ -524,31 +522,31 @@ mod tests {
NextStates {
again: ItemState {
memory: MemoryState {
stability: 4.577856,
difficulty: 8.881129,
stability: 3.9653313,
difficulty: 9.7949
},
interval: 5
interval: 4
},
hard: ItemState {
memory: MemoryState {
stability: 27.6745,
difficulty: 7.9430957
stability: 22.415548,
difficulty: 8.7779
},
interval: 28,
interval: 22
},
good: ItemState {
memory: MemoryState {
stability: 51.31289,
difficulty: 7.005062
stability: 43.055424,
difficulty: 7.7609
},
interval: 51,
interval: 43
},
easy: ItemState {
memory: MemoryState {
stability: 101.94249,
difficulty: 6.0670285
stability: 90.86977,
difficulty: 6.7439003
},
interval: 102,
interval: 91
}
}
);
Expand Down

0 comments on commit ce712c9

Please sign in to comment.