From edf60be31477d55ff741a00f33ede51174435ad3 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 19 Dec 2024 15:21:01 +0800 Subject: [PATCH 1/8] support sample weights --- src/batch_shuffle.rs | 5 +-- src/convertor_tests.rs | 8 +++-- src/dataset.rs | 73 ++++++++++++++++++++++++++++++++---------- src/inference.rs | 11 +++++-- src/training.rs | 21 ++++++++---- 5 files changed, 87 insertions(+), 31 deletions(-) diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index 9ed15016..5cd20525 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -108,14 +108,15 @@ mod tests { use super::*; use crate::{ - convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::prepare_training_data, + convertor_tests::anki21_sample_file_converted_to_fsrs, + dataset::{prepare_training_data, simple_weighted_fsrs_items}, }; #[test] fn test_simple_dataloader() { let train_set = anki21_sample_file_converted_to_fsrs(); let (_pre_train_set, train_set) = prepare_training_data(train_set); - let dataset = FSRSDataset::from(train_set); + let dataset = FSRSDataset::from(simple_weighted_fsrs_items(train_set)); let batch_size = 512; let seed = 114514; let device = NdArrayDevice::Cpu; diff --git a/src/convertor_tests.rs b/src/convertor_tests.rs index 2c732d9b..b431351a 100644 --- a/src/convertor_tests.rs +++ b/src/convertor_tests.rs @@ -1,5 +1,5 @@ use crate::convertor_tests::RevlogReviewKind::*; -use crate::dataset::FSRSBatcher; +use crate::dataset::{simple_weighted_fsrs_items, FSRSBatcher}; use crate::dataset::{FSRSItem, FSRSReview}; use crate::optimal_retention::{RevlogEntry, RevlogReviewKind}; use crate::test_helpers::NdArrayAutodiff; @@ -256,7 +256,7 @@ fn conversion_works() { ); // convert a subset and check it matches expectations - let mut fsrs_items = single_card_revlog + let fsrs_items = single_card_revlog .into_iter() .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) .flatten() @@ -387,9 +387,11 @@ fn conversion_works() { ] ); + let mut weighted_fsrs_items = simple_weighted_fsrs_items(fsrs_items); + let device = NdArrayDevice::Cpu; let batcher = FSRSBatcher::::new(device); - let res = batcher.batch(vec![fsrs_items.pop().unwrap()]); + let res = batcher.batch(vec![weighted_fsrs_items.pop().unwrap()]); assert_eq!(res.delta_ts.into_scalar(), 64.0); assert_eq!( res.r_historys.squeeze(1).to_data(), diff --git a/src/dataset.rs b/src/dataset.rs index 043b143a..267f1488 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -19,6 +19,12 @@ pub struct FSRSItem { pub reviews: Vec, } +#[derive(Debug, Clone)] +pub(crate) struct WeightedFSRSItem { + pub weight: f32, + pub item: FSRSItem, +} + #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] pub struct FSRSReview { /// 1-4 @@ -88,13 +94,14 @@ pub(crate) struct FSRSBatch { pub r_historys: Tensor, pub delta_ts: Tensor, pub labels: Tensor, + pub weights: Tensor, } -impl Batcher> for FSRSBatcher { - fn batch(&self, items: Vec) -> FSRSBatch { +impl Batcher> for FSRSBatcher { + fn batch(&self, items: Vec) -> FSRSBatch { let pad_size = items .iter() - .map(|x| x.reviews.len()) + .map(|x| x.item.reviews.len()) .max() .expect("FSRSItem is empty") - 1; @@ -103,7 +110,7 @@ impl Batcher> for FSRSBatcher { .iter() .map(|item| { let (mut delta_t, mut rating): (Vec<_>, Vec<_>) = - item.history().map(|r| (r.delta_t, r.rating)).unzip(); + item.item.history().map(|r| (r.delta_t, r.rating)).unzip(); delta_t.resize(pad_size, 0); rating.resize(pad_size, 0); let delta_t = Tensor::from_data( @@ -130,19 +137,23 @@ impl Batcher> for FSRSBatcher { }) .unzip(); - let (delta_ts, labels) = items + let (delta_ts, labels, weights) = items .iter() .map(|item| { - let current = item.current(); - let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device); + let current = item.item.current(); + let delta_t: Tensor = + Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device); let label = match current.rating { 1 => 0.0, _ => 1.0, }; - let label = Tensor::from_data(Data::from([label.elem()]), &self.device); - (delta_t, label) + let label: Tensor = + Tensor::from_data(Data::from([label.elem()]), &self.device); + let weight: Tensor = + Tensor::from_data(Data::from([item.weight.elem()]), &self.device); + (delta_t, label, weight) }) - .unzip(); + .multiunzip(); let t_historys = Tensor::cat(time_histories, 0) .transpose() @@ -152,6 +163,7 @@ impl Batcher> for FSRSBatcher { .to_device(&self.device); // [seq_len, batch_size] let delta_ts = Tensor::cat(delta_ts, 0).to_device(&self.device); let labels = Tensor::cat(labels, 0).to_device(&self.device); + let weights = Tensor::cat(weights, 0).to_device(&self.device); // dbg!(&items[0].t_history); // dbg!(&t_historys); @@ -161,27 +173,28 @@ impl Batcher> for FSRSBatcher { r_historys, delta_ts, labels, + weights, } } } pub(crate) struct FSRSDataset { - pub(crate) items: Vec, + pub(crate) items: Vec, } -impl Dataset for FSRSDataset { +impl Dataset for FSRSDataset { fn len(&self) -> usize { self.items.len() } - fn get(&self, index: usize) -> Option { + fn get(&self, index: usize) -> Option { // info!("get {}", index); self.items.get(index).cloned() } } -impl From> for FSRSDataset { - fn from(items: Vec) -> Self { +impl From> for FSRSDataset { + fn from(items: Vec) -> Self { Self { items } } } @@ -252,6 +265,26 @@ pub fn prepare_training_data(items: Vec) -> (Vec, Vec) -> Vec { + items + .into_iter() + .map(|item| WeightedFSRSItem { weight: 1.0, item }) + .collect() +} + +/// The input items should be sorted by the review timestamp. +pub(crate) fn recency_weighted_fsrs_items(items: Vec) -> Vec { + let length = items.len() as f32; + items + .into_iter() + .enumerate() + .map(|(idx, item)| WeightedFSRSItem { + weight: idx as f32 / length + 0.5, + item, + }) + .collect() +} + #[cfg(test)] mod tests { use super::*; @@ -261,9 +294,11 @@ mod tests { fn from_anki() { use burn::data::dataloader::Dataset; - let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs()); + let dataset = FSRSDataset::from(simple_weighted_fsrs_items( + anki21_sample_file_converted_to_fsrs(), + )); assert_eq!( - dataset.get(704).unwrap(), + dataset.get(704).unwrap().item, FSRSItem { reviews: vec![ FSRSReview { @@ -435,6 +470,10 @@ mod tests { ], }, ]; + let items = items + .into_iter() + .map(|item| WeightedFSRSItem { weight: 1.0, item }) + .collect(); let batch = batcher.batch(items); assert_eq!( batch.t_historys.to_data(), diff --git a/src/inference.rs b/src/inference.rs index bd095fd9..6687f9b1 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -6,8 +6,8 @@ use burn::nn::loss::Reduction; use burn::tensor::{Data, Shape, Tensor}; use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend}; -use crate::dataset::FSRSBatch; use crate::dataset::FSRSBatcher; +use crate::dataset::{simple_weighted_fsrs_items, FSRSBatch}; use crate::error::Result; use crate::model::Model; use crate::training::BCELoss; @@ -210,9 +210,11 @@ impl FSRS { if items.is_empty() { return Err(FSRSError::NotEnoughData); } + let items = simple_weighted_fsrs_items(items); let batcher = FSRSBatcher::new(self.device()); let mut all_retention = vec![]; let mut all_labels = vec![]; + let mut all_weights = vec![]; let mut progress_info = ItemProgress { current: 0, total: items.len(), @@ -227,8 +229,9 @@ impl FSRS { let true_val = batch.labels.clone().to_data().convert::().value; all_retention.push(retention); all_labels.push(batch.labels); + all_weights.push(batch.weights); izip!(chunk, pred, true_val).for_each(|(item, p, y)| { - let bin = item.r_matrix_index(); + let bin = item.item.r_matrix_index(); let (pred, real, count) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0)); *pred += p; *real += y; @@ -251,7 +254,8 @@ impl FSRS { .sqrt(); let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); - let loss = BCELoss::new().forward(all_retention, all_labels, Reduction::Mean); + let all_weights = Tensor::cat(all_weights, 0); + let loss = BCELoss::new().forward(all_retention, all_labels, all_weights, Reduction::Mean); Ok(ModelEvaluation { log_loss: loss.to_data().value[0].elem(), rmse_bins: rmse, @@ -278,6 +282,7 @@ impl FSRS { if items.is_empty() { return Err(FSRSError::NotEnoughData); } + let items = simple_weighted_fsrs_items(items); let batcher = FSRSBatcher::new(self.device()); let mut all_predictions_self = vec![]; let mut all_predictions_other = vec![]; diff --git a/src/training.rs b/src/training.rs index d64b714c..697906a2 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,6 +1,6 @@ use crate::batch_shuffle::{BatchTensorDataset, ShuffleDataLoader}; use crate::cosine_annealing::CosineAnnealingLR; -use crate::dataset::{prepare_training_data, FSRSDataset, FSRSItem}; +use crate::dataset::{prepare_training_data, recency_weighted_fsrs_items, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::parameter_clipper::parameter_clipper; @@ -37,10 +37,12 @@ impl BCELoss { &self, retentions: Tensor, labels: Tensor, + weights: Tensor, mean: Reduction, ) -> Tensor { - let loss = - labels.clone() * retentions.clone().log() + (-labels + 1) * (-retentions + 1).log(); + let loss = (labels.clone() * retentions.clone().log() + + (-labels + 1) * (-retentions + 1).log()) + * weights; // info!("loss: {}", &loss); match mean { Reduction::Mean => loss.mean().neg(), @@ -57,13 +59,14 @@ impl Model { r_historys: Tensor, delta_ts: Tensor, labels: Tensor, + weights: Tensor, reduce: Reduction, ) -> Tensor { // info!("t_historys: {}", &t_historys); // info!("r_historys: {}", &r_historys); let state = self.forward(t_historys, r_historys, None); let retention = self.power_forgetting_curve(delta_ts, state.stability); - BCELoss::new().forward(retention, labels.float(), reduce) + BCELoss::new().forward(retention, labels.float(), weights, reduce) } } @@ -325,14 +328,14 @@ fn train( // Training data let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs; let batch_dataset = BatchTensorDataset::::new( - FSRSDataset::from(train_set), + FSRSDataset::from(recency_weighted_fsrs_items(train_set)), config.batch_size, device.clone(), ); let dataloader_train = ShuffleDataLoader::new(batch_dataset, config.seed); let batch_dataset = BatchTensorDataset::::new( - FSRSDataset::from(test_set.clone()), + FSRSDataset::from(recency_weighted_fsrs_items(test_set.clone())), config.batch_size, device, ); @@ -365,6 +368,7 @@ fn train( item.r_historys, item.delta_ts, item.labels, + item.weights, Reduction::Sum, ); let mut gradients = loss.backward(); @@ -399,6 +403,7 @@ fn train( batch.r_historys, batch.delta_ts, batch.labels, + batch.weights, Reduction::Sum, ); let loss = loss.into_data().convert::().value[0]; @@ -493,6 +498,7 @@ mod tests { ), delta_ts: Tensor::from_floats(Data::from([4.0, 11.0, 12.0, 23.0]), &device), labels: Tensor::from_ints(Data::from([1, 1, 1, 0]), &device), + weights: Tensor::from_floats(Data::from([1.0, 1.0, 1.0, 1.0]), &device), }; let loss = model.forward_classification( @@ -500,6 +506,7 @@ mod tests { item.r_historys, item.delta_ts, item.labels, + item.weights, Reduction::Sum, ); @@ -559,6 +566,7 @@ mod tests { ), delta_ts: Tensor::from_floats(Data::from([4.0, 11.0, 12.0, 23.0]), &device), labels: Tensor::from_ints(Data::from([1, 1, 1, 0]), &device), + weights: Tensor::from_floats(Data::from([1.0, 1.0, 1.0, 1.0]), &device), }; let loss = model.forward_classification( @@ -566,6 +574,7 @@ mod tests { item.r_historys, item.delta_ts, item.labels, + item.weights, Reduction::Sum, ); assert_eq!(loss.clone().into_data().convert::().value[0], 4.176347); From f1c1371e8a21d062d1780d000ef34e7121381e7c Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 19 Dec 2024 16:12:02 +0800 Subject: [PATCH 2/8] don't sort by length of reviews at first --- src/batch_shuffle.rs | 6 +++++- src/convertor_tests.rs | 19 +++++++++++-------- src/dataset.rs | 18 ++++++++++++------ src/training.rs | 15 ++++++++++----- 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index 5cd20525..34ce5e78 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -105,6 +105,7 @@ mod tests { backend::{ndarray::NdArrayDevice, NdArray}, tensor::Shape, }; + use itertools::Itertools; use super::*; use crate::{ @@ -114,7 +115,10 @@ mod tests { #[test] fn test_simple_dataloader() { - let train_set = anki21_sample_file_converted_to_fsrs(); + let train_set = anki21_sample_file_converted_to_fsrs() + .into_iter() + .sorted_by_cached_key(|item| item.reviews.len()) + .collect(); let (_pre_train_set, train_set) = prepare_training_data(train_set); let dataset = FSRSDataset::from(simple_weighted_fsrs_items(train_set)); let batch_size = 512; diff --git a/src/convertor_tests.rs b/src/convertor_tests.rs index b431351a..1df5ac0b 100644 --- a/src/convertor_tests.rs +++ b/src/convertor_tests.rs @@ -94,7 +94,7 @@ fn convert_to_fsrs_items( mut entries: Vec, next_day_starts_at: i64, timezone: Tz, -) -> Option> { +) -> Option> { // entries = filter_out_cram(entries); // entries = filter_out_manual(entries); entries = remove_revlog_before_last_first_learn(entries); @@ -110,7 +110,7 @@ fn convert_to_fsrs_items( .iter() .enumerate() .skip(1) - .map(|(idx, _)| { + .map(|(idx, entry)| { let reviews = entries .iter() .take(idx + 1) @@ -119,9 +119,9 @@ fn convert_to_fsrs_items( delta_t: r.last_interval.max(0) as u32, }) .collect(); - FSRSItem { reviews } + (entry.id, FSRSItem { reviews }) }) - .filter(|item| item.current().delta_t > 0) + .filter(|(_, item)| item.current().delta_t > 0) .collect(), ) } @@ -137,8 +137,8 @@ pub(crate) fn anki_to_fsrs(revlogs: Vec) -> Vec { }) .flatten() .collect_vec(); - revlogs.sort_by_cached_key(|r| r.reviews.len()); - revlogs + revlogs.sort_by_cached_key(|(id, _)| *id); + revlogs.into_iter().map(|(_, item)| item).collect() } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -260,6 +260,7 @@ fn conversion_works() { .into_iter() .filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai)) .flatten() + .map(|(_, item)| item) .collect_vec(); assert_eq!( fsrs_items, @@ -445,7 +446,8 @@ fn delta_t_is_correct() -> Result<()> { ], NEXT_DAY_AT, Tz::Asia__Shanghai - ), + ) + .map(|items| items.into_iter().map(|(_, item)| item).collect_vec()), Some(vec![FSRSItem { reviews: vec![ FSRSReview { @@ -470,7 +472,8 @@ fn delta_t_is_correct() -> Result<()> { ], NEXT_DAY_AT, Tz::Asia__Shanghai - ), + ) + .map(|items| items.into_iter().map(|(_, item)| item).collect_vec()), Some(vec![ FSRSItem { reviews: vec![ diff --git a/src/dataset.rs b/src/dataset.rs index 267f1488..264b813a 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -265,6 +265,12 @@ pub fn prepare_training_data(items: Vec) -> (Vec, Vec) -> Vec { + let mut items = items; + items.sort_by_cached_key(|item| item.item.reviews.len()); + items +} + pub(crate) fn simple_weighted_fsrs_items(items: Vec) -> Vec { items .into_iter() @@ -294,21 +300,21 @@ mod tests { fn from_anki() { use burn::data::dataloader::Dataset; - let dataset = FSRSDataset::from(simple_weighted_fsrs_items( + let dataset = FSRSDataset::from(sort_items_by_review_length(simple_weighted_fsrs_items( anki21_sample_file_converted_to_fsrs(), - )); + ))); assert_eq!( dataset.get(704).unwrap().item, FSRSItem { reviews: vec![ FSRSReview { - rating: 3, - delta_t: 0, + rating: 4, + delta_t: 0 }, FSRSReview { rating: 3, - delta_t: 1, - }, + delta_t: 3 + } ], } ); diff --git a/src/training.rs b/src/training.rs index 697906a2..d266a227 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,6 +1,9 @@ use crate::batch_shuffle::{BatchTensorDataset, ShuffleDataLoader}; use crate::cosine_annealing::CosineAnnealingLR; -use crate::dataset::{prepare_training_data, recency_weighted_fsrs_items, FSRSDataset, FSRSItem}; +use crate::dataset::{ + prepare_training_data, recency_weighted_fsrs_items, sort_items_by_review_length, FSRSDataset, + FSRSItem, +}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::parameter_clipper::parameter_clipper; @@ -238,7 +241,6 @@ impl FSRS { AdamConfig::new().with_epsilon(1e-8), ); train_set.retain(|item| item.reviews.len() <= config.max_seq_len); - train_set.sort_by_cached_key(|item| item.reviews.len()); if let Some(progress) = &progress { let progress_state = ProgressState { @@ -308,7 +310,6 @@ impl FSRS { AdamConfig::new().with_epsilon(1e-8), ); train_set.retain(|item| item.reviews.len() <= config.max_seq_len); - train_set.sort_by_cached_key(|item| item.reviews.len()); let model = train::>(train_set.clone(), train_set, &config, self.device(), None); let parameters: Vec = model.unwrap().w.val().to_data().convert().value; @@ -328,14 +329,18 @@ fn train( // Training data let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs; let batch_dataset = BatchTensorDataset::::new( - FSRSDataset::from(recency_weighted_fsrs_items(train_set)), + FSRSDataset::from(sort_items_by_review_length(recency_weighted_fsrs_items( + train_set, + ))), config.batch_size, device.clone(), ); let dataloader_train = ShuffleDataLoader::new(batch_dataset, config.seed); let batch_dataset = BatchTensorDataset::::new( - FSRSDataset::from(recency_weighted_fsrs_items(test_set.clone())), + FSRSDataset::from(sort_items_by_review_length(recency_weighted_fsrs_items( + test_set.clone(), + ))), config.batch_size, device, ); From aeac505646239760151d674299c1970e0ec98d9e Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 19 Dec 2024 18:08:10 +0800 Subject: [PATCH 3/8] apply recency weighting to evaluation --- src/batch_shuffle.rs | 4 ++-- src/convertor_tests.rs | 4 ++-- src/dataset.rs | 5 +++-- src/inference.rs | 27 +++++++++++++-------------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index 34ce5e78..1dc8e291 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -110,7 +110,7 @@ mod tests { use super::*; use crate::{ convertor_tests::anki21_sample_file_converted_to_fsrs, - dataset::{prepare_training_data, simple_weighted_fsrs_items}, + dataset::{constant_weighted_fsrs_items, prepare_training_data}, }; #[test] @@ -120,7 +120,7 @@ mod tests { .sorted_by_cached_key(|item| item.reviews.len()) .collect(); let (_pre_train_set, train_set) = prepare_training_data(train_set); - let dataset = FSRSDataset::from(simple_weighted_fsrs_items(train_set)); + let dataset = FSRSDataset::from(constant_weighted_fsrs_items(train_set)); let batch_size = 512; let seed = 114514; let device = NdArrayDevice::Cpu; diff --git a/src/convertor_tests.rs b/src/convertor_tests.rs index 1df5ac0b..9f8c2328 100644 --- a/src/convertor_tests.rs +++ b/src/convertor_tests.rs @@ -1,5 +1,5 @@ use crate::convertor_tests::RevlogReviewKind::*; -use crate::dataset::{simple_weighted_fsrs_items, FSRSBatcher}; +use crate::dataset::{constant_weighted_fsrs_items, FSRSBatcher}; use crate::dataset::{FSRSItem, FSRSReview}; use crate::optimal_retention::{RevlogEntry, RevlogReviewKind}; use crate::test_helpers::NdArrayAutodiff; @@ -388,7 +388,7 @@ fn conversion_works() { ] ); - let mut weighted_fsrs_items = simple_weighted_fsrs_items(fsrs_items); + let mut weighted_fsrs_items = constant_weighted_fsrs_items(fsrs_items); let device = NdArrayDevice::Cpu; let batcher = FSRSBatcher::::new(device); diff --git a/src/dataset.rs b/src/dataset.rs index 264b813a..bfe1726e 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -271,7 +271,8 @@ pub(crate) fn sort_items_by_review_length(items: Vec) -> Vec) -> Vec { +#[cfg(test)] +pub(crate) fn constant_weighted_fsrs_items(items: Vec) -> Vec { items .into_iter() .map(|item| WeightedFSRSItem { weight: 1.0, item }) @@ -300,7 +301,7 @@ mod tests { fn from_anki() { use burn::data::dataloader::Dataset; - let dataset = FSRSDataset::from(sort_items_by_review_length(simple_weighted_fsrs_items( + let dataset = FSRSDataset::from(sort_items_by_review_length(constant_weighted_fsrs_items( anki21_sample_file_converted_to_fsrs(), ))); assert_eq!( diff --git a/src/inference.rs b/src/inference.rs index 6687f9b1..a54e1432 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -6,8 +6,7 @@ use burn::nn::loss::Reduction; use burn::tensor::{Data, Shape, Tensor}; use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend}; -use crate::dataset::FSRSBatcher; -use crate::dataset::{simple_weighted_fsrs_items, FSRSBatch}; +use crate::dataset::{recency_weighted_fsrs_items, FSRSBatch, FSRSBatcher}; use crate::error::Result; use crate::model::Model; use crate::training::BCELoss; @@ -210,7 +209,7 @@ impl FSRS { if items.is_empty() { return Err(FSRSError::NotEnoughData); } - let items = simple_weighted_fsrs_items(items); + let items = recency_weighted_fsrs_items(items); let batcher = FSRSBatcher::new(self.device()); let mut all_retention = vec![]; let mut all_labels = vec![]; @@ -232,10 +231,10 @@ impl FSRS { all_weights.push(batch.weights); izip!(chunk, pred, true_val).for_each(|(item, p, y)| { let bin = item.item.r_matrix_index(); - let (pred, real, count) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0)); + let (pred, real, weight) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0)); *pred += p; *real += y; - *count += 1.0; + *weight += item.weight; }); progress_info.current += chunk.len(); if !progress(progress_info) { @@ -244,13 +243,13 @@ impl FSRS { } let rmse = (r_matrix .values() - .map(|(pred, real, count)| { - let pred = pred / count; - let real = real / count; - (pred - real).powi(2) * count + .map(|(pred, real, weight)| { + let pred = pred / weight; + let real = real / weight; + (pred - real).powi(2) * weight }) .sum::() - / r_matrix.values().map(|(_, _, count)| count).sum::()) + / r_matrix.values().map(|(_, _, weight)| weight).sum::()) .sqrt(); let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); @@ -282,7 +281,7 @@ impl FSRS { if items.is_empty() { return Err(FSRSError::NotEnoughData); } - let items = simple_weighted_fsrs_items(items); + let items = recency_weighted_fsrs_items(items); let batcher = FSRSBatcher::new(self.device()); let mut all_predictions_self = vec![]; let mut all_predictions_other = vec![]; @@ -499,17 +498,17 @@ mod tests { ]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.211007, 0.037216]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.034676]); let fsrs = FSRS::new(Some(&[]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216286, 0.038692]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.036590]); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203049, 0.027558]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.025646]); let (self_by_other, other_by_self) = fsrs .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true) From e0665e883026121652afd37ebd1e24c844d9e172 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 19 Dec 2024 18:24:10 +0800 Subject: [PATCH 4/8] fix rmse --- src/inference.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index a54e1432..0b229075 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -219,7 +219,7 @@ impl FSRS { total: items.len(), }; let model = self.model(); - let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32)> = HashMap::new(); + let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32, f32)> = HashMap::new(); for chunk in items.chunks(512) { let batch = batcher.batch(chunk.to_vec()); @@ -231,9 +231,11 @@ impl FSRS { all_weights.push(batch.weights); izip!(chunk, pred, true_val).for_each(|(item, p, y)| { let bin = item.item.r_matrix_index(); - let (pred, real, weight) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0)); + let (pred, real, count, weight) = + r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0, 0.0)); *pred += p; *real += y; + *count += 1.0; *weight += item.weight; }); progress_info.current += chunk.len(); @@ -243,13 +245,16 @@ impl FSRS { } let rmse = (r_matrix .values() - .map(|(pred, real, weight)| { - let pred = pred / weight; - let real = real / weight; + .map(|(pred, real, count, weight)| { + let pred = pred / count; + let real = real / count; (pred - real).powi(2) * weight }) .sum::() - / r_matrix.values().map(|(_, _, weight)| weight).sum::()) + / r_matrix + .values() + .map(|(_, _, _, weight)| weight) + .sum::()) .sqrt(); let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); @@ -498,17 +503,17 @@ mod tests { ]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.034676]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.040148]); let fsrs = FSRS::new(Some(&[]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.036590]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.041336]); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.025646]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.029828]); let (self_by_other, other_by_self) = fsrs .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true) From 10e461c2b677806b3453341bc7631a8a0a610720 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 19 Dec 2024 18:42:27 +0800 Subject: [PATCH 5/8] refactor complex type --- src/inference.rs | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 0b229075..dab2a588 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -69,6 +69,14 @@ pub fn next_interval(stability: f32, desired_retention: f32) -> f32 { stability / FACTOR as f32 * (desired_retention.powf(1.0 / DECAY as f32) - 1.0) } +#[derive(Default)] +struct RMatrixValue { + predicted: f32, + actual: f32, + count: f32, + weight: f32, +} + impl FSRS { /// Calculate the current memory state for a given card's history of reviews. /// In the case of truncated reviews, [starting_state] can be set to the value of @@ -219,7 +227,7 @@ impl FSRS { total: items.len(), }; let model = self.model(); - let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32, f32)> = HashMap::new(); + let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new(); for chunk in items.chunks(512) { let batch = batcher.batch(chunk.to_vec()); @@ -231,12 +239,11 @@ impl FSRS { all_weights.push(batch.weights); izip!(chunk, pred, true_val).for_each(|(item, p, y)| { let bin = item.item.r_matrix_index(); - let (pred, real, count, weight) = - r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0, 0.0)); - *pred += p; - *real += y; - *count += 1.0; - *weight += item.weight; + let value = r_matrix.entry(bin).or_default(); + value.predicted += p; + value.actual += y; + value.count += 1.0; + value.weight += item.weight; }); progress_info.current += chunk.len(); if !progress(progress_info) { @@ -245,16 +252,13 @@ impl FSRS { } let rmse = (r_matrix .values() - .map(|(pred, real, count, weight)| { - let pred = pred / count; - let real = real / count; - (pred - real).powi(2) * weight + .map(|v| { + let pred = v.predicted / v.count; + let real = v.actual / v.count; + (pred - real).powi(2) * v.weight }) .sum::() - / r_matrix - .values() - .map(|(_, _, _, weight)| weight) - .sum::()) + / r_matrix.values().map(|v| v.weight).sum::()) .sqrt(); let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); From 38d9a8661c94097fa867ca66f4087ccdd2780e6a Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 23 Dec 2024 16:35:02 +0800 Subject: [PATCH 6/8] improve recency weighting --- src/dataset.rs | 2 +- src/inference.rs | 13 +++++++------ src/training.rs | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/dataset.rs b/src/dataset.rs index bfe1726e..fa869e14 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -286,7 +286,7 @@ pub(crate) fn recency_weighted_fsrs_items(items: Vec) -> Vec FSRS { let all_retention = Tensor::cat(all_retention, 0); let all_labels = Tensor::cat(all_labels, 0).float(); let all_weights = Tensor::cat(all_weights, 0); - let loss = BCELoss::new().forward(all_retention, all_labels, all_weights, Reduction::Mean); + let loss = BCELoss::new().forward(all_retention, all_labels, all_weights, Reduction::Auto); Ok(ModelEvaluation { log_loss: loss.to_data().value[0].elem(), rmse_bins: rmse, @@ -502,22 +502,23 @@ mod tests { let items = [pretrainset, trainset].concat(); let fsrs = FSRS::new(Some(&[ - 0.669, 1.679, 4.1355, 9.862, 7.9435, 0.9379, 1.0148, 0.1588, 1.3851, 0.1248, 0.8421, - 1.992, 0.153, 0.284, 2.4282, 0.2547, 3.1847, 0.2196, 0.1906, + 0.6032805, 1.3376843, 4.4167747, 9.933699, 7.654044, 0.78219295, 2.336606, 0.001, + 1.3264198, 0.12967199, 0.82880765, 1.9360433, 0.13298263, 0.27427456, 2.4304862, + 0.10340813, 3.108867, 0.2114512, 0.2826002, ]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.040148]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.206160, 0.025809]); let fsrs = FSRS::new(Some(&[]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.041336]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.223601, 0.042738]); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.029828]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.208656, 0.030946]); let (self_by_other, other_by_self) = fsrs .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true) diff --git a/src/training.rs b/src/training.rs index d266a227..338cee58 100644 --- a/src/training.rs +++ b/src/training.rs @@ -45,12 +45,12 @@ impl BCELoss { ) -> Tensor { let loss = (labels.clone() * retentions.clone().log() + (-labels + 1) * (-retentions + 1).log()) - * weights; + * weights.clone(); // info!("loss: {}", &loss); match mean { Reduction::Mean => loss.mean().neg(), Reduction::Sum => loss.sum().neg(), - Reduction::Auto => loss.neg(), + Reduction::Auto => (loss.sum() / weights.sum()).neg(), } } } From 6812b92c0bf29405d319b7abe00b43769b74431b Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 24 Dec 2024 11:43:39 +0800 Subject: [PATCH 7/8] rename --- src/dataset.rs | 33 ++++++++++++++++++--------------- src/inference.rs | 22 ++++++++++++---------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/dataset.rs b/src/dataset.rs index fa869e14..8504a127 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -98,19 +98,22 @@ pub(crate) struct FSRSBatch { } impl Batcher> for FSRSBatcher { - fn batch(&self, items: Vec) -> FSRSBatch { - let pad_size = items + fn batch(&self, weighted_items: Vec) -> FSRSBatch { + let pad_size = weighted_items .iter() .map(|x| x.item.reviews.len()) .max() .expect("FSRSItem is empty") - 1; - let (time_histories, rating_histories) = items + let (time_histories, rating_histories) = weighted_items .iter() - .map(|item| { - let (mut delta_t, mut rating): (Vec<_>, Vec<_>) = - item.item.history().map(|r| (r.delta_t, r.rating)).unzip(); + .map(|weighted_item| { + let (mut delta_t, mut rating): (Vec<_>, Vec<_>) = weighted_item + .item + .history() + .map(|r| (r.delta_t, r.rating)) + .unzip(); delta_t.resize(pad_size, 0); rating.resize(pad_size, 0); let delta_t = Tensor::from_data( @@ -137,10 +140,10 @@ impl Batcher> for FSRSBatcher { }) .unzip(); - let (delta_ts, labels, weights) = items + let (delta_ts, labels, weights) = weighted_items .iter() - .map(|item| { - let current = item.item.current(); + .map(|weighted_item| { + let current = weighted_item.item.current(); let delta_t: Tensor = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device); let label = match current.rating { @@ -150,7 +153,7 @@ impl Batcher> for FSRSBatcher { let label: Tensor = Tensor::from_data(Data::from([label.elem()]), &self.device); let weight: Tensor = - Tensor::from_data(Data::from([item.weight.elem()]), &self.device); + Tensor::from_data(Data::from([weighted_item.weight.elem()]), &self.device); (delta_t, label, weight) }) .multiunzip(); @@ -265,13 +268,13 @@ pub fn prepare_training_data(items: Vec) -> (Vec, Vec) -> Vec { - let mut items = items; - items.sort_by_cached_key(|item| item.item.reviews.len()); - items +pub(crate) fn sort_items_by_review_length( + mut weighted_items: Vec, +) -> Vec { + weighted_items.sort_by_cached_key(|weighted_item| weighted_item.item.reviews.len()); + weighted_items } -#[cfg(test)] pub(crate) fn constant_weighted_fsrs_items(items: Vec) -> Vec { items .into_iter() diff --git a/src/inference.rs b/src/inference.rs index e3d7f854..afa4e8f7 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -6,7 +6,9 @@ use burn::nn::loss::Reduction; use burn::tensor::{Data, Shape, Tensor}; use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend}; -use crate::dataset::{recency_weighted_fsrs_items, FSRSBatch, FSRSBatcher}; +use crate::dataset::{ + constant_weighted_fsrs_items, recency_weighted_fsrs_items, FSRSBatch, FSRSBatcher, +}; use crate::error::Result; use crate::model::Model; use crate::training::BCELoss; @@ -217,19 +219,19 @@ impl FSRS { if items.is_empty() { return Err(FSRSError::NotEnoughData); } - let items = recency_weighted_fsrs_items(items); + let weighted_items = recency_weighted_fsrs_items(items); let batcher = FSRSBatcher::new(self.device()); let mut all_retention = vec![]; let mut all_labels = vec![]; let mut all_weights = vec![]; let mut progress_info = ItemProgress { current: 0, - total: items.len(), + total: weighted_items.len(), }; let model = self.model(); let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new(); - for chunk in items.chunks(512) { + for chunk in weighted_items.chunks(512) { let batch = batcher.batch(chunk.to_vec()); let (_state, retention) = infer::(model, batch.clone()); let pred = retention.clone().to_data().convert::().value; @@ -237,13 +239,13 @@ impl FSRS { all_retention.push(retention); all_labels.push(batch.labels); all_weights.push(batch.weights); - izip!(chunk, pred, true_val).for_each(|(item, p, y)| { - let bin = item.item.r_matrix_index(); + izip!(chunk, pred, true_val).for_each(|(weighted_item, p, y)| { + let bin = weighted_item.item.r_matrix_index(); let value = r_matrix.entry(bin).or_default(); value.predicted += p; value.actual += y; value.count += 1.0; - value.weight += item.weight; + value.weight += weighted_item.weight; }); progress_info.current += chunk.len(); if !progress(progress_info) { @@ -290,19 +292,19 @@ impl FSRS { if items.is_empty() { return Err(FSRSError::NotEnoughData); } - let items = recency_weighted_fsrs_items(items); + let weighted_items = constant_weighted_fsrs_items(items); let batcher = FSRSBatcher::new(self.device()); let mut all_predictions_self = vec![]; let mut all_predictions_other = vec![]; let mut all_true_val = vec![]; let mut progress_info = ItemProgress { current: 0, - total: items.len(), + total: weighted_items.len(), }; let model_self = self.model(); let fsrs_other = Self::new_with_backend(Some(parameters), self.device())?; let model_other = fsrs_other.model(); - for chunk in items.chunks(512) { + for chunk in weighted_items.chunks(512) { let batch = batcher.batch(chunk.to_vec()); let (_state, retention) = infer::(model_self, batch.clone()); From a5600a8ed668ba8f399a3bd3cb7b5a0c4fcfd460 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 24 Dec 2024 11:47:44 +0800 Subject: [PATCH 8/8] bump version --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7b442c75..331220c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "1.4.5" +version = "1.5.0" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index f0068919..06344eaf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "1.4.5" +version = "1.5.0" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021"