diff --git a/src/inference.rs b/src/inference.rs index f842ea77..3e475377 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -244,7 +244,7 @@ pub struct ItemProgress { fn get_bin(x: f32, bins: i32) -> i32 { let log_base = (bins.add(1) as f32).ln(); let binned_x = (x * log_base).exp().floor().sub(1.0); - (binned_x as i32).min(bins - 1).max(0) + (binned_x as i32).clamp(0, bins - 1) } fn calibration_rmse(pred: &[f32], true_val: &[f32]) -> f32 { @@ -387,13 +387,13 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.20753297, 0.041_122_54]), 5); + .assert_approx_eq(&Data::from([0.20745006, 0.040_497_02]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.20320644, 0.016_822_13]), 5); + .assert_approx_eq(&Data::from([0.20321770, 0.015_836_29]), 5); Ok(()) } @@ -426,17 +426,17 @@ mod tests { NextStates { again: ItemState { memory: MemoryState { - stability: 4.5604353, + stability: 4.5802255, difficulty: 8.881129, }, interval: 5 }, hard: ItemState { memory: MemoryState { - stability: 26.111229, + stability: 27.7025, difficulty: 7.9430957 }, - interval: 26, + interval: 28, }, good: ItemState { memory: MemoryState { @@ -447,10 +447,10 @@ mod tests { }, easy: ItemState { memory: MemoryState { - stability: 121.01552, + stability: 101.98282, difficulty: 6.0670285 }, - interval: 121, + interval: 102, } } ); diff --git a/src/model.rs b/src/model.rs index f94eafa1..8674ac5e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -69,7 +69,7 @@ impl Model { fn stability_after_success( &self, last_s: Tensor, - new_d: Tensor, + last_d: Tensor, r: Tensor, rating: Tensor, ) -> Tensor { @@ -81,7 +81,7 @@ impl Model { last_s.clone() * (self.w.get(8).exp() - * (-new_d + 11) + * (-last_d + 11) * (last_s.pow(-self.w.get(9))) * (((-r + 1) * self.w.get(10)).exp() - 1) * hard_penalty @@ -92,11 +92,11 @@ impl Model { fn stability_after_failure( &self, last_s: Tensor, - new_d: Tensor, + last_d: Tensor, r: Tensor, ) -> Tensor { self.w.get(11) - * new_d.pow(-self.w.get(12)) + * last_d.pow(-self.w.get(12)) * ((last_s + 1).pow(self.w.get(13)) - 1) * ((-r + 1) * self.w.get(14)).exp() } @@ -125,21 +125,22 @@ impl Model { ) -> MemoryStateTensors { let (new_s, new_d) = if let Some(state) = state { let retention = self.power_forgetting_curve(delta_t, state.stability.clone()); - let mut new_difficulty = self.next_difficulty(state.difficulty.clone(), rating.clone()); - new_difficulty = self.mean_reversion(new_difficulty).clamp(1.0, 10.0); let stability_after_success = self.stability_after_success( state.stability.clone(), - new_difficulty.clone(), + state.difficulty.clone(), retention.clone(), rating.clone(), ); let stability_after_failure = self.stability_after_failure( state.stability.clone(), - new_difficulty.clone(), + state.difficulty.clone(), retention.clone(), ); let mut new_stability = stability_after_success .mask_where(rating.clone().equal_elem(1), stability_after_failure); + + let mut new_difficulty = self.next_difficulty(state.difficulty.clone(), rating.clone()); + new_difficulty = self.mean_reversion(new_difficulty).clamp(1.0, 10.0); // mask padding zeros for rating new_stability = new_stability.mask_where(rating.clone().equal_elem(0), state.stability); new_difficulty = new_difficulty.mask_where(rating.equal_elem(0), state.difficulty); diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 59e4dea6..1df7ee17 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -76,8 +76,8 @@ impl Default for SimulatorConfig { } fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) -> f64 { - let hard_penalty = if response == 1 { w[15] } else { 1.0 }; - let easy_bonus = if response == 3 { w[16] } else { 1.0 }; + let hard_penalty = if response == 2 { w[15] } else { 1.0 }; + let easy_bonus = if response == 4 { w[16] } else { 1.0 }; s * (1.0 + f64::exp(w[8]) * (11.0 - d) @@ -88,8 +88,8 @@ fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) - } fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 { - s.min(w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f64::exp((1.0 - r) * w[14])) - .max(0.1) + (w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f64::exp((1.0 - r) * w[14])) + .clamp(0.1, s) } fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: Option) -> f64 { @@ -116,10 +116,10 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O // let mut learn_cnt_per_day = Array1::::zeros(learn_span); let mut memorized_cnt_per_day = Array1::zeros(learn_span); - let first_rating_choices = [0, 1, 2, 3]; + let first_rating_choices = [1, 2, 3, 4]; let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap(); - let review_rating_choices = [1, 2, 3]; + let review_rating_choices = [2, 3, 4]; let review_rating_dist = WeightedIndex::new(review_rating_prob).unwrap(); let mut rng = StdRng::seed_from_u64(seed.unwrap_or(42)); @@ -180,8 +180,8 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O // Sample 'rating' for 'need_review' entries let mut ratings = Array1::zeros(deck_size); - izip!(&mut ratings, &need_review) - .filter(|(_, &need_review_flag)| need_review_flag) + izip!(&mut ratings, &(&need_review & !&forget)) + .filter(|(_, &condition)| condition) .for_each(|(rating, _)| { *rating = review_rating_choices[review_rating_dist.sample(&mut rng)] }); @@ -193,12 +193,13 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O *cost = if forget_flag { forget_cost * loss_aversion } else { - recall_costs[rating - 1] + recall_costs[rating - 2] } }); // Calculate cumulative sum of 'cost' let mut cum_sum = Array1::::zeros(deck_size); + cum_sum[0] = cost[0]; for i in 1..deck_size { cum_sum[i] = cum_sum[i - 1] + cost[i]; } @@ -219,6 +220,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O *cost = learn_cost; }); + cum_sum[0] = cost[0]; for i in 1..deck_size { cum_sum[i] = cum_sum[i - 1] + cost[i]; } @@ -276,9 +278,21 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O izip!(&mut new_difficulty, &old_difficulty, &true_review, &forget) .filter(|(.., &true_rev, &frgt)| true_rev && frgt) .for_each(|(new_diff, &old_diff, ..)| { - *new_diff = (old_diff + 2.0 * w[6]).max(1.0).min(10.0); + *new_diff = (old_diff + 2.0 * w[6]).clamp(1.0, 10.0); }); + // Update the difficulty values based on the condition 'true_review & !forget' + izip!( + &mut new_difficulty, + &old_difficulty, + &ratings, + &(&true_review & !&forget) + ) + .filter(|(.., &condition)| condition) + .for_each(|(new_diff, &old_diff, &rating, ..)| { + *new_diff = (old_diff - w[6] * (rating as f64 - 3.0)).clamp(1.0, 10.0); + }); + // Update 'last_date' column where 'true_review' or 'true_learn' is true let mut new_last_date = old_last_date.to_owned(); izip!(&mut new_last_date, &true_review, &true_learn) @@ -295,8 +309,8 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O ) .filter(|(.., &true_learn_flag)| true_learn_flag) .for_each(|(new_stab, new_diff, &rating, _)| { - *new_stab = w[rating]; - *new_diff = w[4] - w[5] * (rating as f64 - 3.0); + *new_stab = w[rating - 1]; + *new_diff = (w[4] - w[5] * (rating as f64 - 3.0)).clamp(1.0, 10.0); }); let old_interval = card_table.slice(s![Column::Interval, ..]); let mut new_interval = old_interval.to_owned(); @@ -305,8 +319,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O .for_each(|(new_ivl, &new_stab, ..)| { *new_ivl = (9.0 * new_stab * (1.0 / request_retention - 1.0)) .round() - .min(max_ivl) - .max(1.0); + .clamp(1.0, max_ivl); }); let old_due = card_table.slice(s![Column::Due, ..]); @@ -331,7 +344,6 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O card_table .slice_mut(s![Column::Interval, ..]) .assign(&new_interval); - // Update the review_cnt_per_day, learn_cnt_per_day and memorized_cnt_per_day // review_cnt_per_day[today] = true_review.iter().filter(|&&x| x).count() as f64; // learn_cnt_per_day[today] = true_learn.iter().filter(|&&x| x).count() as f64; @@ -436,7 +448,7 @@ mod tests { 0.9, None, ); - assert_eq!(memorization, 2635.689850107157) + assert_eq!(memorization, 2633.365434092778) } #[test] @@ -444,7 +456,7 @@ mod tests { let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8633668648071942); + assert_eq!(optimal_retention, 0.8530025910684347); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) }