Skip to content

Commit

Permalink
Pass e2e test: Record public inputs in synthesiz & tiny model
Browse files Browse the repository at this point in the history
  • Loading branch information
zmrocze committed Jul 16, 2024
1 parent c989003 commit c1c8bbd
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 33 deletions.
1 change: 1 addition & 0 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ itertools = "0.13.0"
dfdx = "0.13.0" # maybe completely unneeded
rand = "0.8.5"
better-panic = "0.2.0"
human-panic = "2.0.0"
petgraph-graphml = "3.0.0"

# arkworks
Expand Down
22 changes: 13 additions & 9 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::{collections::HashMap, vec};

use itertools::Itertools;
use luminal::prelude::NodeIndex;
use model::{get_weights, Model, TrainedGraph};
use scalar::{scalar, InputsTracker};
use snark::{MLSnark, SourceMap, SourceType};
use snark::{CircuitField, MLSnark, SourceMap, SourceType};
use tracing::{error, info};

// #![feature(ascii_char)]
Expand All @@ -25,7 +25,7 @@ pub mod utils;
pub const SCALE: usize = 1000000;

/// Main crate export. Take a tensor computation and rewrite to snark.
pub fn compile(c: TrainedGraph) -> MLSnark {
pub fn compile(c: TrainedGraph) -> MLSnark<CircuitField> {
// We set here the weights already. Set input with ::set_input.
let sc = scalar(c.graph);
let mut source_map = HashMap::new();
Expand Down Expand Up @@ -54,6 +54,7 @@ pub fn compile(c: TrainedGraph) -> MLSnark {
scale: SCALE,
source_map: source_map,
og_input_id: c.input_id,
recorded_public_inputs : vec![],
}
}

Expand All @@ -70,28 +71,31 @@ mod tests {
use ark_bls12_381::Bls12_381;
use ark_groth16::Groth16;
use ark_snark::SNARK;
use tracing::info;

use crate::{
compile,
model::{parse_dataset, run_model, TrainParams},
snark::{CircuitField, MLSnark},
utils,
model::{parse_dataset, TrainParams},
snark::{scaled_float, CircuitField},
utils, SCALE,
};

#[test]
pub fn test_trained_into_snark_0() -> Result<(), String> {
utils::init_logging().unwrap();
let err = |e| format!("{:?}", e).to_string();
let data = parse_dataset(include_str!("../../data/rp.data").to_string());
let (_, _model, trained_model) = crate::model::run_model(TrainParams { data, epochs: 1 });
let (_, _model, trained_model) = crate::model::tiny_model::run_model(TrainParams { data, epochs: 1 });
let we = trained_model.weights.clone();
let mut snark = compile(trained_model);
let (pk, vk) = snark.make_keys().map_err(err)?;
// set input
snark.set_input(vec![0.0; 9]);
let proof = snark.make_proof(&pk).map_err(err)?;
let verified = Groth16::<Bls12_381>::verify(&vk, &[CircuitField::from(73)], &proof);
let public_inputs = snark.recorded_public_inputs;
let verified = Groth16::<Bls12_381>::verify(&vk, &public_inputs, &proof);
println!("{:?}", verified);
// assert!(.unwrap());
assert!(verified == Ok(true));
Ok(())
}
}
16 changes: 2 additions & 14 deletions lib/src/model.rs → lib/src/model/medium_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ pub struct TrainParams {
// pub model: Model,
}

#[derive(Debug)]
pub struct TrainedGraph {
pub graph: Graph,
pub input_id: NodeIndex,
Expand Down Expand Up @@ -143,19 +144,6 @@ pub fn run_model(train_params: TrainParams) -> (Graph, Model, TrainedGraph) {
cx.keep_tensors(&weights);
lr.set(5e-3);

// #[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
// cx.compile(
// GenericCompiler::default(),
// (
// &mut input,
// &mut target,
// &mut loss,
// &mut output,
// &mut weights,
// &mut new_weights,
// ),
// );

let (mut loss_avg, mut acc_avg) = (ExponentialAverage::new(1.0), ExponentialAverage::new(0.0));
let start = std::time::Instant::now();
// let EPOCHS = 20;
Expand Down Expand Up @@ -234,7 +222,7 @@ pub struct ExponentialAverage {
}

impl ExponentialAverage {
fn new(initial: f32) -> Self {
pub fn new(initial: f32) -> Self {
ExponentialAverage {
beta: 0.999,
moment: 0.,
Expand Down
5 changes: 5 additions & 0 deletions lib/src/model/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

pub mod tiny_model;
pub mod medium_model;

pub use medium_model::*;
113 changes: 113 additions & 0 deletions lib/src/model/tiny_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::{
any::Any,
collections::HashMap,
convert::TryInto,
fs::{self, File},
iter::zip,
ops::Deref,
path::Path,
};

use luminal::prelude::*;
use luminal_nn::{Linear, ReLU};
use luminal_training::{mse_loss, sgd_on_graph, Autograd};
use petgraph::visit::{EdgeRef, IntoEdgeReferences, IntoNodeReferences};

use crate::{model::{normalize_data, split_dataset, ExponentialAverage, InputsVec, OutputsVec}, scalar::copy_graph_roughly};

use super::{TrainParams, TrainedGraph};

pub type Model = (Linear<9, 1>);

pub fn run_model(train_params: TrainParams) -> (Graph, Model, TrainedGraph) {
let dataset: (InputsVec, OutputsVec) = train_params.data;
let epochs = train_params.epochs;
// Setup gradient graph
let mut cx = Graph::new();
let model = <Model>::initialize(&mut cx);
let input = cx.tensor::<R1<9>>();
let output = model.forward(input).retrieve();

cx.display();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let cx_og = copy_graph_roughly(&cx);

let target = cx.tensor::<R1<1>>();
let loss = mse_loss(output, target).retrieve();
let weights = params(&model);

let grads = cx.compile(Autograd::new(&weights, loss), ());
let (new_weights, lr) = sgd_on_graph(&mut cx, &weights, &grads);
cx.keep_tensors(&new_weights);
cx.keep_tensors(&weights);
lr.set(5e-3);

let (mut loss_avg, mut acc_avg) = (ExponentialAverage::new(1.0), ExponentialAverage::new(0.0));
let start = std::time::Instant::now();
// let EPOCHS = 20;

let (X, Y) = dataset;
let (X_train, _x_test, y_train, _y_test) = split_dataset(X, Y, 0.8);
let X_train = normalize_data(X_train);
let mut iter = 0;
for _ in 0..epochs {
for (x, y) in zip(X_train.iter(), y_train.iter()) {
let answer = [y.to_owned()];
input.set(x.to_owned());
target.set(answer);

cx.execute();
transfer_data_same_graph(&new_weights, &weights, &mut cx);
loss_avg.update(loss.data()[0]);
loss.drop();
// println!("{:}, {:}", output.data()[0], answer[0]);
acc_avg.update(
output
.data()
.into_iter()
.zip(answer)
.filter(|(a, b)| (a - b).abs() < 0.5)
.count() as f32,
);
output.drop();
// println!(
// "Iter {iter} Loss: {:.2} Acc: {:.2}",
// loss_avg.value, acc_avg.value
// );
iter += 1;
}
}
println!("Finished in {iter} iterations");
println!(
"Took {:.2}s, {:.2}µs / iter",
start.elapsed().as_secs_f32(),
start.elapsed().as_micros() / iter
);
cx.display();
let weights_vec = weights
.into_iter()
.map(|a| {
(
a,
cx.tensors
.get(&(a, 0 /* assuming single output */))
.unwrap()
.downcast_ref::<Vec<f32>>()
.unwrap()
.clone()
.into_iter()
.collect(),
)
})
.collect();
(
cx,
model,
TrainedGraph {
graph: cx_og,
input_id: input.id,
weights: weights_vec,
// model: model.clone()
},
)
}
34 changes: 25 additions & 9 deletions lib/src/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use luminal::{
NodeIndex,
},
};
use tracing::{info, instrument, warn};
use tracing::{instrument, warn};

use crate::scalar::ConstantOp;
use crate::scalar::InputOp;
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<F: From<u128>> SourceType<F> {
}
}

fn scaled_float<F: From<u128>>(x: f32, scale: usize) -> F {
pub fn scaled_float<F: From<u128>>(x: f32, scale: usize) -> F {
let y: u128 = (x * (scale as f32)).round() as u128;
F::from(y)
}
Expand Down Expand Up @@ -111,7 +111,7 @@ pub type CircuitField = ark_bls12_381::Fr;
/// and then snark synthesis would rewrite Div_int to a similar circuit as above.
///
#[derive(Debug)]
pub struct MLSnark {
pub struct MLSnark<F> {
pub graph: ScalarGraph,
// start here
pub scale: usize,
Expand All @@ -120,11 +120,16 @@ pub struct MLSnark {
// for convenience
pub og_input_id: NodeIndex,
// pub inputs_tracker : InputsTracker

// this is needed due to some redundancy in how public inputs need to be passed to verify.
// this field is filled up while calling SynthesizeSnark with assignments given to public inputs in order.
// The few last elements record the result of the circuit, last element if single output.
pub recorded_public_inputs : Vec<F>
}

pub type SourceMap = HashMap<NodeIndex, SourceType<f32>>;

impl MLSnark {
impl MLSnark<CircuitField> {
pub fn set_input(&mut self, value: Vec<f32>) {
set_input(
&mut self.source_map,
Expand All @@ -135,7 +140,7 @@ impl MLSnark {
}

pub fn make_keys(
&self,
&mut self,
) -> Result<(ProvingKey<Bls12_381>, VerifyingKey<Bls12_381>), SynthesisError> {
// let cloned = MLSnark {
// graph: self.graph.copy_graph_roughly(),
Expand All @@ -149,7 +154,7 @@ impl MLSnark {
}

// first provide all inputs with the set_input method, otherwise SynthesisError
pub fn make_proof(&self, pk: &ProvingKey<Bls12_381>) -> Result<Proof<Bls12_381>, SynthesisError> {
pub fn make_proof(&mut self, pk: &ProvingKey<Bls12_381>) -> Result<Proof<Bls12_381>, SynthesisError> {
let rng = &mut ark_std::test_rng();
// let cloned = MLSnark {
// graph: self.graph.copy_graph_roughly(),
Expand All @@ -171,7 +176,7 @@ fn set_input(source_map: &mut SourceMap, tracker: &InputsTracker, id: NodeIndex,
}
}

impl ConstraintSynthesizer<CircuitField> for &MLSnark {
impl ConstraintSynthesizer<CircuitField> for &mut MLSnark<CircuitField> {
// THIS-WORKS

#[instrument(level = "debug", name = "generate_constraints")]
Expand All @@ -195,6 +200,14 @@ impl ConstraintSynthesizer<CircuitField> for &MLSnark {
(k, v)
})
.collect();
let mut public_record = vec![];

// return public input variable and assignment but also record it in the map
let mut mk_public_input = |n| {
public_record.push(n);
let v = cs.new_input_variable(|| Ok(n))?;
Ok((v , Some(n)))
};

let pi = petgraph::algo::toposort(&graph.graph, None).unwrap();
let mut vars: HashMap<NodeIndex, ark_relations::r1cs::Variable> = HashMap::new();
Expand Down Expand Up @@ -222,7 +235,7 @@ impl ConstraintSynthesizer<CircuitField> for &MLSnark {
.downcast_ref::<ConstantOp>()
.unwrap();
let n = scaled_float(constant_op.val, scale);
(cs.new_input_variable(|| Ok(n))?, Some(n))
mk_public_input(n)?
} else if graph.check_node_type::<InputOp>(x) {
let src_ty = source_map
.get(&x)
Expand All @@ -233,7 +246,9 @@ impl ConstraintSynthesizer<CircuitField> for &MLSnark {
cs.new_witness_variable(|| mn.ok_or(SynthesisError::AssignmentMissing))?,
mn.clone(),
),
Public(n) => (cs.new_input_variable(|| Ok(*n))?, Some(*n)),
Public(n) =>
mk_public_input(*n)?
,
}
} else {
panic!(
Expand Down Expand Up @@ -330,6 +345,7 @@ impl ConstraintSynthesizer<CircuitField> for &MLSnark {
vars.insert(x, v);
assignments.insert(x, ass);
}
self.recorded_public_inputs = public_record;
Ok(())
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#[cfg(not(debug_assertions))]
use human_panic::setup_panic;
use tracing::subscriber::{self, SetGlobalDefaultError};

#[cfg(debug_assertions)]
extern crate better_panic;

use tracing::subscriber::{self, SetGlobalDefaultError};
use tracing_subscriber::{self, fmt, layer::SubscriberExt};

// [NOTE] tracing
Expand Down

0 comments on commit c1c8bbd

Please sign in to comment.