Skip to content

Commit

Permalink
Split module with scaling helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
zmrocze committed Jul 24, 2024
1 parent 98f32fe commit 5b1b462
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 158 deletions.
12 changes: 2 additions & 10 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{collections::HashMap, vec};

use model::TrainedGraph;
use scalar::{copy_graph_roughly, scalar};
use snark::{CircuitField, MLSnark, SourceType};
use snark::{scaling_helpers::ScaleT, CircuitField, MLSnark, SourceType};

// #![feature(ascii_char)]
//
Expand All @@ -21,13 +21,6 @@ pub mod scalar;
pub mod snark;
pub mod utils;

// pub type ScaleT = u64;
#[derive(Debug, Clone, Copy)]
pub struct ScaleT {
s : u128,
z : u128
}

pub const SCALE: ScaleT = ScaleT {s : 1_000, z : u128::MAX << 1 /* ~ 1e38 */}; // giving float range from about -1e33 to 1e33

/// Main crate export. Take a tensor computation and rewrite to snark.
Expand Down Expand Up @@ -72,11 +65,10 @@ mod tests {
use crate::{
compile,
model::{parse_dataset, TrainParams, TrainedGraph},
snark::{f_from_bigint_unsafe, field_close_as_floats, field_elems_close, scaled_float, CircuitField},
snark::{CircuitField, scaling_helpers::{f_from_bigint_unsafe, field_close_as_floats, scaled_float, }},
SCALE,
};
use ark_bls12_381::Bls12_381;
use ark_ff::Field;
use ark_groth16::Groth16;
use ark_snark::SNARK;
use itertools::Itertools;
Expand Down
4 changes: 4 additions & 0 deletions lib/src/snark/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

mod snark;
pub mod scaling_helpers;
pub use snark::*;
101 changes: 101 additions & 0 deletions lib/src/snark/scaling_helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::convert::{TryFrom, TryInto};
use std::{fmt::Debug, ops::Div};

use ark_ff::{BigInteger256, PrimeField};
use ark_relations::r1cs::SynthesisError;
use num_bigint::{BigInt, BigUint, ToBigInt};
use super::CircuitField;

#[derive(Debug, Clone, Copy)]
/// Defines a scaling of a float by: x => round(s * x) + z
pub struct ScaleT {
pub s : u128,
pub z : u128
}

/// Convert a float to a scaled integer.
///
/// See [Note: floats as ints]
pub fn scaled_float(x: f32, scale: &ScaleT) -> BigInt {
// // TODO: handle errors upstream
let s = scale.s;
let z = scale.z;
let x : f64 = x.into();
// assert!( (- (z as f64) / (s as f64) <= x) && (x <= (z as f64) / (s as f64)) , "Float within allowed range");
let scaled: BigInt = ((x * (s as f64)).round()).to_bigint().expect("scaled_float: Conversion to bigint failed");
scaled + z
// todo: handle the unwrap upstream
// assert!(y.is_positive(), "Scaled float outside of the range!");
// assert!( ((((y - z) / s) as f64) - x).abs() <= x * 0.0001 , "Float is recoverable");
}

// TODO: factor out a module with conversion helpers
pub fn unscaled_f(x : CircuitField, scale: &ScaleT) -> Option<f32> {
unscaled_bigint(i256_to_bigint(x.into_repr()), scale)
}

pub fn unscaled_bigint(x: BigInt, scale: &ScaleT) -> Option<f32> {
// // TODO: handle errors upstream
let s = scale.s;
let z = scale.z;
let div: i128 = ((x.clone() - z) / s).try_into().ok()?;
let rem: u64 = ((x - z) % s).try_into().ok()?;

Some(((div as f64) + ((rem as f64) / (s as f64))) as f32)
// todo: handle the unwrap upstream
// assert!(y.is_positive(), "Scaled float outside of the range!");
// assert!( ((((y - z) / s) as f64) - x).abs() <= x * 0.0001 , "Float is recoverable");
}

pub fn positive_bigint(b: BigInt) -> BigUint {
b.try_into().expect("Expects positive bigint, otherwise its negative float overflow")
}
pub fn f_from_bigint(b: BigInt) -> Result<CircuitField, SynthesisError> {
CircuitField::try_from(positive_bigint(b)).map_err(|_| SynthesisError::AssignmentMissing)
}
pub fn f_from_bigint_unsafe(b: BigInt) -> CircuitField {
f_from_bigint(b).expect("Expects bigint to fit in the prime field range, otherwise its positive float overflow")
}

pub fn i256_to_bigint(a: BigInteger256) -> BigInt {
let x : BigUint = a.into();
x.into()
// // let x: u128 = a.try_into();
// (BigInt::from(q) << (64 * 3)) + (BigInt::from(w) << (64 * 2)) + (BigInt::from(e) << 64) + BigInt::from(r)
}

pub fn field_elems_close(a : CircuitField , b : CircuitField, scale: ScaleT) -> bool {
let a = i256_to_bigint(a.into_repr());
let b = i256_to_bigint(b.into_repr());
let diff = if a < b {b.clone() - a.clone()} else {a.clone() - b.clone()};
diff.le(
& ( (a.max(b)).div(scale.s * 100) )
)
}

pub fn floats_close(a : f32, b: f32) -> bool {
(a - b).abs().le( & (0.001 * (a.abs() + b.abs()).max(1.0) ) )
}
pub fn bigints_close_as_floats(a : BigInt, b: BigInt, scale: &ScaleT) -> bool {
let ab = || {
let aa = unscaled_bigint(a, scale)?;
let bb = unscaled_bigint(b, scale)?;
Some((aa, bb))
};
match ab() {
None => false,
Some((a, b)) => floats_close(a, b)
}
}

pub fn field_close_as_floats(a : CircuitField, b: CircuitField, scale: &ScaleT) -> bool {
let ab = || {
let aa = unscaled_f(a, scale)?;
let bb = unscaled_f(b, scale)?;
Some((aa, bb))
};
match ab() {
None => false,
Some((a, b)) => floats_close(a, b)
}
}
190 changes: 42 additions & 148 deletions lib/src/snark.rs → lib/src/snark/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::scalar::ConstantOp;
use crate::scalar::InputOp;
// use crate::model::copy_graph_roughly;
use crate::scalar::{InputsTracker, ScalarGraph};
use crate::ScaleT;
use crate::snark::scaling_helpers::*;

/// Tensor computation is initialized by setting input tensors data and then evaluating.
/// This function takes a mapping from input index to its tensor and creates a
Expand Down Expand Up @@ -82,8 +82,45 @@ impl SourceType<BigUint> {
}
}

