Skip to content

Commit

Permalink
Feat/check_and_fill_parameters (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Aug 22, 2024
1 parent 840d80e commit 0256410
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.1.5"
version = "1.2.0"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
2 changes: 1 addition & 1 deletion src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ mod tests {
.assert_approx_eq(&Data::from([4.170096, 9.462736]), 5);
let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.95).unwrap();
Data::from([memory_state.stability, memory_state.difficulty])
.assert_approx_eq(&Data::from([21.712555, 2.380210]), 5);
.assert_approx_eq(&Data::from([21.712555, 2.380_21]), 5);
let memory_state = fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap();
Data::from([memory_state.stability, memory_state.difficulty])
.assert_approx_eq(&Data::from([19.999992, 10.0]), 5);
Expand Down
44 changes: 25 additions & 19 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,19 @@ impl FSRS<NdArray> {

impl<B: Backend> FSRS<B> {
pub fn new_with_backend<B2: Backend>(
mut parameters: Option<&Parameters>,
parameters: Option<&Parameters>,
device: B2::Device,
) -> Result<FSRS<B2>> {
if let Some(parameters) = &mut parameters {
if parameters.is_empty() {
*parameters = DEFAULT_PARAMETERS.as_slice()
} else if parameters.len() != 19 && parameters.len() != 17 {
return Err(FSRSError::InvalidParameters);
let model = match parameters {
Some(params) => {
let parameters = check_and_fill_parameters(params)?;
let model = parameters_to_model::<B2>(&parameters);
Some(model)
}
}
Ok(FSRS {
model: parameters.map(parameters_to_model),
device,
})
None => None,
};

Ok(FSRS { model, device })
}

pub(crate) fn model(&self) -> &Model<B> {
Expand All @@ -239,20 +238,27 @@ impl<B: Backend> FSRS<B> {
pub(crate) fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
let config = ModelConfig::default();
let mut model = Model::new(config);
let new_params = if parameters.len() == 17 {
let mut new_params = parameters.to_vec();
new_params.extend_from_slice(&[0.0, 0.0]);
new_params
} else {
parameters.to_vec()
};
model.w = Param::from_tensor(Tensor::from_floats(
Data::new(clip_parameters(&new_params), Shape { dims: [19] }),
Data::new(clip_parameters(parameters), Shape { dims: [19] }),
&B::Device::default(),
));
model
}

pub(crate) fn check_and_fill_parameters(parameters: &Parameters) -> Result<Vec<f32>, FSRSError> {
let parameters = match parameters.len() {
0 => DEFAULT_PARAMETERS.to_vec(),
17 => {
let mut parameters = parameters.to_vec();
parameters.extend_from_slice(&[0.0, 0.0]);
parameters
}
19 => parameters.to_vec(),
_ => return Err(FSRSError::InvalidParameters),
};
Ok(parameters)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
63 changes: 28 additions & 35 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::error::{FSRSError, Result};
use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MIN};
use crate::{DEFAULT_PARAMETERS, FSRS};
use crate::model::check_and_fill_parameters;
use crate::FSRS;
use burn::tensor::backend::Backend;
use itertools::{izip, Itertools};
use ndarray::{s, Array1, Array2, Ix0, Ix1, SliceInfoElem, Zip};
Expand Down Expand Up @@ -143,13 +144,15 @@ pub struct Card {
pub due: f32,
}

#[allow(clippy::type_complexity)]
pub fn simulate(
config: &SimulatorConfig,
w: &[f32],
w: &Parameters,
desired_retention: f32,
seed: Option<u64>,
existing_cards: Option<Vec<Card>>,
) -> (Array1<f32>, Array1<usize>, Array1<usize>, Array1<f32>) {
) -> Result<(Array1<f32>, Array1<usize>, Array1<usize>, Array1<f32>), FSRSError> {
let w = &check_and_fill_parameters(w)?;
let SimulatorConfig {
deck_size,
learn_span,
Expand Down Expand Up @@ -440,28 +443,28 @@ pub fn simulate(
.sum();
}

(
Ok((
memorized_cnt_per_day,
review_cnt_per_day,
learn_cnt_per_day,
cost_per_day,
)
))
}

fn sample<F>(
config: &SimulatorConfig,
parameters: &[f32],
parameters: &Parameters,
desired_retention: f32,
n: usize,
progress: &mut F,
) -> Result<f32>
) -> Result<f32, FSRSError>
where
F: FnMut() -> bool,
{
if !progress() {
return Err(FSRSError::Interrupted);
}
Ok((0..n)
let results: Result<Vec<f32>, FSRSError> = (0..n)
.into_par_iter()
.map(|i| {
let (memorized_cnt_per_day, _, _, cost_per_day) = simulate(
Expand All @@ -470,13 +473,13 @@ where
desired_retention,
Some((i + 42).try_into().unwrap()),
None,
);
)?;
let total_memorized = memorized_cnt_per_day[memorized_cnt_per_day.len() - 1];
let total_cost = cost_per_day.sum();
total_cost / total_memorized
Ok(total_cost / total_memorized)
})
.sum::<f32>()
/ n as f32)
.collect();
results.map(|v| v.iter().sum::<f32>() / n as f32)
}

impl<B: Backend> FSRS<B> {
Expand All @@ -491,19 +494,6 @@ impl<B: Backend> FSRS<B> {
where
F: FnMut(ItemProgress) -> bool + Send,
{
let parameters = if parameters.is_empty() {
DEFAULT_PARAMETERS.to_vec()
} else if parameters.len() != 19 {
if parameters.len() == 17 {
let mut parameters = parameters.to_vec();
parameters.extend_from_slice(&[0.0, 0.0]);
parameters
} else {
return Err(FSRSError::InvalidParameters);
}
} else {
parameters.to_vec()
};
let mut progress_info = ItemProgress {
current: 0,
// not provided for this method
Expand All @@ -514,13 +504,13 @@ impl<B: Backend> FSRS<B> {
progress(progress_info)
};

Self::brent(config, &parameters, inc_progress)
Self::brent(config, parameters, inc_progress)
}
/// https://argmin-rs.github.io/argmin/argmin/solver/brent/index.html
/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2446
fn brent<F>(
config: &SimulatorConfig,
parameters: &[f32],
parameters: &Parameters,
mut progress: F,
) -> Result<f32, FSRSError>
where
Expand Down Expand Up @@ -1002,18 +992,19 @@ mod tests {
use crate::{convertor_tests::read_collection, DEFAULT_PARAMETERS};

#[test]
fn simulator() {
fn simulator() -> Result<()> {
let config = SimulatorConfig::default();
let (memorized_cnt_per_day, _, _, _) =
simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None);
simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?;
assert_eq!(
memorized_cnt_per_day[memorized_cnt_per_day.len() - 1],
6521.068
)
);
Ok(())
}

#[test]
fn simulate_with_existing_cards() {
fn simulate_with_existing_cards() -> Result<()> {
let config = SimulatorConfig {
learn_span: 30,
learn_limit: 60,
Expand All @@ -1035,20 +1026,21 @@ mod tests {
due: 0.0,
},
];
let memorization = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, Some(cards));
let memorization = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, Some(cards))?;
dbg!(memorization);
Ok(())
}

#[test]
fn simulate_with_learn_review_limit() {
fn simulate_with_learn_review_limit() -> Result<()> {
let config = SimulatorConfig {
learn_span: 30,
learn_limit: 60,
review_limit: 200,
max_cost_perday: f32::INFINITY,
..Default::default()
};
let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None);
let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?;
assert_eq!(
results.1.to_vec(),
vec![
Expand All @@ -1059,7 +1051,8 @@ mod tests {
assert_eq!(
results.2.to_vec(),
vec![config.learn_limit; config.learn_span]
)
);
Ok(())
}

#[test]
Expand Down

0 comments on commit 0256410

Please sign in to comment.