Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 24, 2024
1 parent 38d9a86 commit 6812b92
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
33 changes: 18 additions & 15 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,22 @@ pub(crate) struct FSRSBatch<B: Backend> {
}

impl<B: Backend> Batcher<WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
fn batch(&self, items: Vec<WeightedFSRSItem>) -> FSRSBatch<B> {
let pad_size = items
fn batch(&self, weighted_items: Vec<WeightedFSRSItem>) -> FSRSBatch<B> {
let pad_size = weighted_items
.iter()
.map(|x| x.item.reviews.len())
.max()
.expect("FSRSItem is empty")
- 1;

let (time_histories, rating_histories) = items
let (time_histories, rating_histories) = weighted_items
.iter()
.map(|item| {
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) =
item.item.history().map(|r| (r.delta_t, r.rating)).unzip();
.map(|weighted_item| {
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) = weighted_item
.item
.history()
.map(|r| (r.delta_t, r.rating))
.unzip();
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
let delta_t = Tensor::from_data(
Expand All @@ -137,10 +140,10 @@ impl<B: Backend> Batcher<WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
})
.unzip();

let (delta_ts, labels, weights) = items
let (delta_ts, labels, weights) = weighted_items
.iter()
.map(|item| {
let current = item.item.current();
.map(|weighted_item| {
let current = weighted_item.item.current();
let delta_t: Tensor<B, 1> =
Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
let label = match current.rating {
Expand All @@ -150,7 +153,7 @@ impl<B: Backend> Batcher<WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
let label: Tensor<B, 1, Int> =
Tensor::from_data(Data::from([label.elem()]), &self.device);
let weight: Tensor<B, 1> =
Tensor::from_data(Data::from([item.weight.elem()]), &self.device);
Tensor::from_data(Data::from([weighted_item.weight.elem()]), &self.device);
(delta_t, label, weight)
})
.multiunzip();
Expand Down Expand Up @@ -265,13 +268,13 @@ pub fn prepare_training_data(items: Vec<FSRSItem>) -> (Vec<FSRSItem>, Vec<FSRSIt
(pretrainset.clone(), [pretrainset, trainset].concat())
}

pub(crate) fn sort_items_by_review_length(items: Vec<WeightedFSRSItem>) -> Vec<WeightedFSRSItem> {
let mut items = items;
items.sort_by_cached_key(|item| item.item.reviews.len());
items
pub(crate) fn sort_items_by_review_length(
mut weighted_items: Vec<WeightedFSRSItem>,
) -> Vec<WeightedFSRSItem> {
weighted_items.sort_by_cached_key(|weighted_item| weighted_item.item.reviews.len());
weighted_items
}

#[cfg(test)]
pub(crate) fn constant_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
items
.into_iter()
Expand Down
22 changes: 12 additions & 10 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use burn::nn::loss::Reduction;
use burn::tensor::{Data, Shape, Tensor};
use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend};

use crate::dataset::{recency_weighted_fsrs_items, FSRSBatch, FSRSBatcher};
use crate::dataset::{
constant_weighted_fsrs_items, recency_weighted_fsrs_items, FSRSBatch, FSRSBatcher,
};
use crate::error::Result;
use crate::model::Model;
use crate::training::BCELoss;
Expand Down Expand Up @@ -217,33 +219,33 @@ impl<B: Backend> FSRS<B> {
if items.is_empty() {
return Err(FSRSError::NotEnoughData);
}
let items = recency_weighted_fsrs_items(items);
let weighted_items = recency_weighted_fsrs_items(items);
let batcher = FSRSBatcher::new(self.device());
let mut all_retention = vec![];
let mut all_labels = vec![];
let mut all_weights = vec![];
let mut progress_info = ItemProgress {
current: 0,
total: items.len(),
total: weighted_items.len(),
};
let model = self.model();
let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new();

for chunk in items.chunks(512) {
for chunk in weighted_items.chunks(512) {
let batch = batcher.batch(chunk.to_vec());
let (_state, retention) = infer::<B>(model, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
let true_val = batch.labels.clone().to_data().convert::<f32>().value;
all_retention.push(retention);
all_labels.push(batch.labels);
all_weights.push(batch.weights);
izip!(chunk, pred, true_val).for_each(|(item, p, y)| {
let bin = item.item.r_matrix_index();
izip!(chunk, pred, true_val).for_each(|(weighted_item, p, y)| {
let bin = weighted_item.item.r_matrix_index();
let value = r_matrix.entry(bin).or_default();
value.predicted += p;
value.actual += y;
value.count += 1.0;
value.weight += item.weight;
value.weight += weighted_item.weight;
});
progress_info.current += chunk.len();
if !progress(progress_info) {
Expand Down Expand Up @@ -290,19 +292,19 @@ impl<B: Backend> FSRS<B> {
if items.is_empty() {
return Err(FSRSError::NotEnoughData);
}
let items = recency_weighted_fsrs_items(items);
let weighted_items = constant_weighted_fsrs_items(items);
let batcher = FSRSBatcher::new(self.device());
let mut all_predictions_self = vec![];
let mut all_predictions_other = vec![];
let mut all_true_val = vec![];
let mut progress_info = ItemProgress {
current: 0,
total: items.len(),
total: weighted_items.len(),
};
let model_self = self.model();
let fsrs_other = Self::new_with_backend(Some(parameters), self.device())?;
let model_other = fsrs_other.model();
for chunk in items.chunks(512) {
for chunk in weighted_items.chunks(512) {
let batch = batcher.batch(chunk.to_vec());

let (_state, retention) = infer::<B>(model_self, batch.clone());
Expand Down

0 comments on commit 6812b92

Please sign in to comment.