diff --git a/Cargo.lock b/Cargo.lock index 657cd269..e6f9aa00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1065,7 +1065,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.3.0" +version = "0.4.0" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 6e18eeb1..d6e1aebc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.3.0" +version = "0.4.0" authors = ["Open Spaced Repetition"] categories = ["Algorithms", "Science"] edition = "2021" diff --git a/src/error.rs b/src/error.rs index e479c65f..46948751 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,7 @@ use snafu::Snafu; pub enum FSRSError { NotEnoughData, Interrupted, - InvalidWeights, + InvalidParameters, OptimalNotFound, InvalidInput, } diff --git a/src/inference.rs b/src/inference.rs index 5dc9d3ad..3cc05aa6 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -17,10 +17,10 @@ pub(crate) const DECAY: f64 = -0.5; pub(crate) const FACTOR: f64 = 19f64 / 81f64; pub(crate) const S_MIN: f32 = 0.01; /// This is a slice for efficiency, but should always be 17 in length. -pub type Weights = [f32]; +pub type Parameters = [f32]; use itertools::izip; -pub static DEFAULT_WEIGHTS: [f32; 17] = [ +pub static DEFAULT_PARAMETERS: [f32; 17] = [ 0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132, 0.0839, 0.3204, 1.3547, 0.219, 2.7849, ]; @@ -72,7 +72,7 @@ impl FSRS { /// In the case of truncated reviews, [starting_state] can be set to the value of /// [FSRS::memory_state_from_sm2] for the first review (which should not be included /// in FSRSItem). If not provided, the card starts as new. - /// Weights must have been provided when calling FSRS::new(). + /// Parameters must have been provided when calling FSRS::new(). pub fn memory_state( &self, item: FSRSItem, @@ -102,7 +102,7 @@ impl FSRS { /// If a card has incomplete learning history, memory state can be approximated from /// current sm2 values. - /// Weights must have been provided when calling FSRS::new(). + /// Parameters must have been provided when calling FSRS::new(). pub fn memory_state_from_sm2( &self, ease_factor: f32, @@ -129,7 +129,7 @@ impl FSRS { /// Calculate the next interval for the current memory state, for rescheduling. Stability /// should be provided except when the card is new. Rating is ignored except when card is new. - /// Weights must have been provided when calling FSRS::new(). + /// Parameters must have been provided when calling FSRS::new(). pub fn next_interval( &self, stability: Option, @@ -146,7 +146,7 @@ impl FSRS { } /// The intervals and memory states for each answer button. - /// Weights must have been provided when calling FSRS::new(). + /// Parameters must have been provided when calling FSRS::new(). pub fn next_states( &self, current_memory_state: Option, @@ -189,8 +189,8 @@ impl FSRS { }) } - /// Determine how well the model and weights predict performance. - /// Weights must have been provided when calling FSRS::new(). + /// Determine how well the model and parameters predict performance. + /// Parameters must have been provided when calling FSRS::new(). pub fn evaluate(&self, items: Vec, mut progress: F) -> Result where F: FnMut(ItemProgress) -> bool, @@ -243,7 +243,7 @@ impl FSRS { pub fn universal_metrics( &self, items: Vec, - parameters: &Weights, + parameters: &Parameters, mut progress: F, ) -> Result<(f32, f32)> where @@ -354,7 +354,7 @@ mod tests { use super::*; use crate::{convertor_tests::anki21_sample_file_converted_to_fsrs, FSRSReview}; - static WEIGHTS: &[f32] = &[ + static PARAMETERS: &[f32] = &[ 0.81497127, 1.5411042, 4.007436, @@ -415,7 +415,7 @@ mod tests { }, ], }; - let fsrs = FSRS::new(Some(WEIGHTS))?; + let fsrs = FSRS::new(Some(PARAMETERS))?; assert_eq!( fsrs.memory_state(item, None).unwrap(), MemoryState { @@ -464,14 +464,14 @@ mod tests { Data::from([metrics.log_loss, metrics.rmse_bins]) .assert_approx_eq(&Data::from([0.204_001, 0.025_387]), 5); - let fsrs = FSRS::new(Some(WEIGHTS))?; + let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) .assert_approx_eq(&Data::from([0.201_908, 0.013_894]), 5); let (self_by_other, other_by_self) = fsrs - .universal_metrics(items, &DEFAULT_WEIGHTS, |_| true) + .universal_metrics(items, &DEFAULT_PARAMETERS, |_| true) .unwrap(); Data::from([self_by_other, other_by_self]) @@ -501,7 +501,7 @@ mod tests { }, ], }; - let fsrs = FSRS::new(Some(WEIGHTS))?; + let fsrs = FSRS::new(Some(PARAMETERS))?; let state = fsrs.memory_state(item, None).unwrap(); assert_eq!( fsrs.next_states(Some(state), 0.9, 21).unwrap(), diff --git a/src/lib.rs b/src/lib.rs index 909ecfd3..f17a7422 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ mod weight_clipper; pub use dataset::{FSRSItem, FSRSReview}; pub use error::{FSRSError, Result}; pub use inference::{ - ItemProgress, ItemState, MemoryState, ModelEvaluation, NextStates, DEFAULT_WEIGHTS, + ItemProgress, ItemState, MemoryState, ModelEvaluation, NextStates, DEFAULT_PARAMETERS, }; pub use model::FSRS; pub use optimal_retention::SimulatorConfig; diff --git a/src/model.rs b/src/model.rs index 700d2cc9..f8702b8e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,7 +1,7 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{Weights, DECAY, FACTOR, S_MIN}; -use crate::weight_clipper::clip_weights; -use crate::DEFAULT_WEIGHTS; +use crate::inference::{Parameters, DECAY, FACTOR, S_MIN}; +use crate::weight_clipper::clip_parameters; +use crate::DEFAULT_PARAMETERS; use burn::backend::ndarray::NdArrayDevice; use burn::backend::NdArray; use burn::{ @@ -43,9 +43,9 @@ impl Model { pub fn new(config: ModelConfig) -> Self { let initial_params = config .initial_stability - .unwrap_or_else(|| DEFAULT_WEIGHTS[0..4].try_into().unwrap()) + .unwrap_or_else(|| DEFAULT_PARAMETERS[0..4].try_into().unwrap()) .into_iter() - .chain(DEFAULT_WEIGHTS[4..].iter().copied()) + .chain(DEFAULT_PARAMETERS[4..].iter().copied()) .collect(); Self { @@ -201,27 +201,27 @@ pub struct FSRS { } impl FSRS { - /// - Weights must be provided before running commands that need them. - /// - Weights may be an empty slice to use the default values instead. - pub fn new(weights: Option<&Weights>) -> Result { - Self::new_with_backend(weights, NdArrayDevice::Cpu) + /// - Parameters must be provided before running commands that need them. + /// - Parameters may be an empty slice to use the default values instead. + pub fn new(parameters: Option<&Parameters>) -> Result { + Self::new_with_backend(parameters, NdArrayDevice::Cpu) } } impl FSRS { pub fn new_with_backend( - mut weights: Option<&Weights>, + mut parameters: Option<&Parameters>, device: B2::Device, ) -> Result> { - if let Some(weights) = &mut weights { - if weights.is_empty() { - *weights = DEFAULT_WEIGHTS.as_slice() - } else if weights.len() != 17 { - return Err(FSRSError::InvalidWeights); + if let Some(parameters) = &mut parameters { + if parameters.is_empty() { + *parameters = DEFAULT_PARAMETERS.as_slice() + } else if parameters.len() != 17 { + return Err(FSRSError::InvalidParameters); } } Ok(FSRS { - model: weights.map(weights_to_model), + model: parameters.map(parameters_to_model), device, }) } @@ -229,7 +229,7 @@ impl FSRS { pub(crate) fn model(&self) -> &Model { self.model .as_ref() - .expect("command requires weights to be set on creation") + .expect("command requires parameters to be set on creation") } pub(crate) fn device(&self) -> B::Device { @@ -237,11 +237,11 @@ impl FSRS { } } -pub(crate) fn weights_to_model(weights: &Weights) -> Model { +pub(crate) fn parameters_to_model(parameters: &Parameters) -> Model { let config = ModelConfig::default(); let mut model = Model::new(config); model.w = Param::from(Tensor::from_floats(Data::new( - clip_weights(weights), + clip_parameters(parameters), Shape { dims: [17] }, ))); model @@ -256,7 +256,7 @@ mod tests { #[test] fn w() { let model = Model::new(ModelConfig::default()); - assert_eq!(model.w.val().to_data(), Data::from(DEFAULT_WEIGHTS)) + assert_eq!(model.w.val().to_data(), Data::from(DEFAULT_PARAMETERS)) } #[test] @@ -363,6 +363,6 @@ mod tests { fn fsrs() { assert!(FSRS::new(Some(&[])).is_ok()); assert!(FSRS::new(Some(&[1.])).is_err()); - assert!(FSRS::new(Some(DEFAULT_WEIGHTS.as_slice())).is_ok()); + assert!(FSRS::new(Some(DEFAULT_PARAMETERS.as_slice())).is_ok()); } } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index a2a0403e..0739dcbd 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -1,6 +1,6 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{next_interval, ItemProgress, Weights, DECAY, FACTOR, S_MIN}; -use crate::{DEFAULT_WEIGHTS, FSRS}; +use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MIN}; +use crate::{DEFAULT_PARAMETERS, FSRS}; use burn::tensor::backend::Backend; use itertools::izip; use ndarray::{s, Array1, Array2, Ix0, Ix1, SliceInfoElem, Zip}; @@ -395,7 +395,7 @@ fn simulate( fn sample( config: &SimulatorConfig, - weights: &[f64], + parameters: &[f64], desired_retention: f64, n: usize, progress: &mut F, @@ -411,7 +411,7 @@ where .map(|i| { let memorization = simulate( config, - weights, + parameters, desired_retention, Some((i + 42).try_into().unwrap()), None, @@ -430,7 +430,7 @@ fn bracket( mut xa: f64, mut xb: f64, config: &SimulatorConfig, - weights: &[f64], + parameters: &[f64], progress: &mut F, ) -> Result<(f64, f64, f64, f64, f64, f64)> where @@ -442,15 +442,15 @@ where const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html const MAXITER: i32 = 20; - let mut fa = -sample(config, weights, xa, SAMPLE_SIZE, progress)?; - let mut fb = -sample(config, weights, xb, SAMPLE_SIZE, progress)?; + let mut fa = -sample(config, parameters, xa, SAMPLE_SIZE, progress)?; + let mut fb = -sample(config, parameters, xb, SAMPLE_SIZE, progress)?; if fa < fb { (fa, fb) = (fb, fa); (xa, xb) = (xb, xa); } let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM); - let mut fc = -sample(config, weights, xc, SAMPLE_SIZE, progress)?; + let mut fc = -sample(config, parameters, xc, SAMPLE_SIZE, progress)?; let mut iter = 0; while fc < fb { @@ -470,34 +470,38 @@ where let mut fw: f64; if (w - xc) * (xb - w) > 0.0 { - fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?; + fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; if fw < fc { (xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM)); (fa, fb) = (fb, fw); break; } else if fw > fb { xc = w.clamp(L_LIM, U_LIM); - fc = -sample(config, weights, xc, SAMPLE_SIZE, progress)?; + fc = -sample(config, parameters, xc, SAMPLE_SIZE, progress)?; break; } w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?; + fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; } else if (w - wlim) * (wlim - xc) >= 0.0 { w = wlim; - fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?; + fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; } else if (w - wlim) * (xc - w) > 0.0 { - fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?; + fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; if fw < fc { (xb, xc, w) = ( xc.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM), GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM), ); - (fb, fc, fw) = (fc, fw, -sample(config, weights, w, SAMPLE_SIZE, progress)?); + (fb, fc, fw) = ( + fc, + fw, + -sample(config, parameters, w, SAMPLE_SIZE, progress)?, + ); } } else { w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?; + fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; } (xa, xb, xc) = ( xb.clamp(L_LIM, U_LIM), @@ -510,23 +514,23 @@ where } impl FSRS { - /// For the given simulator parameters and weights, determine the suggested `desired_retention` + /// For the given simulator parameters and parameters, determine the suggested `desired_retention` /// value. pub fn optimal_retention( &self, config: &SimulatorConfig, - weights: &Weights, + parameters: &Parameters, mut progress: F, ) -> Result where F: FnMut(ItemProgress) -> bool + Send, { - let weights = if weights.is_empty() { - &DEFAULT_WEIGHTS - } else if weights.len() != 17 { - return Err(FSRSError::InvalidWeights); + let parameters = if parameters.is_empty() { + &DEFAULT_PARAMETERS + } else if parameters.len() != 17 { + return Err(FSRSError::InvalidParameters); } else { - weights + parameters } .iter() .map(|v| *v as f64) @@ -541,13 +545,13 @@ impl FSRS { progress(progress_info) }; - Self::brent(config, &weights, inc_progress) + Self::brent(config, ¶meters, inc_progress) } /// https://argmin-rs.github.io/argmin/argmin/solver/brent/index.html /// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2446 fn brent( config: &SimulatorConfig, - weights: &[f64], + parameters: &[f64], mut progress: F, ) -> Result where @@ -558,7 +562,7 @@ impl FSRS { let maxiter = 64; let tol = 0.01f64; - let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, weights, &mut progress)?; + let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, parameters, &mut progress)?; let (mut v, mut w, mut x) = (xb, xb, xb); let (mut fx, mut fv, mut fw) = (fb, fb, fb); @@ -616,7 +620,7 @@ impl FSRS { rat }; // calculate new output value - let fu = -sample(config, weights, u, SAMPLE_SIZE, &mut progress)?; + let fu = -sample(config, parameters, u, SAMPLE_SIZE, &mut progress)?; // if it's bigger than current if fu > fx { @@ -660,14 +664,14 @@ mod tests { use itertools::Itertools; use super::*; - use crate::DEFAULT_WEIGHTS; + use crate::DEFAULT_PARAMETERS; #[test] fn simulator() { let config = SimulatorConfig::default(); let memorization = simulate( &config, - &DEFAULT_WEIGHTS.iter().map(|v| *v as f64).collect_vec(), + &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), 0.9, None, None, @@ -701,7 +705,7 @@ mod tests { ]; let memorization = simulate( &config, - &DEFAULT_WEIGHTS.iter().map(|v| *v as f64).collect_vec(), + &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), 0.9, None, Some(cards), @@ -720,7 +724,7 @@ mod tests { }; let results = simulate( &config, - &DEFAULT_WEIGHTS.iter().map(|v| *v as f64).collect_vec(), + &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), 0.9, None, None, diff --git a/src/pre_training.rs b/src/pre_training.rs index 91d6dc2b..67050e08 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -1,15 +1,15 @@ use crate::error::{FSRSError, Result}; use crate::inference::{DECAY, FACTOR, S_MIN}; use crate::FSRSItem; -use crate::DEFAULT_WEIGHTS; +use crate::DEFAULT_PARAMETERS; use ndarray::Array1; use std::collections::HashMap; static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[ - (1, DEFAULT_WEIGHTS[0]), - (2, DEFAULT_WEIGHTS[1]), - (3, DEFAULT_WEIGHTS[2]), - (4, DEFAULT_WEIGHTS[3]), + (1, DEFAULT_PARAMETERS[0]), + (2, DEFAULT_PARAMETERS[1]), + (3, DEFAULT_PARAMETERS[2]), + (4, DEFAULT_PARAMETERS[3]), ]; pub fn pretrain(fsrs_items: Vec, average_recall: f32) -> Result<[f32; 4]> { diff --git a/src/training.rs b/src/training.rs index 3d7c12f5..c0448dd3 100644 --- a/src/training.rs +++ b/src/training.rs @@ -5,7 +5,7 @@ use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::pre_training::pretrain; use crate::weight_clipper::weight_clipper; -use crate::{FSRSError, DEFAULT_WEIGHTS, FSRS}; +use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff; use burn::data::dataloader::DataLoaderBuilder; use burn::module::Module; @@ -235,8 +235,8 @@ pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 { } impl FSRS { - /// Calculate appropriate weights for the provided review history. - pub fn compute_weights( + /// Calculate appropriate parameters for the provided review history. + pub fn compute_parameters( &self, items: Vec, pretrain_only: bool, @@ -264,11 +264,11 @@ impl FSRS { })?; if pretrain_only { finish_progress(); - let weights = initial_stability + let parameters = initial_stability .into_iter() - .chain(DEFAULT_WEIGHTS[4..].iter().copied()) + .chain(DEFAULT_PARAMETERS[4..].iter().copied()) .collect(); - return Ok(weights); + return Ok(parameters); } let config = TrainingConfig::new( ModelConfig { @@ -310,22 +310,22 @@ impl FSRS { finish_progress(); let weight_sets = weight_sets?; - let average_weights: Vec = weight_sets + let average_parameters: Vec = weight_sets .iter() - .fold(vec![0.0; weight_sets[0].len()], |sum, weights| { - sum.par_iter().zip(weights).map(|(a, b)| a + b).collect() + .fold(vec![0.0; weight_sets[0].len()], |sum, parameters| { + sum.par_iter().zip(parameters).map(|(a, b)| a + b).collect() }) .par_iter() .map(|&sum| sum / n_splits as f32) .collect(); - for weight in &average_weights { + for weight in &average_parameters { if !weight.is_finite() { return Err(FSRSError::InvalidInput); } } - Ok(average_weights) + Ok(average_parameters) } } @@ -403,9 +403,9 @@ fn train( return Err(FSRSError::Interrupted); } - info!("trained weights: {}", &model_trained.w.val()); + info!("trained parameters: {}", &model_trained.w.val()); model_trained.w = Param::from(weight_clipper(model_trained.w.val())); - info!("clipped weights: {}", &model_trained.w.val()); + info!("clipped parameters: {}", &model_trained.w.val()); if let Ok(path) = artifact_dir { PrettyJsonFileRecorder::::new() @@ -460,7 +460,7 @@ mod tests { AdamConfig::new(), ); - let weights_sets: Vec> = (0..n_splits) + let parameters_sets: Vec> = (0..n_splits) .into_par_iter() .map(|i| { let trainset = trainsets @@ -475,16 +475,16 @@ mod tests { }) .collect(); - dbg!(&weights_sets); + dbg!(¶meters_sets); - let average_weights: Vec = weights_sets + let average_parameters: Vec = parameters_sets .iter() - .fold(vec![0.0; weights_sets[0].len()], |sum, weights| { - sum.par_iter().zip(weights).map(|(a, b)| a + b).collect() + .fold(vec![0.0; parameters_sets[0].len()], |sum, parameters| { + sum.par_iter().zip(parameters).map(|(a, b)| a + b).collect() }) .par_iter() .map(|&sum| sum / n_splits as f32) .collect(); - dbg!(average_weights); + dbg!(average_parameters); } } diff --git a/src/weight_clipper.rs b/src/weight_clipper.rs index afdb1df3..93c62161 100644 --- a/src/weight_clipper.rs +++ b/src/weight_clipper.rs @@ -1,15 +1,15 @@ use crate::{ - inference::{Weights, S_MIN}, + inference::{Parameters, S_MIN}, pre_training::INIT_S_MAX, }; use burn::tensor::{backend::Backend, Data, Tensor}; -pub(crate) fn weight_clipper(weights: Tensor) -> Tensor { - let val = clip_weights(&weights.to_data().convert().value); - Tensor::from_data(Data::new(val, weights.shape()).convert()) +pub(crate) fn weight_clipper(parameters: Tensor) -> Tensor { + let val = clip_parameters(¶meters.to_data().convert().value); + Tensor::from_data(Data::new(val, parameters.shape()).convert()) } -pub(crate) fn clip_weights(weights: &Weights) -> Vec { +pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec { // https://regex101.com/r/21mXNI/1 const CLAMPS: [(f32, f32); 17] = [ (S_MIN, INIT_S_MAX), @@ -31,12 +31,12 @@ pub(crate) fn clip_weights(weights: &Weights) -> Vec { (1.0, 4.0), ]; - let mut weights = weights.to_vec(); - weights + let mut parameters = parameters.to_vec(); + parameters .iter_mut() .zip(CLAMPS) .for_each(|(w, (low, high))| *w = w.clamp(low, high)); - weights + parameters } #[cfg(test)]