Skip to content

Commit

Permalink
update burn to 0.12.1 (#157)
Browse files Browse the repository at this point in the history
* update burn to 0.12.1

* try bump version

* replace &B::Device::default() with &self.device

* cargo fmt

---------

Co-authored-by: Asuka Minato <[email protected]>
  • Loading branch information
L-M-Sherlock and asukaminato0721 authored Feb 14, 2024
1 parent 54df730 commit e34265a
Show file tree
Hide file tree
Showing 10 changed files with 552 additions and 472 deletions.
843 changes: 439 additions & 404 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ description = "FSRS for Rust, including Optimizer and Scheduler"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies.burn]
version = "0.11.1"
version = "0.12.1"
# git = "https://github.com/burn-rs/burn.git"
# rev = "d2639682367f39d0d0ed049d0cf3a2077259e05d"
# path = "../burn/burn"
default-features = false
features = ["std", "train", "ndarray"]

[dev-dependencies.burn]
version = "0.11.1"
version = "0.12.1"
# git = "https://github.com/burn-rs/burn.git"
# rev = "d2639682367f39d0d0ed049d0cf3a2077259e05d"
# path = "../burn/burn"
Expand All @@ -38,8 +38,8 @@ ndarray-rand = "0.14.0"
rand = "0.8.5"
rayon = "1.8.0"
serde = "1.0.193"
snafu = "0.7.5"
strum = { version = "0.25.0", features = ["derive"] }
snafu = "0.8.0"
strum = { version = "0.26.1", features = ["derive"] }

[dev-dependencies]
chrono = { version = "0.4.31", default-features = false, features = ["std", "clock"] }
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
# older versions may fail to compile; newer versions may fail the clippy tests
channel = "1.75"
channel = "1.76"
components = ["rustfmt", "clippy"]
4 changes: 4 additions & 0 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ impl<I: Send + Sync + Clone + 'static, O: Send + Sync> DataLoader<O>
self.batcher.clone(),
))
}

fn num_items(&self) -> usize {
self.dataset.len()
}
}

