Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/calc SInc based on last d #111

Merged
merged 5 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ pub struct ItemProgress {
fn get_bin(x: f32, bins: i32) -> i32 {
let log_base = (bins.add(1) as f32).ln();
let binned_x = (x * log_base).exp().floor().sub(1.0);
(binned_x as i32).min(bins - 1).max(0)
(binned_x as i32).clamp(0, bins - 1)
}

fn calibration_rmse(pred: &[f32], true_val: &[f32]) -> f32 {
Expand Down Expand Up @@ -387,13 +387,13 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.20753297, 0.041_122_54]), 5);
.assert_approx_eq(&Data::from([0.20745006, 0.040_497_02]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.20320644, 0.016_822_13]), 5);
.assert_approx_eq(&Data::from([0.20321770, 0.015_836_29]), 5);
Ok(())
}

Expand Down Expand Up @@ -426,17 +426,17 @@ mod tests {
NextStates {
again: ItemState {
memory: MemoryState {
stability: 4.5604353,
stability: 4.5802255,
difficulty: 8.881129,
},
interval: 5
},
hard: ItemState {
memory: MemoryState {
stability: 26.111229,
stability: 27.7025,
difficulty: 7.9430957
},
interval: 26,
interval: 28,
},
good: ItemState {
memory: MemoryState {
Expand All @@ -447,10 +447,10 @@ mod tests {
},
easy: ItemState {
memory: MemoryState {
stability: 121.01552,
stability: 101.98282,
difficulty: 6.0670285
},
interval: 121,
interval: 102,
}
}
);
Expand Down
17 changes: 9 additions & 8 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<B: Backend> Model<B> {
fn stability_after_success(
&self,
last_s: Tensor<B, 1>,
new_d: Tensor<B, 1>,
last_d: Tensor<B, 1>,
r: Tensor<B, 1>,
rating: Tensor<B, 1>,
) -> Tensor<B, 1> {
Expand All @@ -81,7 +81,7 @@ impl<B: Backend> Model<B> {

last_s.clone()
* (self.w.get(8).exp()
* (-new_d + 11)
* (-last_d + 11)
* (last_s.pow(-self.w.get(9)))
* (((-r + 1) * self.w.get(10)).exp() - 1)
* hard_penalty
Expand All @@ -92,11 +92,11 @@ impl<B: Backend> Model<B> {
fn stability_after_failure(
&self,
last_s: Tensor<B, 1>,
new_d: Tensor<B, 1>,
last_d: Tensor<B, 1>,
r: Tensor<B, 1>,
) -> Tensor<B, 1> {
self.w.get(11)
* new_d.pow(-self.w.get(12))
* last_d.pow(-self.w.get(12))
* ((last_s + 1).pow(self.w.get(13)) - 1)
* ((-r + 1) * self.w.get(14)).exp()
}
Expand Down Expand Up @@ -125,21 +125,22 @@ impl<B: Backend> Model<B> {
) -> MemoryStateTensors<B> {
let (new_s, new_d) = if let Some(state) = state {
let retention = self.power_forgetting_curve(delta_t, state.stability.clone());
let mut new_difficulty = self.next_difficulty(state.difficulty.clone(), rating.clone());
new_difficulty = self.mean_reversion(new_difficulty).clamp(1.0, 10.0);
let stability_after_success = self.stability_after_success(
state.stability.clone(),
new_difficulty.clone(),
state.difficulty.clone(),
retention.clone(),
rating.clone(),
);
let stability_after_failure = self.stability_after_failure(
state.stability.clone(),
new_difficulty.clone(),
state.difficulty.clone(),
retention.clone(),
);
let mut new_stability = stability_after_success
.mask_where(rating.clone().equal_elem(1), stability_after_failure);

let mut new_difficulty = self.next_difficulty(state.difficulty.clone(), rating.clone());
new_difficulty = self.mean_reversion(new_difficulty).clamp(1.0, 10.0);
// mask padding zeros for rating
new_stability = new_stability.mask_where(rating.clone().equal_elem(0), state.stability);
new_difficulty = new_difficulty.mask_where(rating.equal_elem(0), state.difficulty);
Expand Down
46 changes: 29 additions & 17 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ impl Default for SimulatorConfig {
}

fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) -> f64 {
let hard_penalty = if response == 1 { w[15] } else { 1.0 };
let easy_bonus = if response == 3 { w[16] } else { 1.0 };
let hard_penalty = if response == 2 { w[15] } else { 1.0 };
let easy_bonus = if response == 4 { w[16] } else { 1.0 };
s * (1.0
+ f64::exp(w[8])
* (11.0 - d)
Expand All @@ -88,8 +88,8 @@ fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) -
}

fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 {
s.min(w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f64::exp((1.0 - r) * w[14]))
.max(0.1)
(w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f64::exp((1.0 - r) * w[14]))
.clamp(0.1, s)
}

fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: Option<u64>) -> f64 {
Expand All @@ -116,10 +116,10 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
// let mut learn_cnt_per_day = Array1::<f64>::zeros(learn_span);
let mut memorized_cnt_per_day = Array1::zeros(learn_span);

let first_rating_choices = [0, 1, 2, 3];
let first_rating_choices = [1, 2, 3, 4];
let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap();

let review_rating_choices = [1, 2, 3];
let review_rating_choices = [2, 3, 4];
let review_rating_dist = WeightedIndex::new(review_rating_prob).unwrap();

let mut rng = StdRng::seed_from_u64(seed.unwrap_or(42));
Expand Down Expand Up @@ -180,8 +180,8 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O

// Sample 'rating' for 'need_review' entries
let mut ratings = Array1::zeros(deck_size);
izip!(&mut ratings, &need_review)
.filter(|(_, &need_review_flag)| need_review_flag)
izip!(&mut ratings, &(&need_review & !&forget))
.filter(|(_, &condition)| condition)
.for_each(|(rating, _)| {
*rating = review_rating_choices[review_rating_dist.sample(&mut rng)]
});
Expand All @@ -193,12 +193,13 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
*cost = if forget_flag {
forget_cost * loss_aversion
} else {
recall_costs[rating - 1]
recall_costs[rating - 2]
}
});

// Calculate cumulative sum of 'cost'
let mut cum_sum = Array1::<f64>::zeros(deck_size);
cum_sum[0] = cost[0];
for i in 1..deck_size {
cum_sum[i] = cum_sum[i - 1] + cost[i];
}
Expand All @@ -219,6 +220,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
*cost = learn_cost;
});

cum_sum[0] = cost[0];
for i in 1..deck_size {
cum_sum[i] = cum_sum[i - 1] + cost[i];
}
Expand Down Expand Up @@ -276,9 +278,21 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
izip!(&mut new_difficulty, &old_difficulty, &true_review, &forget)
.filter(|(.., &true_rev, &frgt)| true_rev && frgt)
.for_each(|(new_diff, &old_diff, ..)| {
*new_diff = (old_diff + 2.0 * w[6]).max(1.0).min(10.0);
*new_diff = (old_diff + 2.0 * w[6]).clamp(1.0, 10.0);
});

// Update the difficulty values based on the condition 'true_review & !forget'
izip!(
&mut new_difficulty,
&old_difficulty,
&ratings,
&(&true_review & !&forget)
)
.filter(|(.., &condition)| condition)
.for_each(|(new_diff, &old_diff, &rating, ..)| {
*new_diff = (old_diff - w[6] * (rating as f64 - 3.0)).clamp(1.0, 10.0);
});

// Update 'last_date' column where 'true_review' or 'true_learn' is true
let mut new_last_date = old_last_date.to_owned();
izip!(&mut new_last_date, &true_review, &true_learn)
Expand All @@ -295,8 +309,8 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
)
.filter(|(.., &true_learn_flag)| true_learn_flag)
.for_each(|(new_stab, new_diff, &rating, _)| {
*new_stab = w[rating];
*new_diff = w[4] - w[5] * (rating as f64 - 3.0);
*new_stab = w[rating - 1];
*new_diff = (w[4] - w[5] * (rating as f64 - 3.0)).clamp(1.0, 10.0);
});
let old_interval = card_table.slice(s![Column::Interval, ..]);
let mut new_interval = old_interval.to_owned();
Expand All @@ -305,8 +319,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
.for_each(|(new_ivl, &new_stab, ..)| {
*new_ivl = (9.0 * new_stab * (1.0 / request_retention - 1.0))
.round()
.min(max_ivl)
.max(1.0);
.clamp(1.0, max_ivl);
});

let old_due = card_table.slice(s![Column::Due, ..]);
Expand All @@ -331,7 +344,6 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
card_table
.slice_mut(s![Column::Interval, ..])
.assign(&new_interval);

// Update the review_cnt_per_day, learn_cnt_per_day and memorized_cnt_per_day
// review_cnt_per_day[today] = true_review.iter().filter(|&&x| x).count() as f64;
// learn_cnt_per_day[today] = true_learn.iter().filter(|&&x| x).count() as f64;
Expand Down Expand Up @@ -436,15 +448,15 @@ mod tests {
0.9,
None,
);
assert_eq!(memorization, 2635.689850107157)
assert_eq!(memorization, 2633.365434092778)
}

#[test]
fn optimal_retention() -> Result<()> {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8633668648071942);
assert_eq!(optimal_retention, 0.8530025910684347);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down