Skip to content

Commit

Permalink
weights -> parameters (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
dae authored Feb 7, 2024
1 parent 6b0207a commit a79fcd1
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.3.0"
version = "0.4.0"
authors = ["Open Spaced Repetition"]
categories = ["Algorithms", "Science"]
edition = "2021"
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use snafu::Snafu;
pub enum FSRSError {
NotEnoughData,
Interrupted,
InvalidWeights,
InvalidParameters,
OptimalNotFound,
InvalidInput,
}
Expand Down
28 changes: 14 additions & 14 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ pub(crate) const DECAY: f64 = -0.5;
pub(crate) const FACTOR: f64 = 19f64 / 81f64;
pub(crate) const S_MIN: f32 = 0.01;
/// This is a slice for efficiency, but should always be 17 in length.
pub type Weights = [f32];
pub type Parameters = [f32];
use itertools::izip;

pub static DEFAULT_WEIGHTS: [f32; 17] = [
pub static DEFAULT_PARAMETERS: [f32; 17] = [
0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132,
0.0839, 0.3204, 1.3547, 0.219, 2.7849,
];
Expand Down Expand Up @@ -72,7 +72,7 @@ impl<B: Backend> FSRS<B> {
/// In the case of truncated reviews, [starting_state] can be set to the value of
/// [FSRS::memory_state_from_sm2] for the first review (which should not be included
/// in FSRSItem). If not provided, the card starts as new.
/// Weights must have been provided when calling FSRS::new().
/// Parameters must have been provided when calling FSRS::new().
pub fn memory_state(
&self,
item: FSRSItem,
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<B: Backend> FSRS<B> {

/// If a card has incomplete learning history, memory state can be approximated from
/// current sm2 values.
/// Weights must have been provided when calling FSRS::new().
/// Parameters must have been provided when calling FSRS::new().
pub fn memory_state_from_sm2(
&self,
ease_factor: f32,
Expand All @@ -129,7 +129,7 @@ impl<B: Backend> FSRS<B> {

/// Calculate the next interval for the current memory state, for rescheduling. Stability
/// should be provided except when the card is new. Rating is ignored except when card is new.
/// Weights must have been provided when calling FSRS::new().
/// Parameters must have been provided when calling FSRS::new().
pub fn next_interval(
&self,
stability: Option<f32>,
Expand All @@ -146,7 +146,7 @@ impl<B: Backend> FSRS<B> {
}

/// The intervals and memory states for each answer button.
/// Weights must have been provided when calling FSRS::new().
/// Parameters must have been provided when calling FSRS::new().
pub fn next_states(
&self,
current_memory_state: Option<MemoryState>,
Expand Down Expand Up @@ -189,8 +189,8 @@ impl<B: Backend> FSRS<B> {
})
}

/// Determine how well the model and weights predict performance.
/// Weights must have been provided when calling FSRS::new().
/// Determine how well the model and parameters predict performance.
/// Parameters must have been provided when calling FSRS::new().
pub fn evaluate<F>(&self, items: Vec<FSRSItem>, mut progress: F) -> Result<ModelEvaluation>
where
F: FnMut(ItemProgress) -> bool,
Expand Down Expand Up @@ -243,7 +243,7 @@ impl<B: Backend> FSRS<B> {
pub fn universal_metrics<F>(
&self,
items: Vec<FSRSItem>,
parameters: &Weights,
parameters: &Parameters,
mut progress: F,
) -> Result<(f32, f32)>
where
Expand Down Expand Up @@ -354,7 +354,7 @@ mod tests {
use super::*;
use crate::{convertor_tests::anki21_sample_file_converted_to_fsrs, FSRSReview};

static WEIGHTS: &[f32] = &[
static PARAMETERS: &[f32] = &[
0.81497127,
1.5411042,
4.007436,
Expand Down Expand Up @@ -415,7 +415,7 @@ mod tests {
},
],
};
let fsrs = FSRS::new(Some(WEIGHTS))?;
let fsrs = FSRS::new(Some(PARAMETERS))?;
assert_eq!(
fsrs.memory_state(item, None).unwrap(),
MemoryState {
Expand Down Expand Up @@ -464,14 +464,14 @@ mod tests {
Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.204_001, 0.025_387]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let fsrs = FSRS::new(Some(PARAMETERS))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.201_908, 0.013_894]), 5);

let (self_by_other, other_by_self) = fsrs
.universal_metrics(items, &DEFAULT_WEIGHTS, |_| true)
.universal_metrics(items, &DEFAULT_PARAMETERS, |_| true)
.unwrap();

Data::from([self_by_other, other_by_self])
Expand Down Expand Up @@ -501,7 +501,7 @@ mod tests {
},
],
};
let fsrs = FSRS::new(Some(WEIGHTS))?;
let fsrs = FSRS::new(Some(PARAMETERS))?;
let state = fsrs.memory_state(item, None).unwrap();
assert_eq!(
fsrs.next_states(Some(state), 0.9, 21).unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod weight_clipper;
pub use dataset::{FSRSItem, FSRSReview};
pub use error::{FSRSError, Result};
pub use inference::{
ItemProgress, ItemState, MemoryState, ModelEvaluation, NextStates, DEFAULT_WEIGHTS,
ItemProgress, ItemState, MemoryState, ModelEvaluation, NextStates, DEFAULT_PARAMETERS,
};
pub use model::FSRS;
pub use optimal_retention::SimulatorConfig;
Expand Down
42 changes: 21 additions & 21 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error::{FSRSError, Result};
use crate::inference::{Weights, DECAY, FACTOR, S_MIN};
use crate::weight_clipper::clip_weights;
use crate::DEFAULT_WEIGHTS;
use crate::inference::{Parameters, DECAY, FACTOR, S_MIN};
use crate::weight_clipper::clip_parameters;
use crate::DEFAULT_PARAMETERS;
use burn::backend::ndarray::NdArrayDevice;
use burn::backend::NdArray;
use burn::{
Expand Down Expand Up @@ -43,9 +43,9 @@ impl<B: Backend> Model<B> {
pub fn new(config: ModelConfig) -> Self {
let initial_params = config
.initial_stability
.unwrap_or_else(|| DEFAULT_WEIGHTS[0..4].try_into().unwrap())
.unwrap_or_else(|| DEFAULT_PARAMETERS[0..4].try_into().unwrap())
.into_iter()
.chain(DEFAULT_WEIGHTS[4..].iter().copied())
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
.collect();

Self {
Expand Down Expand Up @@ -201,47 +201,47 @@ pub struct FSRS<B: Backend = NdArray> {
}

impl FSRS<NdArray> {
/// - Weights must be provided before running commands that need them.
/// - Weights may be an empty slice to use the default values instead.
pub fn new(weights: Option<&Weights>) -> Result<Self> {
Self::new_with_backend(weights, NdArrayDevice::Cpu)
/// - Parameters must be provided before running commands that need them.
/// - Parameters may be an empty slice to use the default values instead.
pub fn new(parameters: Option<&Parameters>) -> Result<Self> {
Self::new_with_backend(parameters, NdArrayDevice::Cpu)
}
}

impl<B: Backend> FSRS<B> {
pub fn new_with_backend<B2: Backend>(
mut weights: Option<&Weights>,
mut parameters: Option<&Parameters>,
device: B2::Device,
) -> Result<FSRS<B2>> {
if let Some(weights) = &mut weights {
if weights.is_empty() {
*weights = DEFAULT_WEIGHTS.as_slice()
} else if weights.len() != 17 {
return Err(FSRSError::InvalidWeights);
if let Some(parameters) = &mut parameters {
if parameters.is_empty() {
*parameters = DEFAULT_PARAMETERS.as_slice()
} else if parameters.len() != 17 {
return Err(FSRSError::InvalidParameters);
}
}
Ok(FSRS {
model: weights.map(weights_to_model),
model: parameters.map(parameters_to_model),
device,
})
}

pub(crate) fn model(&self) -> &Model<B> {
self.model
.as_ref()
.expect("command requires weights to be set on creation")
.expect("command requires parameters to be set on creation")
}

pub(crate) fn device(&self) -> B::Device {
self.device.clone()
}
}

pub(crate) fn weights_to_model<B: Backend>(weights: &Weights) -> Model<B> {
pub(crate) fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
let config = ModelConfig::default();
let mut model = Model::new(config);
model.w = Param::from(Tensor::from_floats(Data::new(
clip_weights(weights),
clip_parameters(parameters),
Shape { dims: [17] },
)));
model
Expand All @@ -256,7 +256,7 @@ mod tests {
#[test]
fn w() {
let model = Model::new(ModelConfig::default());
assert_eq!(model.w.val().to_data(), Data::from(DEFAULT_WEIGHTS))
assert_eq!(model.w.val().to_data(), Data::from(DEFAULT_PARAMETERS))
}

#[test]
Expand Down Expand Up @@ -363,6 +363,6 @@ mod tests {
fn fsrs() {
assert!(FSRS::new(Some(&[])).is_ok());
assert!(FSRS::new(Some(&[1.])).is_err());
assert!(FSRS::new(Some(DEFAULT_WEIGHTS.as_slice())).is_ok());
assert!(FSRS::new(Some(DEFAULT_PARAMETERS.as_slice())).is_ok());
}
}
Loading

0 comments on commit a79fcd1

Please sign in to comment.