impl<I, O> BatchShuffledDataloaderIterator<I, O> {
Expand Down
9 changes: 5 additions & 4 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use burn::{lr_scheduler::LrScheduler, LearningRate};
use burn::{lr_scheduler::LrScheduler, tensor::backend::Backend, LearningRate};
use log::info;
#[derive(Clone, Debug)]
pub(crate) struct CosineAnnealingLR {
Expand All @@ -21,7 +21,7 @@ impl CosineAnnealingLR {
}
}

impl LrScheduler for CosineAnnealingLR {
impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
type Record = usize;

fn step(&mut self) -> LearningRate {
Expand Down Expand Up @@ -66,15 +66,16 @@ impl LrScheduler for CosineAnnealingLR {
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Data;
use burn::{backend::NdArray, tensor::Data};
type Backend = NdArray<f32>;

#[test]
fn lr_scheduler() {
let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1);

let lrs = (0..=200000)
.map(|_| {
lr_scheduler.step();
LrScheduler::<Backend>::step(&mut lr_scheduler);
lr_scheduler.current_lr
})
.step_by(20000)
Expand Down
20 changes: 12 additions & 8 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
item.history().map(|r| (r.delta_t, r.rating)).unzip();
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
let delta_t =
Tensor::from_data(Data::new(delta_t, Shape { dims: [pad_size] }).convert())
.unsqueeze();
let rating =
Tensor::from_data(Data::new(rating, Shape { dims: [pad_size] }).convert())
.unsqueeze();
let delta_t = Tensor::from_data(
Data::new(delta_t, Shape { dims: [pad_size] }).convert(),
&self.device,
)
.unsqueeze();
let rating = Tensor::from_data(
Data::new(rating, Shape { dims: [pad_size] }).convert(),
&self.device,
)
.unsqueeze();
(delta_t, rating)
})
.unzip();
Expand All @@ -84,12 +88,12 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
.iter()
.map(|item| {
let current = item.current();
let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]));
let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
let label = match current.rating {
1 => 0.0,
_ => 1.0,
};
let label = Tensor::from_data(Data::from([label.elem()]));
let label = Tensor::from_data(Data::from([label.elem()]), &self.device);
(delta_t, label)
})
.unzip();
Expand Down
48 changes: 32 additions & 16 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ impl<B: Backend> From<MemoryStateTensors<B>> for MemoryState {
impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
fn from(m: MemoryState) -> Self {
Self {
stability: Tensor::from_data(Data::new(vec![m.stability.elem()], Shape { dims: [1] })),
difficulty: Tensor::from_data(Data::new(
vec![m.difficulty.elem()],
Shape { dims: [1] },
)),
stability: Tensor::from_data(
Data::new(vec![m.stability.elem()], Shape { dims: [1] }),
&B::Device::default(),
),
difficulty: Tensor::from_data(
Data::new(vec![m.difficulty.elem()], Shape { dims: [1] }),
&B::Device::default(),
),
}
}
}
Expand All @@ -81,14 +84,18 @@ impl<B: Backend> FSRS<B> {
let (time_history, rating_history) =
item.reviews.iter().map(|r| (r.delta_t, r.rating)).unzip();
let size = item.reviews.len();
let time_history =
Tensor::from_data(Data::new(time_history, Shape { dims: [size] }).convert())
.unsqueeze()
.transpose();
let rating_history =
Tensor::from_data(Data::new(rating_history, Shape { dims: [size] }).convert())
.unsqueeze()
.transpose();
let time_history = Tensor::from_data(
Data::new(time_history, Shape { dims: [size] }).convert(),
&self.device(),
)
.unsqueeze()
.transpose();
let rating_history = Tensor::from_data(
Data::new(rating_history, Shape { dims: [size] }).convert(),
&self.device(),
)
.unsqueeze()
.transpose();
let state: MemoryState = self
.model()
.forward(time_history, rating_history, starting_state.map(Into::into))
Expand Down Expand Up @@ -138,7 +145,10 @@ impl<B: Backend> FSRS<B> {
) -> u32 {
let stability = stability.unwrap_or_else(|| {
// get initial stability for new card
let rating = Tensor::from_data(Data::new(vec![rating.elem()], Shape { dims: [1] }));
let rating = Tensor::from_data(
Data::new(vec![rating.elem()], Shape { dims: [1] }),
&self.device(),
);
let model = self.model();
model.init_stability(rating).into_scalar().elem()
});
Expand All @@ -153,7 +163,10 @@ impl<B: Backend> FSRS<B> {
desired_retention: f32,
days_elapsed: u32,
) -> Result<NextStates> {
let delta_t = Tensor::from_data(Data::new(vec![days_elapsed.elem()], Shape { dims: [1] }));
let delta_t = Tensor::from_data(
Data::new(vec![days_elapsed.elem()], Shape { dims: [1] }),
&self.device(),
);
let current_memory_state_tensors = current_memory_state.map(MemoryStateTensors::from);
let model = self.model();
let mut next_memory_states = (1..=4).map(|rating| {
Expand All @@ -164,7 +177,10 @@ impl<B: Backend> FSRS<B> {
} else {
let state = MemoryState::from(model.step(
delta_t.clone(),
Tensor::from_data(Data::new(vec![rating.elem()], Shape { dims: [1] })),
Tensor::from_data(
Data::new(vec![rating.elem()], Shape { dims: [1] }),
&self.device(),
),
current_memory_state_tensors.clone(),
));
if !state.stability.is_finite() || !state.difficulty.is_finite() {
Expand Down
74 changes: 43 additions & 31 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ impl<B: Backend> Model<B> {
.collect();

Self {
w: Param::from(Tensor::from_floats(Data::new(
initial_params,
Shape { dims: [17] },
))),
w: Param::from(Tensor::from_floats(
Data::new(initial_params, Shape { dims: [17] }),
&B::Device::default(),
)),
config,
}
}

pub fn power_forgetting_curve(&self, t: Tensor<B, 1>, s: Tensor<B, 1>) -> Tensor<B, 1> {
(t / s * FACTOR + 1).powf(DECAY as f32)
(t / s * FACTOR + 1).powf_scalar(DECAY as f32)
}

fn stability_after_success(
Expand All @@ -69,10 +69,10 @@ impl<B: Backend> Model<B> {
rating: Tensor<B, 1>,
) -> Tensor<B, 1> {
let batch_size = rating.dims()[0];
let hard_penalty =
Tensor::ones([batch_size]).mask_where(rating.clone().equal_elem(2), self.w.get(15));
let easy_bonus =
Tensor::ones([batch_size]).mask_where(rating.equal_elem(4), self.w.get(16));
let hard_penalty = Tensor::ones([batch_size], &B::Device::default())
.mask_where(rating.clone().equal_elem(2), self.w.get(15));
let easy_bonus = Tensor::ones([batch_size], &B::Device::default())
.mask_where(rating.equal_elem(4), self.w.get(16));

last_s.clone()
* (self.w.get(8).exp()
Expand Down Expand Up @@ -243,10 +243,10 @@ impl<B: Backend> FSRS<B> {
pub(crate) fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
let config = ModelConfig::default();
let mut model = Model::new(config);
model.w = Param::from(Tensor::from_floats(Data::new(
clip_parameters(parameters),
Shape { dims: [17] },
)));
model.w = Param::from(Tensor::from_floats(
Data::new(clip_parameters(parameters), Shape { dims: [17] }),
&B::Device::default(),
));
model
}

Expand All @@ -264,9 +264,10 @@ mod tests {

#[test]
fn power_forgetting_curve() {
let device = NdArrayDevice::Cpu;
let model = Model::new(ModelConfig::default());
let delta_t = Tensor::from_floats([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
let stability = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 4.0, 2.0]);
let delta_t = Tensor::from_floats([0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &device);
let stability = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 4.0, 2.0], &device);
let retention = model.power_forgetting_curve(delta_t, stability);
assert_eq!(
retention.to_data(),
Expand All @@ -276,8 +277,9 @@ mod tests {

#[test]
fn init_stability() {
let device = NdArrayDevice::Cpu;
let model = Model::new(ModelConfig::default());
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 1.0, 2.0]);
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 1.0, 2.0], &device);
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Expand All @@ -287,8 +289,9 @@ mod tests {

#[test]
fn init_difficulty() {
let device = NdArrayDevice::Cpu;
let model = Model::new(ModelConfig::default());
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 1.0, 2.0]);
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0, 1.0, 2.0], &device);
let difficulty = model.init_difficulty(rating);
assert_eq!(
difficulty.to_data(),
Expand All @@ -298,24 +301,32 @@ mod tests {

#[test]
fn forward() {
let device = NdArrayDevice::Cpu;
let model = Model::new(ModelConfig::default());
let delta_ts = Tensor::from_floats([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 2.0, 2.0],
]);
let ratings = Tensor::from_floats([
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
]);
let delta_ts = Tensor::from_floats(
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 2.0, 2.0],
],
&device,
);
let ratings = Tensor::from_floats(
[
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
],
&device,
);
let state = model.forward(delta_ts, ratings, None);
dbg!(&state);
}

#[test]
fn next_difficulty() {
let device = NdArrayDevice::Cpu;
let model = Model::new(ModelConfig::default());
let difficulty = Tensor::from_floats([5.0; 4]);
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0]);
let difficulty = Tensor::from_floats([5.0; 4], &device);
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let next_difficulty = model.next_difficulty(difficulty, rating);
next_difficulty.clone().backward();
assert_eq!(
Expand All @@ -332,11 +343,12 @@ mod tests {

#[test]
fn next_stability() {
let device = NdArrayDevice::Cpu;
let model = Model::new(ModelConfig::default());
let stability = Tensor::from_floats([5.0; 4]);
let difficulty = Tensor::from_floats([1.0, 2.0, 3.0, 4.0]);
let retention = Tensor::from_floats([0.9, 0.8, 0.7, 0.6]);
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0]);
let stability = Tensor::from_floats([5.0; 4], &device);
let difficulty = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let retention = Tensor::from_floats([0.9, 0.8, 0.7, 0.6], &device);
let rating = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let s_recall = model.stability_after_success(
stability.clone(),
difficulty.clone(),
Expand Down
3 changes: 2 additions & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ impl<B: Backend> Model<B> {
impl<B: AutodiffBackend> Model<B> {
fn freeze_initial_stability(&self, mut grad: B::Gradients) -> B::Gradients {
let grad_tensor = self.w.grad(&grad).unwrap();
let updated_grad_tensor = grad_tensor.slice_assign([0..4], Tensor::zeros([4]));
let updated_grad_tensor =
grad_tensor.slice_assign([0..4], Tensor::zeros([4], &B::Device::default()));

self.w.grad_remove(&mut grad);
self.w.grad_replace(&mut grad, updated_grad_tensor);
Expand Down
13 changes: 10 additions & 3 deletions src/weight_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use burn::tensor::{backend::Backend, Data, Tensor};

pub(crate) fn weight_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(Data::new(val, parameters.shape()).convert())
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
)
}

pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
Expand Down Expand Up @@ -43,11 +46,15 @@ pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
mod tests {
use super::*;
use crate::test_helpers::Tensor;
use burn::backend::ndarray::NdArrayDevice;

#[test]
fn weight_clipper_works() {
let tensor =
Tensor::from_floats([0.0, -1000.0, 1000.0, 0.0, 1000.0, -1000.0, 1.0, 0.25, -0.1]);
let device = NdArrayDevice::Cpu;
let tensor = Tensor::from_floats(
[0.0, -1000.0, 1000.0, 0.0, 1000.0, -1000.0, 1.0, 0.25, -0.1],
&device,
);

let param: Tensor<1> = weight_clipper(tensor);
let values = &param.to_data().value;
Expand Down

0 comments on commit e34265a

Please sign in to comment.