/// Convert a float to a scaled integer.
///
// ++
// A ++ B := A+B-z
// Operation that corresponds to the addition on floats being encoded by the uints by scaled_float mapping.
// See [Note: floats as ints]
fn add_add(a : BigInt, b : BigInt, scale: &ScaleT) -> BigInt {
// todo: handle errors upstream
a + b - scale.z
// assert!( (a < u128::MAX - b) && ( a + b > z ), "Addition doesn't overflow." );
// r.try_into().ok()
}

#[derive(Debug, Clone)]
pub struct DivisionResult {
result: BigInt,
remainder: BigInt,
}

// **
// A ** B := (A * B + z*s + z*z - A*z - B*z) / s
// Operation that corresponds to the multiplication on floats being encoded by the uints by scaled_float mapping.
// See [Note: floats as ints]
pub fn mul_mul(a : BigInt, b : BigInt, scale: &ScaleT) -> DivisionResult {
// todo: handle errors upstream
let (s, z) = (scale.s, scale.z);
let r = a.clone() * b.clone() + BigInt::from(z) * s + BigInt::from(z) * z - a * z - b * z;
let ( div , rem ) = (r.clone() / BigInt::from(s), r.clone() % BigInt::from(s));
DivisionResult { result: div, remainder: rem }
}

pub type Curve = ark_bls12_381::Bls12_381;
pub type CircuitField = ark_bls12_381::Fr;

