Skip to content

Commit

Permalink
Undiscovered bug fixed: use remap in input_tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
zmrocze committed Jul 26, 2024
1 parent d5f6277 commit af35c01
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Here, either the weights or the inputs or any combination of the two can be made

### Demo

To see how it's used and how it works, check out the tests in [lib::test_trained_into_snark](https://github.com/przyjacielpkp/zkml/blob/main/lib/src/lib.rs#L76). Or run them with `cargo test --profile=test`. Tests take some minute and a half to run on my laptop.
To see how it's used and how it works, check out the tests in [lib::test_trained_into_snark](https://github.com/przyjacielpkp/zkml/blob/main/lib/src/lib.rs#L76). Or run them with `cargo test --profile=test`. Tests take about a minute to run on my laptop.

These demonstrate the full functionality:
- a trained model is taken, that is a computation graph with the weight assignments (and some more bookeeping, we admit the abstraction is leaky here)
Expand Down
10 changes: 5 additions & 5 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod tests {
// See the model shape at https://dreampuf.github.io/GraphvizOnline/#digraph%20%7B%0A%20%20%20%200%20%5B%20label%20%3D%20%22Weight%20Load%20%7C%200%22%20%5D%0A%20%20%20%201%20%5B%20label%20%3D%20%22Tensor%20Load%20%7C%201%22%20%5D%0A%20%20%20%202%20%5B%20label%20%3D%20%22Mul%20%7C%202%22%20%5D%0A%20%20%20%203%20%5B%20label%20%3D%20%22SumReduce(2)%20%7C%203%22%20%5D%0A%20%20%20%200%20-%3E%202%20%5B%20%20%5D%0A%20%20%20%201%20-%3E%202%20%5B%20%20%5D%0A%20%20%20%202%20-%3E%203%20%5B%20%20%5D%0A%7D%0A
tracing::info!("linear layer, data A");
let data = parse_dataset(include_str!("../../data/rp.data").to_string());
let trained_model = crate::model::tiny_model::run_model(TrainParams { data, epochs: 10 });
let trained_model = crate::model::tiny_model::run_model(TrainParams { data, epochs: 2 });
let input = (0..9).map(|x| f32::from(x as i16)).collect_vec();
test_trained_into_snark(trained_model, input)
}
Expand All @@ -133,7 +133,7 @@ mod tests {
pub fn test_trained_into_snark_1() -> Result<(), String> {
tracing::info!("linear layer, data B");
let data = parse_dataset(include_str!("../../data/rp.data").to_string());
let trained_model = crate::model::tiny_model::run_model(TrainParams { data, epochs: 10 });
let trained_model = crate::model::tiny_model::run_model(TrainParams { data, epochs: 2 });
let input = (9..18).map(|x| f32::from(x as i16)).collect_vec();
test_trained_into_snark(trained_model, input)
}
Expand All @@ -143,7 +143,7 @@ mod tests {
// see the model shape at https://dreampuf.github.io/GraphvizOnline/#digraph%20%7B%0A%20%20%20%200%20%5B%20label%20%3D%20%22Weight%20Load%20%7C%200%22%20%5D%0A%20%20%20%201%20%5B%20label%20%3D%20%22Weight%20Load%20%7C%201%22%20%5D%0A%20%20%20%202%20%5B%20label%20%3D%20%22Tensor%20Load%20%7C%202%22%20%5D%0A%20%20%20%203%20%5B%20label%20%3D%20%22Mul%20%7C%203%22%20%5D%0A%20%20%20%204%20%5B%20label%20%3D%20%22SumReduce(2)%20%7C%204%22%20%5D%0A%20%20%20%205%20%5B%20label%20%3D%20%22Constant(0.0)%20%7C%205%22%20%5D%0A%20%20%20%206%20%5B%20label%20%3D%20%22LessThan%20%7C%206%22%20%5D%0A%20%20%20%207%20%5B%20label%20%3D%20%22Mul%20%7C%207%22%20%5D%0A%20%20%20%208%20%5B%20label%20%3D%20%22LessThan%20%7C%208%22%20%5D%0A%20%20%20%209%20%5B%20label%20%3D%20%22Constant(-1.0)%20%7C%209%22%20%5D%0A%20%20%20%2010%20%5B%20label%20%3D%20%22Mul%20%7C%2010%22%20%5D%0A%20%20%20%2011%20%5B%20label%20%3D%20%22Constant(1.0)%20%7C%2011%22%20%5D%0A%20%20%20%2012%20%5B%20label%20%3D%20%22Add%20%7C%2012%22%20%5D%0A%20%20%20%2013%20%5B%20label%20%3D%20%22Mul%20%7C%2013%22%20%5D%0A%20%20%20%2014%20%5B%20label%20%3D%20%22Add%20%7C%2014%22%20%5D%0A%20%20%20%2015%20%5B%20label%20%3D%20%22Mul%20%7C%2015%22%20%5D%0A%20%20%20%2016%20%5B%20label%20%3D%20%22SumReduce(2)%20%7C%2016%22%20%5D%0A%20%20%20%200%20-%3E%203%20%5B%20%20%5D%0A%20%20%20%201%20-%3E%2015%20%5B%20%20%5D%0A%20%20%20%202%20-%3E%203%20%5B%20%20%5D%0A%20%20%20%203%20-%3E%204%20%5B%20%20%5D%0A%20%20%20%204%20-%3E%208%20%5B%20%20%5D%0A%20%20%20%204%20-%3E%206%20%5B%20%20%5D%0A%20%20%20%204%20-%3E%2013%20%5B%20%20%5D%0A%20%20%20%205%20-%3E%208%20%5B%20%20%5D%0A%20%20%20%205%20-%3E%207%20%5B%20%20%5D%0A%20%20%20%205%20-%3E%206%20%5B%20%20%5D%0A%20%20%20%206%20-%3E%207%20%5B%20%20%5D%0A%20%20%20%207%20-%3E%2014%20%5B%20%20%5D%0A%20%20%20%208%20-%3E%2010%20%5B%20%20%5D%0A%20%20%20%209%20-%3E%2010%20%5B%20%20%5D%0A%20%20%20%2010%20-%3E%2012%20%5B%20%20%5D%0A%20%20%20%2011%20-%3E%2012%20%5B%20%20%5D%0A%20%20%20%2012%20-%3E%2013%20%5B%20%20%5D%0A%20%20%20%2013%20-%3E%2014%20%5B%20%20%5D%0A%20%20%20%2014%20-%3E%2015%20%5B%20%20%5D%0A%20%20%20%2015%20-%3E%2016%20%5B%20%20%5D%0A%7D%0A
tracing::info!("linear layer into ReLU, data A");
let data = parse_dataset(include_str!("../../data/rp.data").to_string());
let trained_model = crate::model::lessthan_model::run_model(TrainParams { data, epochs: 10 });
let trained_model = crate::model::lessthan_model::run_model(TrainParams { data, epochs: 2 });
let input = (0..9).map(|x| f32::from(x as i16)).collect_vec();
test_trained_into_snark(trained_model, input)
}
Expand All @@ -152,7 +152,7 @@ mod tests {
pub fn test_trained_into_snark_3() -> Result<(), String> {
tracing::info!("linear layer into ReLU, data B");
let data = parse_dataset(include_str!("../../data/rp.data").to_string());
let trained_model = crate::model::lessthan_model::run_model(TrainParams { data, epochs: 10 });
let trained_model = crate::model::lessthan_model::run_model(TrainParams { data, epochs: 2 });
let input = (9..18).map(|x| f32::from(x as i16)).collect_vec();
test_trained_into_snark(trained_model, input)
}
Expand All @@ -161,7 +161,7 @@ mod tests {
pub fn test_trained_into_snark_4() -> Result<(), String> {
tracing::info!("linear layer into ReLU, data C");
let data = parse_dataset(include_str!("../../data/rp.data").to_string());
let trained_model = crate::model::lessthan_model::run_model(TrainParams { data, epochs: 10 });
let trained_model = crate::model::lessthan_model::run_model(TrainParams { data, epochs: 2 });
let input: Vec<f32> = [
1.001231212412512,
0.3141512,
Expand Down
2 changes: 1 addition & 1 deletion lib/src/model/fixed_weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn run_model() -> TrainedGraph {
// cx.display();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;
let input_id = remap[&input.id];

let target = cx.tensor::<R1<1>>();
// let loss = mse_loss(output, target).retrieve();
Expand Down
2 changes: 1 addition & 1 deletion lib/src/model/lessthan_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
// cx.display_shapes();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;
let input_id = remap[&input.id];

let target = cx.tensor::<R1<1>>();
let loss = mse_loss(output, target).retrieve();
Expand Down
18 changes: 11 additions & 7 deletions lib/src/model/medium_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,17 @@ pub fn get_weights(graph: &Graph, model: &Model) -> HashMap<NodeIndex, Vec<f32>>

pub struct TrainParams {
pub data: (InputsVec, OutputsVec),
// pub output: PathBuf,
pub epochs: usize,
// pub lr: f32,
// pub batch_size: u32,
// pub model: Model,
}

/// Contains everything needed to define the snark: the ml graph but without the gradients, trained weights and indexes.
/// Note: this is quite a specific and frankly poor interface between training and snark synthesiz, so don't take it as engraved in stone.
#[derive(Debug)]
pub struct GraphForSnark {
// the initial ml computation graph, without gradients
pub graph: Graph,
pub input_id: NodeIndex,
pub weights: Vec<(NodeIndex, Vec<f32>)>,
Expand All @@ -133,19 +135,21 @@ impl GraphForSnark {
}
}

/// Contains everything needed to define a snark and also evaluate the model.
/// Note: this is quite a specific and frankly poor interface between training and snark synthesiz, so don't take it as engraved in stone.
/// Generally: this is graph + some stuff recorded to evaluate it on input.
#[derive(Debug)]
pub struct TrainedGraph {
/// the original ml computation graph, without gradients + input id + trained weights
pub graph: GraphForSnark,
// below are needed to evaluate the model:
pub cx: Graph, // full trained graph for evaluation, the above "graph" is a rough copy
pub cx_weights: Vec<(NodeIndex, Vec<f32>)>, // needed for evaluation, mostly tests
// below are needed to evaluate the model to compare result against a snark derived from GraphForSnark:
pub cx: Graph, /// full trained graph for evaluation, the above "graph" is similar but without gradients
pub cx_weights: Vec<(NodeIndex, Vec<f32>)>, // needed for evaluation, mostly tests. redundant a bit
pub cx_input_id: NodeIndex, // needed for evaluation, mostly tests
pub cx_target_id: NodeIndex, // needed for evaluation, mostly tests
pub cx_output_id: NodeIndex,
}

// todo: make general
// impl<I: Shape, M : Module<GraphTensor<I>>>
impl TrainedGraph {
pub fn evaluate(&mut self, input_data: Vec<f32>) -> Vec<f32> {
self.cx.get_op_mut::<Function>(self.cx_input_id).1 =
Expand Down Expand Up @@ -181,7 +185,7 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
// cx.display();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;
let input_id = remap[&input.id];

let target = cx.tensor::<R1<1>>();
let loss = mse_loss(output, target).retrieve();
Expand Down
4 changes: 2 additions & 2 deletions lib/src/model/tiny_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
let output = model.forward(input).retrieve();

// cx.display();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
// record graph without gradients.
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;
let input_id = remap[&input.id];

let target = cx.tensor::<R1<1>>();
let loss = mse_loss(output, target).retrieve();
Expand Down
5 changes: 3 additions & 2 deletions lib/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use petgraph::{
visit::{EdgeRef, IntoEdgeReferences, IntoNodeIdentifiers, NodeRef},
Direction::{Incoming, Outgoing},
};
use tracing::{debug, info, instrument, warn};
use tracing::{debug, instrument, warn};

use luminal::{
op::{Constant, InputTensor, Operator},
Expand Down Expand Up @@ -49,9 +49,10 @@ pub struct ScalarGraph {
impl ScalarGraph {
pub fn copy_graph_roughly(&self) -> Self {
let (g, remap) = copy_graph_roughly(&self.graph);
let inputs_tracker = self.inputs_tracker.remap(remap);
ScalarGraph {
graph: g,
inputs_tracker: self.inputs_tracker.clone(),
inputs_tracker,
}
}
}
Expand Down

0 comments on commit af35c01

Please sign in to comment.