Skip to content

Commit

Permalink
Fix/progress goes backwards (#166)
Browse files Browse the repository at this point in the history
Co-authored-by: Asuka Minato <[email protected]>
  • Loading branch information
L-M-Sherlock and asukaminato0721 authored Mar 10, 2024
1 parent f1afdd7 commit ec2eae5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.5.0"
version = "0.5.1"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
60 changes: 40 additions & 20 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,6 @@ impl<B: Backend> FSRS<B> {
};

let n_splits = 5;
if let Some(progress) = &progress {
progress.lock().unwrap().splits = vec![ProgressState::default(); n_splits];
}
let average_recall = calculate_average_recall(&items);
let (pre_trainset, trainsets, testset) = split_data(items, n_splits);
let initial_stability = pretrain(pre_trainset, average_recall).map_err(|e| {
Expand All @@ -239,18 +236,32 @@ impl<B: Backend> FSRS<B> {
AdamConfig::new(),
);

let weight_sets: Result<Vec<Vec<f32>>> = (0..n_splits)
let trainsets: Vec<Vec<FSRSItem>> = (0..n_splits)
.into_par_iter()
.map(|i| {
let trainset = trainsets
trainsets
.par_iter()
.enumerate()
.filter(|&(j, _)| j != i)
.flat_map(|(_, trainset)| trainset.clone())
.collect();
.collect()
})
.collect();

if let Some(progress) = &progress {
let mut progress_states = vec![ProgressState::default(); n_splits];
for (i, progress_state) in progress_states.iter_mut().enumerate() {
progress_state.epoch_total = config.num_epochs;
progress_state.items_total = trainsets[i].len();
}
progress.lock().unwrap().splits = progress_states
}

let weight_sets: Result<Vec<Vec<f32>>> = (0..n_splits)
.into_par_iter()
.map(|i| {
let model = train::<Autodiff<B>>(
trainset,
trainsets[i].clone(),
testset.clone(),
&config,
self.device(),
Expand Down Expand Up @@ -280,10 +291,8 @@ impl<B: Backend> FSRS<B> {
.map(|&sum| sum / n_splits as f32)
.collect();

for weight in &average_parameters {
if !weight.is_finite() {
return Err(FSRSError::InvalidInput);
}
if average_parameters.iter().any(|weight| weight.is_infinite()) {
return Err(FSRSError::InvalidInput);
}

Ok(average_parameters)
Expand Down Expand Up @@ -499,25 +508,36 @@ mod tests {
thread::sleep(Duration::from_millis(10));
let guard = progress.lock().unwrap();
finished = guard.finished();
info!("progress: {}/{}", guard.current(), guard.total());
println!("progress: {}/{}", guard.current(), guard.total());
}
});

if let Some(progress2) = &progress2 {
progress2.lock().unwrap().splits = vec![ProgressState::default(); n_splits];
}

let parameters_sets: Vec<Vec<f32>> = (0..n_splits)
let trainsets: Vec<Vec<FSRSItem>> = (0..n_splits)
.into_par_iter()
.map(|i| {
let trainset = trainsets
trainsets
.par_iter()
.enumerate()
.filter(|&(j, _)| j != i)
.flat_map(|(_, trainset)| trainset.clone())
.collect();
.collect()
})
.collect();

if let Some(progress2) = &progress2 {
let mut progress_states = vec![ProgressState::default(); n_splits];
for (i, progress_state) in progress_states.iter_mut().enumerate() {
progress_state.epoch_total = config.num_epochs;
progress_state.items_total = trainsets[i].len();
}
progress2.lock().unwrap().splits = progress_states
}

let parameters_sets: Vec<Vec<f32>> = (0..n_splits)
.into_par_iter()
.map(|i| {
let model = train::<NdArrayAutodiff>(
trainset,
trainsets[i].clone(),
items.clone(),
&config,
device,
Expand Down

0 comments on commit ec2eae5

Please sign in to comment.