///
/// NOTE on integer vs float computation:
///
/// The ML computation is obviously meant to evaluate to floats.
/// If we were to take the static description of the expression for evaluation, but treat all Op's as if
/// they act on integers - then what changes do we need to do to the expression?
///
/// [Note: floats as integers]
/// ## intro
Expand Down Expand Up @@ -190,96 +227,6 @@ impl SourceType<BigUint> {
/// BUT still the operations done in the prime field may overflow, without us noticing leading to wrong results.
/// The solution is to add careful asserts on the ranges of subresults of all calculations - boring, lets do that later.
///
pub fn scaled_float(x: f32, scale: &ScaleT) -> BigInt {
// // TODO: handle errors upstream
let s = scale.s;
let z = scale.z;
let x : f64 = x.into();
// assert!( (- (z as f64) / (s as f64) <= x) && (x <= (z as f64) / (s as f64)) , "Float within allowed range");
let scaled: BigInt = ((x * (s as f64)).round()).to_bigint().expect("scaled_float: Conversion to bigint failed");
scaled + z
// todo: handle the unwrap upstream
// assert!(y.is_positive(), "Scaled float outside of the range!");
// assert!( ((((y - z) / s) as f64) - x).abs() <= x * 0.0001 , "Float is recoverable");
}

// TODO: factor out a module with conversion helpers
pub fn unscaled_f(x : CircuitField, scale: &ScaleT) -> Option<f32> {
unscaled_bigint(i256_to_bigint(x.into_repr()), scale)
}

pub fn unscaled_bigint(x: BigInt, scale: &ScaleT) -> Option<f32> {
// // TODO: handle errors upstream
let s = scale.s;
let z = scale.z;
let div: i128 = ((x.clone() - z) / s).try_into().ok()?;
let rem: u64 = ((x - z) % s).try_into().ok()?;

Some(((div as f64) + ((rem as f64) / (s as f64))) as f32)
// todo: handle the unwrap upstream
// assert!(y.is_positive(), "Scaled float outside of the range!");
// assert!( ((((y - z) / s) as f64) - x).abs() <= x * 0.0001 , "Float is recoverable");
}


// ++
// A ++ B := A+B-z
// Operation that corresponds to the addition on floats being encoded by the uints by scaled_float mapping.
// See [Note: floats as ints]
fn add_add(a : BigInt, b : BigInt, scale: &ScaleT) -> BigInt {
// todo: handle errors upstream
a + b - scale.z
// assert!( (a < u128::MAX - b) && ( a + b > z ), "Addition doesn't overflow." );
// r.try_into().ok()
}

#[derive(Debug, Clone)]
struct DivisionResult {
result: BigInt,
remainder: BigInt,
}

// **
// A ** B := (A * B + z*s + z*z - A*z - B*z) / s
// Operation that corresponds to the multiplication on floats being encoded by the uints by scaled_float mapping.
// See [Note: floats as ints]
fn mul_mul(a : BigInt, b : BigInt, scale: &ScaleT) -> DivisionResult {
// todo: handle errors upstream
let (s, z) = (scale.s, scale.z);
let r = a.clone() * b.clone() + BigInt::from(z) * s + BigInt::from(z) * z - a * z - b * z;
let ( div , rem ) = (r.clone() / BigInt::from(s), r.clone() % BigInt::from(s));
DivisionResult { result: div, remainder: rem }
}


pub type Curve = ark_bls12_381::Bls12_381;
pub type CircuitField = ark_bls12_381::Fr;

///
/// NOTE on integer vs float computation:
///
/// The ML computation is obviously meant to evaluate to floats.
/// If we were to take the static description of the expression for evaluation, but treat all Op's as if
/// they act on integers - then what changes do we need to do to the expression?
///
/// We define a scale factor and use integer `round(scale * f)` to represent a float `f`.
/// Firstly, we scale the inputs by scale factor.
/// Addition and operations are fine as is.
/// Mul needs to divide the result by scale, sth along the lines for Recip, etc. LessThan probably needs to divide by scale (?).
/// In the end result is multiplied by scale.
///
/// - Recip: n = f * s. 1/f = s/n. So we represent Recip(n) as s^2/n, where / is in F?
///
/// Q: There is two ways in terms of code structure to implement this.
/// We can separate it into a compilation step or we can combine this step with snark synthesis.
/// Both are fine.
/// For example, in snark we see multiplication and
/// we'd like to just say: Mul_float a b => (Mul_int a' b') / scale
/// But because can't divide (TODO: can we?) we instead take additional witness for the division result and say:
/// Mul_float a b => (if (Mul_int a' b' == witness * scale)) then witness else abort
/// If doing a seperate integer step we'd say: Mul_float a b => (Div_int scale (Mul_int a' b'))
/// and then snark synthesis would rewrite Div_int to a similar circuit as above.
///
#[derive(Debug)]
pub struct MLSnark<F> {
pub graph: ScalarGraph,
Expand Down Expand Up @@ -343,16 +290,6 @@ fn set_input(source_map: &mut SourceMap, tracker: &InputsTracker, id: NodeIndex,
}
}

pub fn positive_bigint(b: BigInt) -> BigUint {
b.try_into().expect("Expects positive bigint, otherwise its negative float overflow")
}
fn f_from_bigint(b: BigInt) -> Result<CircuitField, SynthesisError> {
CircuitField::try_from(positive_bigint(b)).map_err(|_| SynthesisError::AssignmentMissing)
}
pub fn f_from_bigint_unsafe(b: BigInt) -> CircuitField {
f_from_bigint(b).expect("Expects bigint to fit in the prime field range, otherwise its positive float overflow")
}

impl ConstraintSynthesizer<CircuitField> for &mut MLSnark<CircuitField> {

#[instrument(level = "debug", name = "generate_constraints")]
Expand Down Expand Up @@ -609,59 +546,16 @@ impl ConstraintSynthesizer<CircuitField> for &mut MLSnark<CircuitField> {
}
}

pub fn i256_to_bigint(a: BigInteger256) -> BigInt {
let x : BigUint = a.into();
x.into()
// // let x: u128 = a.try_into();
// (BigInt::from(q) << (64 * 3)) + (BigInt::from(w) << (64 * 2)) + (BigInt::from(e) << 64) + BigInt::from(r)
}

pub fn field_elems_close(a : CircuitField , b : CircuitField, scale: ScaleT) -> bool {
let a = i256_to_bigint(a.into_repr());
let b = i256_to_bigint(b.into_repr());
let diff = if a < b {b.clone() - a.clone()} else {a.clone() - b.clone()};
diff.le(
& ( (a.max(b)).div(scale.s * 100) )
)
}

pub fn floats_close(a : f32, b: f32) -> bool {
(a - b).abs().le( & (0.001 * (a.abs() + b.abs()).max(1.0) ) )
}
pub fn bigints_close_as_floats(a : BigInt, b: BigInt, scale: &ScaleT) -> bool {
let ab = || {
let aa = unscaled_bigint(a, scale)?;
let bb = unscaled_bigint(b, scale)?;
Some((aa, bb))
};
match ab() {
None => false,
Some((a, b)) => floats_close(a, b)
}
}

pub fn field_close_as_floats(a : CircuitField, b: CircuitField, scale: &ScaleT) -> bool {
let ab = || {
let aa = unscaled_f(a, scale)?;
let bb = unscaled_f(b, scale)?;
Some((aa, bb))
};
match ab() {
None => false,
Some((a, b)) => floats_close(a, b)
}
}

mod tests {
use num_bigint::BigInt;
// use ark_ff::PrimeField;
// use quickcheck::quickcheck;
use proptest::prelude::*;
use proptest::num::f32::{POSITIVE, NEGATIVE};
use std::ops::Div;
use crate::snark::{f_from_bigint_unsafe, field_elems_close, mul_mul, scaled_float, unscaled_bigint, unscaled_f, CircuitField};
use crate::snark::{mul_mul, CircuitField};
use crate::SCALE;

use crate::snark::scaling_helpers::*;
use super::bigints_close_as_floats;

proptest! {
Expand Down

0 comments on commit 5b1b462

Please sign in to comment.