From 54df730e1f3d5bd2775b97b5ef6ceb1773857763 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 13 Feb 2024 21:24:54 +0800 Subject: [PATCH] Fix/PLS shouldn't exceed last_s (#156) https://github.com/open-spaced-repetition/fsrs-optimizer/pull/82 --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/inference.rs | 4 ++-- src/model.rs | 9 ++++++--- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e6f9aa00..53e22a52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1065,7 +1065,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.4.0" +version = "0.4.1" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index a65fff78..5e9933ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.4.0" +version = "0.4.1" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/inference.rs b/src/inference.rs index 3cc05aa6..9b2a720a 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -468,14 +468,14 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.201_908, 0.013_894]), 5); + .assert_approx_eq(&Data::from([0.201_908, 0.013_964]), 5); let (self_by_other, other_by_self) = fsrs .universal_metrics(items, &DEFAULT_PARAMETERS, |_| true) .unwrap(); Data::from([self_by_other, other_by_self]) - .assert_approx_eq(&Data::from([0.015_987_674, 0.019_702_684]), 5); + .assert_approx_eq(&Data::from([0.015_987, 0.019_767]), 5); Ok(()) } diff --git a/src/model.rs b/src/model.rs index f8702b8e..62d3a38b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -90,10 +90,13 @@ impl Model { last_d: Tensor, r: Tensor, ) -> Tensor { - self.w.get(11) + let new_s = self.w.get(11) * last_d.pow(-self.w.get(12)) - * ((last_s + 1).pow(self.w.get(13)) - 1) - * ((-r + 1) * self.w.get(14)).exp() + * ((last_s.clone() + 1).pow(self.w.get(13)) - 1) + * ((-r + 1) * self.w.get(14)).exp(); + new_s + .clone() + .mask_where(last_s.clone().lower(new_s), last_s) } fn mean_reversion(&self, new_d: Tensor) -> Tensor {