From 35b80f444cb9cc5b0b8fe51a46fffd4278396e38 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 2 Oct 2024 23:36:31 +0800 Subject: [PATCH] Feat/float next interval (#213) --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/inference.rs | 22 ++++++++++------------ src/optimal_retention.rs | 4 +++- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 80fa2fd0..60386f75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "1.2.4" +version = "1.3.0" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index b7ef898c..cf1e291c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "1.2.4" +version = "1.3.0" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/inference.rs b/src/inference.rs index a6197c82..c593d437 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -65,10 +65,8 @@ impl From for MemoryStateTensors { } } -pub fn next_interval(stability: f32, desired_retention: f32) -> u32 { - (stability / FACTOR as f32 * (desired_retention.powf(1.0 / DECAY as f32) - 1.0)) - .round() - .max(1.0) as u32 +pub fn next_interval(stability: f32, desired_retention: f32) -> f32 { + stability / FACTOR as f32 * (desired_retention.powf(1.0 / DECAY as f32) - 1.0) } impl FSRS { @@ -144,7 +142,7 @@ impl FSRS { stability: Option, desired_retention: f32, rating: u32, - ) -> u32 { + ) -> f32 { let stability = stability.unwrap_or_else(|| { // get initial stability for new card let rating = Tensor::from_data( @@ -333,7 +331,7 @@ pub struct NextStates { #[derive(Debug, PartialEq, Clone)] pub struct ItemState { pub memory: MemoryState, - pub interval: u32, + pub interval: f32, } #[derive(Debug, Clone, Copy)] @@ -475,7 +473,7 @@ mod tests { let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::>(); let intervals = desired_retentions .iter() - .map(|r| next_interval(1.0, *r)) + .map(|r| next_interval(1.0, *r).round().max(1.0) as i32) .collect::>(); assert_eq!(intervals, [422, 102, 43, 22, 13, 8, 4, 2, 1, 1]); } @@ -548,32 +546,32 @@ mod tests { stability: 2.969144, difficulty: 9.520562 }, - interval: 3 + interval: 2.9691453 }, hard: ItemState { memory: MemoryState { stability: 17.091442, difficulty: 8.4513445 }, - interval: 17 + interval: 17.09145 }, good: ItemState { memory: MemoryState { stability: 31.722975, difficulty: 7.382128 }, - interval: 32 + interval: 31.722988 }, easy: ItemState { memory: MemoryState { stability: 71.75015, difficulty: 6.3129106 }, - interval: 72 + interval: 71.75018 } } ); - assert_eq!(fsrs.next_interval(Some(121.01552), 0.9, 1), 121); + assert_eq!(fsrs.next_interval(Some(121.01552), 0.9, 1), 121.01557); Ok(()) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 5d044921..6201d987 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -414,7 +414,9 @@ pub fn simulate( izip!(&mut new_interval, &new_stability, &true_review, &true_learn) .filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) .for_each(|(new_ivl, &new_stab, ..)| { - *new_ivl = (next_interval(new_stab, desired_retention) as f32).clamp(1.0, max_ivl); + *new_ivl = next_interval(new_stab, desired_retention) + .round() + .clamp(1.0, max_ivl); }); let old_due = card_table.slice(s![Column::Due, ..]);