Skip to content

Commit

Permalink
Fix/smooth stability after training (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Aug 19, 2024
1 parent ea80c61 commit 840d80e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.1.4"
version = "1.1.5"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
14 changes: 10 additions & 4 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@ static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[
(4, DEFAULT_PARAMETERS[3]),
];

pub fn pretrain(fsrs_items: Vec<FSRSItem>, average_recall: f32) -> Result<[f32; 4]> {
pub fn pretrain(
fsrs_items: Vec<FSRSItem>,
average_recall: f32,
) -> Result<([f32; 4], HashMap<u32, u32>)> {
let pretrainset = create_pretrain_data(fsrs_items);
let rating_count = total_rating_count(&pretrainset);
let mut rating_stability = search_parameters(pretrainset, average_recall);
smooth_and_fill(&mut rating_stability, &rating_count)
Ok((
smooth_and_fill(&mut rating_stability, &rating_count)?,
rating_count,
))
}

type FirstRating = u32;
Expand Down Expand Up @@ -158,7 +164,7 @@ fn search_parameters(
optimal_stabilities
}

fn smooth_and_fill(
pub(crate) fn smooth_and_fill(
rating_stability: &mut HashMap<u32, f32>,
rating_count: &HashMap<u32, u32>,
) -> Result<[f32; 4]> {
Expand Down Expand Up @@ -349,7 +355,7 @@ mod tests {
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
let items = [pretrainset.clone(), trainset].concat();
let average_recall = calculate_average_recall(&items);
Data::from(pretrain(pretrainset, average_recall).unwrap())
Data::from(pretrain(pretrainset, average_recall).unwrap().0)
.assert_approx_eq(&Data::from([0.908_688, 1.678_973, 4.216_837, 9.615_904]), 6)
}

Expand Down
25 changes: 19 additions & 6 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::dataset::{prepare_training_data, FSRSBatcher, FSRSDataset, FSRSItem};
use crate::error::Result;
use crate::model::{Model, ModelConfig};
use crate::parameter_clipper::parameter_clipper;
use crate::pre_training::pretrain;
use crate::pre_training::{pretrain, smooth_and_fill};
use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS};
use burn::backend::Autodiff;

Expand Down Expand Up @@ -213,10 +213,11 @@ impl<B: Backend> FSRS<B> {
return Ok(DEFAULT_PARAMETERS.to_vec());
}

let initial_stability = pretrain(pre_train_set.clone(), average_recall).map_err(|e| {
finish_progress();
e
})?;
let (initial_stability, initial_rating_count) =
pretrain(pre_train_set.clone(), average_recall).map_err(|e| {
finish_progress();
e
})?;
let pretrained_parameters: Vec<f32> = initial_stability
.into_iter()
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
Expand Down Expand Up @@ -272,6 +273,18 @@ impl<B: Backend> FSRS<B> {
return Err(FSRSError::InvalidInput);
}

let mut optimized_initial_stability = optimized_parameters[0..4]
.iter()
.enumerate()
.map(|(i, &val)| (i as u32 + 1, val))
.collect();
let clamped_stability =
smooth_and_fill(&mut optimized_initial_stability, &initial_rating_count).unwrap();
let optimized_parameters = clamped_stability
.into_iter()
.chain(optimized_parameters[4..].iter().copied())
.collect();

Ok(optimized_parameters)
}

Expand All @@ -281,7 +294,7 @@ impl<B: Backend> FSRS<B> {
.clone()
.into_iter()
.partition(|item| item.long_term_review_cnt() == 1);
let initial_stability = pretrain(pre_train_set, average_recall).unwrap();
let initial_stability = pretrain(pre_train_set, average_recall).unwrap().0;
let config = TrainingConfig::new(
ModelConfig {
freeze_stability: false,
Expand Down

0 comments on commit 840d80e

Please sign in to comment.