-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pass e2e test: Record public inputs in synthesiz & tiny model
- Loading branch information
Showing
7 changed files
with
160 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
|
||
pub mod tiny_model; | ||
pub mod medium_model; | ||
|
||
pub use medium_model::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
use std::{ | ||
any::Any, | ||
collections::HashMap, | ||
convert::TryInto, | ||
fs::{self, File}, | ||
iter::zip, | ||
ops::Deref, | ||
path::Path, | ||
}; | ||
|
||
use luminal::prelude::*; | ||
use luminal_nn::{Linear, ReLU}; | ||
use luminal_training::{mse_loss, sgd_on_graph, Autograd}; | ||
use petgraph::visit::{EdgeRef, IntoEdgeReferences, IntoNodeReferences}; | ||
|
||
use crate::{model::{normalize_data, split_dataset, ExponentialAverage, InputsVec, OutputsVec}, scalar::copy_graph_roughly}; | ||
|
||
use super::{TrainParams, TrainedGraph}; | ||
|
||
pub type Model = (Linear<9, 1>); | ||
|
||
pub fn run_model(train_params: TrainParams) -> (Graph, Model, TrainedGraph) { | ||
let dataset: (InputsVec, OutputsVec) = train_params.data; | ||
let epochs = train_params.epochs; | ||
// Setup gradient graph | ||
let mut cx = Graph::new(); | ||
let model = <Model>::initialize(&mut cx); | ||
let input = cx.tensor::<R1<9>>(); | ||
let output = model.forward(input).retrieve(); | ||
|
||
cx.display(); | ||
// record graph without gradients. assuming nodeids dont change in Autograd::compile | ||
let cx_og = copy_graph_roughly(&cx); | ||
|
||
let target = cx.tensor::<R1<1>>(); | ||
let loss = mse_loss(output, target).retrieve(); | ||
let weights = params(&model); | ||
|
||
let grads = cx.compile(Autograd::new(&weights, loss), ()); | ||
let (new_weights, lr) = sgd_on_graph(&mut cx, &weights, &grads); | ||
cx.keep_tensors(&new_weights); | ||
cx.keep_tensors(&weights); | ||
lr.set(5e-3); | ||
|
||
let (mut loss_avg, mut acc_avg) = (ExponentialAverage::new(1.0), ExponentialAverage::new(0.0)); | ||
let start = std::time::Instant::now(); | ||
// let EPOCHS = 20; | ||
|
||
let (X, Y) = dataset; | ||
let (X_train, _x_test, y_train, _y_test) = split_dataset(X, Y, 0.8); | ||
let X_train = normalize_data(X_train); | ||
let mut iter = 0; | ||
for _ in 0..epochs { | ||
for (x, y) in zip(X_train.iter(), y_train.iter()) { | ||
let answer = [y.to_owned()]; | ||
input.set(x.to_owned()); | ||
target.set(answer); | ||
|
||
cx.execute(); | ||
transfer_data_same_graph(&new_weights, &weights, &mut cx); | ||
loss_avg.update(loss.data()[0]); | ||
loss.drop(); | ||
// println!("{:}, {:}", output.data()[0], answer[0]); | ||
acc_avg.update( | ||
output | ||
.data() | ||
.into_iter() | ||
.zip(answer) | ||
.filter(|(a, b)| (a - b).abs() < 0.5) | ||
.count() as f32, | ||
); | ||
output.drop(); | ||
// println!( | ||
// "Iter {iter} Loss: {:.2} Acc: {:.2}", | ||
// loss_avg.value, acc_avg.value | ||
// ); | ||
iter += 1; | ||
} | ||
} | ||
println!("Finished in {iter} iterations"); | ||
println!( | ||
"Took {:.2}s, {:.2}µs / iter", | ||
start.elapsed().as_secs_f32(), | ||
start.elapsed().as_micros() / iter | ||
); | ||
cx.display(); | ||
let weights_vec = weights | ||
.into_iter() | ||
.map(|a| { | ||
( | ||
a, | ||
cx.tensors | ||
.get(&(a, 0 /* assuming single output */)) | ||
.unwrap() | ||
.downcast_ref::<Vec<f32>>() | ||
.unwrap() | ||
.clone() | ||
.into_iter() | ||
.collect(), | ||
) | ||
}) | ||
.collect(); | ||
( | ||
cx, | ||
model, | ||
TrainedGraph { | ||
graph: cx_og, | ||
input_id: input.id, | ||
weights: weights_vec, | ||
// model: model.clone() | ||
}, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters