Skip to content

Commit

Permalink
Allow to compile with GenericCompiler before snark synthesiz
Browse files Browse the repository at this point in the history
  • Loading branch information
zmrocze committed Jul 24, 2024
1 parent d7e088e commit 8ce0d4b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 54 deletions.
3 changes: 3 additions & 0 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ authors = [""]
description = "Main"
edition = "2018"

[lib]
doctest = false

[dependencies]
axum = "0.7.5"
reqwest = "0.12.5"
Expand Down
15 changes: 9 additions & 6 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ pub mod scalar;
pub mod snark;
pub mod utils;

pub const SCALE: ScaleT = ScaleT {s : 1_000, z : u128::MAX << 1 /* ~ 1e38 */}; // giving float range from about -1e33 to 1e33
pub const SCALE: ScaleT = ScaleT {s : 1_000_000, z : u128::MAX << 1 /* ~ 1e38 */}; // giving float range from about -1e29 to 1e29

/// Main crate export. Take a tensor computation and rewrite to snark.
pub fn compile(c: &TrainedGraph) -> MLSnark<CircuitField> {
let graph = copy_graph_roughly(&c.graph);
let weights = c.weights.clone();
let graph_for_snark = c.graph.copy_graph_roughly();
let graph = graph_for_snark.graph;
let weights = graph_for_snark.weights;
let input_id = graph_for_snark.input_id;
// let weights = c.weights.clone();
// We set here the weights already. Set input with ::set_input.
let sc = scalar(graph);
let mut source_map = HashMap::new();
Expand All @@ -45,7 +48,7 @@ pub fn compile(c: &TrainedGraph) -> MLSnark<CircuitField> {
let little_ids = sc
.inputs_tracker
.new_inputs
.get(&c.input_id)
.get(&input_id)
.unwrap_or_else(|| panic!("Wrong input id"));
for little_id in little_ids.into_iter() {
source_map.insert(*little_id, SourceType::Private(None));
Expand All @@ -54,7 +57,7 @@ pub fn compile(c: &TrainedGraph) -> MLSnark<CircuitField> {
graph: sc,
scale: SCALE,
source_map: source_map,
og_input_id: c.input_id,
og_input_id: input_id,
recorded_public_inputs: vec![],
}
}
Expand Down Expand Up @@ -106,7 +109,7 @@ mod tests {
diff,
"The snark evaluates to the correct result (~ float precision)"
);
tracing::info!("evaluated the model to {:?}, which is represented by a field element {:?}. Also evaluated the snark to a field element {:?}. The two results are within 0.01 float margin. Verifier correctly verified the proof that snark evaluates to that value.", model_eval_res_float, model_eval_result, snark_eval_result);
tracing::info!("evaluated the model to {:?}, which is represented by a field element {:?}. Also evaluated the snark to a field element {:?}. The two results are within 0.05 float margin. Verifier correctly verified the proof that snark evaluates to that value.", model_eval_res_float, model_eval_result, snark_eval_result);

drop(scope);
Ok(())
Expand Down
34 changes: 19 additions & 15 deletions lib/src/model/lessthan_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use luminal_training::{mse_loss, sgd_on_graph, Autograd};
use tracing::info;

use crate::{
model::{normalize_data, split_dataset, ExponentialAverage, InputsVec, OutputsVec},
model::{normalize_data, split_dataset, ExponentialAverage, GraphForSnark, InputsVec, OutputsVec},
scalar::copy_graph_roughly,
};

Expand All @@ -25,18 +25,19 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {

// todo: remove x=n
// cx.display();
// cx.compile(
// GenericCompiler::default(),
// (
// &mut input,
// &mut output,
// ),
// );
cx.compile(
GenericCompiler::default(),
(
&mut input,
&mut output,
),
);

// cx.display();
// cx.display_shapes();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let cx_og = copy_graph_roughly(&cx);
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;

let target = cx.tensor::<R1<1>>();
let loss = mse_loss(output, target).retrieve();
Expand Down Expand Up @@ -95,7 +96,7 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
.into_iter()
.map(|a| {
(
a,
remap[&a],
cx.tensors
.get(&(a, 0 /* assuming single output */))
.unwrap()
Expand All @@ -108,11 +109,14 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
})
.collect();
TrainedGraph {
graph : GraphForSnark {
graph: cx_og,
weights: weights_vec,
input_id,
},
cx: cx,
graph: cx_og,
input_id: input.id,
output_id: output.id,
target_id: target.id,
weights: weights_vec,
cx_output_id: output.id,
cx_input_id: input.id,
cx_target_id: target.id,
}
}
47 changes: 32 additions & 15 deletions lib/src/model/medium_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,42 @@ pub struct TrainParams {
}

#[derive(Debug)]
pub struct TrainedGraph {
pub struct GraphForSnark {
pub graph: Graph,
pub input_id: NodeIndex,
pub weights: Vec<(NodeIndex, Vec<f32>)>,
pub weights: Vec<(NodeIndex, Vec<f32>)>,
}

impl GraphForSnark {
pub fn copy_graph_roughly(&self) -> Self {
let (g, remap) = copy_graph_roughly(&self.graph);
GraphForSnark { graph: g, input_id: remap[&self.input_id],
weights: self.weights.iter().map(|(a, b)| (remap[a], b.clone())).collect() }
}
}

#[derive(Debug)]
pub struct TrainedGraph {
pub graph: GraphForSnark,
// below are needed to evaluate the model:
pub target_id: NodeIndex, // needed for evaluation, mostly tests
pub cx: Graph, // full trained graph, the above "graph" is a rough copy
pub output_id: NodeIndex,
pub cx_input_id: NodeIndex, // needed for evaluation, mostly tests
pub cx_target_id: NodeIndex, // needed for evaluation, mostly tests
pub cx_output_id: NodeIndex,
}

// todo: make general
// impl<I: Shape, M : Module<GraphTensor<I>>>
impl TrainedGraph {
pub fn evaluate(&mut self, input_data: Vec<f32>) -> Vec<f32> {
self.cx.get_op_mut::<Function>(self.input_id).1 =
self.cx.get_op_mut::<Function>(self.cx_input_id).1 =
Box::new(move |_| vec![Tensor::new(input_data.to_owned())]);
self.cx.get_op_mut::<Function>(self.target_id).1 =
self.cx.get_op_mut::<Function>(self.cx_target_id).1 =
Box::new(move |_| vec![Tensor::new(vec![0.0])]); // doesnt matter
self.cx.execute();
let d = self
.cx
.get_tensor_ref(self.output_id, 0)
.get_tensor_ref(self.cx_output_id, 0)
.unwrap()
.clone()
.downcast_ref::<Vec<f32>>()
Expand All @@ -154,7 +168,7 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {

// cx.display();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let cx_og = copy_graph_roughly(&cx);
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;

let target = cx.tensor::<R1<1>>();
Expand Down Expand Up @@ -213,7 +227,7 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
.into_iter()
.map(|a| {
(
a,
remap[&a],
cx.tensors
.get(&(a, 0 /* assuming single output */))
.unwrap()
Expand All @@ -227,12 +241,15 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
.collect();
assert!(input_id == input.id);
TrainedGraph {
cx,
graph: cx_og,
input_id: input_id,
output_id: output.id,
target_id: target.id,
weights: weights_vec,
graph : GraphForSnark {
graph: cx_og,
weights: weights_vec,
input_id,
},
cx: cx,
cx_output_id: output.id,
cx_input_id: input.id,
cx_target_id: target.id,
}
}

Expand Down
20 changes: 12 additions & 8 deletions lib/src/model/tiny_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use luminal_training::{mse_loss, sgd_on_graph, Autograd};
use tracing::info;

use crate::{
model::{normalize_data, split_dataset, ExponentialAverage, InputsVec, OutputsVec},
model::{normalize_data, split_dataset, ExponentialAverage, GraphForSnark, InputsVec, OutputsVec},
scalar::copy_graph_roughly,
};

Expand All @@ -25,7 +25,8 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {

// cx.display();
// record graph without gradients. assuming nodeids dont change in Autograd::compile
let cx_og = copy_graph_roughly(&cx);
let (cx_og, remap) = copy_graph_roughly(&cx);
let input_id = input.id;

let target = cx.tensor::<R1<1>>();
let loss = mse_loss(output, target).retrieve();
Expand Down Expand Up @@ -84,7 +85,7 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
.into_iter()
.map(|a| {
(
a,
remap[&a],
cx.tensors
.get(&(a, 0 /* assuming single output */))
.unwrap()
Expand All @@ -98,11 +99,14 @@ pub fn run_model(train_params: TrainParams) -> TrainedGraph {
.collect();

TrainedGraph {
graph : GraphForSnark {
graph: cx_og,
weights: weights_vec,
input_id,
},
cx: cx,
graph: cx_og,
input_id: input.id,
output_id: output.id,
target_id: target.id,
weights: weights_vec,
cx_output_id: output.id,
cx_input_id: input.id,
cx_target_id: target.id,
}
}
30 changes: 21 additions & 9 deletions lib/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ pub struct ScalarGraph {

impl ScalarGraph {
pub fn copy_graph_roughly(&self) -> Self {
let (g, remap ) = copy_graph_roughly(&self.graph);
ScalarGraph {
graph: copy_graph_roughly(&self.graph),
graph: g,
inputs_tracker: self.inputs_tracker.clone(),
}
}
Expand Down Expand Up @@ -141,6 +142,17 @@ pub struct InputsTracker {
pub new_inputs: HashMap<NodeIndex, Vec<NodeIndex>>,
}

impl InputsTracker {
pub fn remap(&self, remap: HashMap<NodeIndex, NodeIndex>) -> Self {
let mut m = HashMap::new();
for (k, v) in self.new_inputs.iter() {
m.insert(*k, v.iter().map(|x| *remap.get(x).unwrap()).collect());
}
InputsTracker { new_inputs: m }
}

}

#[derive(Debug, Default)]
pub struct Scalarize;

Expand Down Expand Up @@ -486,9 +498,9 @@ pub fn pretty_print_g(graph: &Graph) -> Result<(), Box<dyn Error>> {

// copies things that are relevant. very much not exact copy
// Expects a graph with indices from the [0..n] range without gaps (check the commented lines).
pub fn copy_graph_roughly(src: &Graph) -> Graph {
pub fn copy_graph_roughly(src: &Graph) -> (Graph, HashMap<NodeIndex, NodeIndex>) {
let mut g = Graph::new();
// let mut map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let mut map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
// copy nodes
for x in src.node_indices().sorted() {
let n = if src.check_node_type::<Add>(x) {
Expand Down Expand Up @@ -526,21 +538,21 @@ pub fn copy_graph_roughly(src: &Graph) -> Graph {
src.node_weight(x).unwrap().type_name()
)
};
// map.insert(x, n);
assert!(x == n)
map.insert(x, n);
// assert!(x == n)
}
// copy edges
for e in src.edge_references() {
g.add_edge(e.source(), e.target(), e.weight().clone());
// g.add_edge(map[&e.source()], map[&e.target()], e.weight().clone());
// g.add_edge(e.source(), e.target(), e.weight().clone());
g.add_edge(map[&e.source()], map[&e.target()], e.weight().clone());
}
// copy retrieval marks
// src.to_retrieve.iter().for_each(|(id, sh)| {g.to_retrieve.insert(map[id], *sh);});
src.to_retrieve.iter().for_each(|(id, sh)| {
g.to_retrieve.insert(*id, *sh);
g.to_retrieve.insert(map[id], *sh);
});

g
(g, map)
}

#[cfg(test)]
Expand Down
7 changes: 6 additions & 1 deletion lib/src/snark/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use ark_snark::SNARK;
use blake2::digest::generic_array::typenum::uint;
use itertools::Itertools;

use luminal::prelude::petgraph::data::DataMap;
use luminal::prelude::petgraph::Direction::Outgoing;

///
Expand All @@ -36,7 +37,7 @@ use luminal::{
},
};
use num_bigint::{BigInt, BigUint, ToBigInt};
use tracing::{instrument, warn};
use tracing::{info, instrument, warn};

use crate::scalar::ConstantOp;
use crate::scalar::InputOp;
Expand Down Expand Up @@ -517,6 +518,10 @@ impl ConstraintSynthesizer<CircuitField> for &mut MLSnark<CircuitField> {
panic!("No n-ary ops for n>2")
}
};

// let nd_ty = graph.node_weight(x).unwrap().as_any().type_id();
// info!("{:?}: {:?} of {:?}", x, nd_ty, ass.clone().map(|x| unscaled_bigint(x, &scale)));

vars.insert(x, v);
assignments.insert(x, ass.clone());

Expand Down

0 comments on commit 8ce0d4b

Please sign in to comment.