Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hint function #220

Merged
merged 21 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
370 changes: 354 additions & 16 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ toml = "0.8.8"
constraint_writers = { git = "https://github.com/iden3/circom.git", tag = "v2.1.8" } # to generate r1cs file
num-bigint-dig = "0.6.0" # to adapt for circom lib
rstest = "0.19.0" # for testing different backend cases
rug = "1.26.1" # circ uses this for integer type
circ = { git = "https://github.com/circify/circ", rev = "8140b1369edd5992ede038d2e9e5721510ae7065" } # for compiling to circ IR
circ_fields = { git = "https://github.com/circify/circ", rev = "8140b1369edd5992ede038d2e9e5721510ae7065", subdir = "circ_fields" } # for field types supported by circ
fxhash = "0.2.1" # hash algorithm used by circ
tokio = { version = "1.41.0", features = ["full"] }
tower = "0.5.1"
tower-http = { version = "0.6.1", features = ["trace", "fs"] }
Expand Down
30 changes: 30 additions & 0 deletions examples/fixture/asm/kimchi/hint.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@ noname.0.7.0
@ public inputs: 2

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1,0,0,0,-2>
DoubleGeneric<1,0,0,0,-2>
DoubleGeneric<2,0,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,0,0,0,-16>
DoubleGeneric<1,-1>
DoubleGeneric<1,0,0,0,-3>
DoubleGeneric<1,0,-1,0,1>
DoubleGeneric<1,0,0,0,-1>
DoubleGeneric<1,0,0,0,-1>
DoubleGeneric<1,1>
DoubleGeneric<1,0,-1,0,1>
DoubleGeneric<1,0,0,0,-1>
DoubleGeneric<1,0,0,0,-1>
DoubleGeneric<1,-1>
(0,0) -> (18,0)
(1,0) -> (2,0) -> (9,1)
(4,0) -> (6,1)
(4,2) -> (5,1)
(5,0) -> (7,1) -> (18,1)
(9,0) -> (11,0)
(14,1) -> (15,0)
(15,2) -> (16,0)
16 changes: 16 additions & 0 deletions examples/fixture/asm/r1cs/hint.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@ noname.0.7.0
@ public inputs: 2

2 == (v_2) * (1)
2 == (v_3) * (1)
2 * v_5 == (v_4) * (1)
v_5 == (v_6) * (1)
v_4 == (v_7) * (1)
16 == (v_8) * (1)
v_2 == (v_9) * (1)
3 == (v_10) * (1)
1 == (v_11) * (1)
1 == (v_12) * (1)
1 == (-1 * v_13 + 1) * (1)
1 == (v_14) * (1)
v_4 == (v_1) * (1)
86 changes: 86 additions & 0 deletions examples/hint.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
struct Thing {
xx: Field,
yy: Field,
}

hint fn mul(lhs: Field, rhs: Field) -> Field {
return lhs * rhs;
}

hint fn add_mul_2(lhs: Field, rhs: Field) -> Field {
let sum = lhs + rhs;
return unsafe mul(sum, 2);
}

hint fn div(lhs: Field, rhs: Field) -> Field {
return lhs / rhs;
}

hint fn ite(lhs: Field, rhs: Field) -> Field {
return if lhs != rhs { lhs } else { rhs };
}

hint fn exp(const EXP: Field, val: Field) -> Field {
let mut res = val;

for num in 1..EXP {
res = res * val;
}

return res;
}

hint fn sub(lhs: Field, rhs: Field) -> Field {
return lhs - rhs;
}

hint fn boolean_ops(lhs: Field, rhs: Field) -> [Bool; 3] {
let aa = lhs == rhs;

let bb = aa && false;
let cc = bb || true;

return [aa, bb, cc];
}

hint fn multiple_inputs_outputs(aa: [Field; 2]) -> Thing {
return Thing {
xx: aa[0],
yy: aa[1],
};
}

fn main(pub public_input: Field, private_input: Field) -> Field {
// have to assert these inputs, otherwise it throws vars not in circuit error
assert_eq(public_input, 2);
assert_eq(private_input, 2);

let xx = unsafe add_mul_2(public_input, private_input);
let yy = unsafe mul(public_input, private_input);
assert_eq(xx, yy * 2);

let zz = unsafe div(xx, public_input);
assert_eq(zz, yy);

let ww = unsafe ite(xx, yy);
assert_eq(ww, xx);

let kk = unsafe exp(4, public_input);
assert_eq(kk, 16);

let thing = unsafe multiple_inputs_outputs([public_input, 3]);
// have to include all the outputs from hint function, otherwise it throws vars not in circuit error.
// this is because each individual element in the hint output maps to a separate cell var in noname.
assert_eq(thing.xx, public_input);
assert_eq(thing.yy, 3);

let jj = unsafe sub(thing.xx + 1, public_input);
assert_eq(jj, 1);

let oo = unsafe boolean_ops(2, 2);
assert(oo[0]);
assert(!oo[1]);
assert(oo[2]);

return xx;
}
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"
30 changes: 28 additions & 2 deletions src/backends/kimchi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ pub mod asm;
pub mod builtin;
pub mod prover;

use circ::cfg::{CircCfg, CircOpt};
use educe::Educe;
use rug::Integer;
use std::{
collections::{BTreeMap, HashMap, HashSet},
fmt::Write,
ops::Neg as _,
};

use itertools::{izip, Itertools};
use kimchi::circuits::polynomials::generic::{GENERIC_COEFFS, GENERIC_REGISTERS};
use kimchi::{
circuits::polynomials::generic::{GENERIC_COEFFS, GENERIC_REGISTERS},
o1_utils::FieldHelpers,
};
use serde::{Deserialize, Serialize};

use crate::{
Expand Down Expand Up @@ -39,7 +44,28 @@ pub const NUM_REGISTERS: usize = kimchi::circuits::wires::COLUMNS;

use super::{Backend, BackendField, BackendVar};

impl BackendField for VestaField {}
impl BackendField for VestaField {
fn to_circ_field(&self) -> circ_fields::FieldV {
let mut opt = CircOpt::default();

// define the modulus for the field
opt.field.custom_modulus = VestaField::modulus_biguint().to_str_radix(10);

let cfg = CircCfg::from(opt);

let cfg_f = cfg.field();
let int = Integer::from_str_radix(&self.to_biguint().to_str_radix(10), 10).unwrap();

cfg_f.new_v(int)
}

fn to_circ_type() -> circ_fields::FieldT {
let digits = VestaField::modulus_biguint().to_bytes_le();
circ_fields::FieldT::IntField(
Integer::from_digits::<u8>(&digits, rug::integer::Order::Lsf).into(),
)
}
}

#[derive(Debug)]
pub struct Witness(Vec<[VestaField; NUM_REGISTERS]>);
Expand Down
51 changes: 48 additions & 3 deletions src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::{fmt::Debug, hash::Hash, str::FromStr};
use std::{fmt::Debug, str::FromStr};

use ::kimchi::o1_utils::FieldHelpers;
use ark_ff::{Field, One, Zero};
use ark_ff::{Field, One, PrimeField, Zero};
use circ::ir::term::precomp::PreComp;
use fxhash::FxHashMap;
use num_bigint::BigUint;

use crate::{
Expand All @@ -11,7 +13,6 @@ use crate::{
error::{Error, ErrorKind, Result},
helpers::PrettyField,
imports::FnHandle,
parser::FunctionDef,
var::{ConstOrCell, Value, Var},
witness::WitnessEnv,
};
Expand All @@ -28,6 +29,8 @@ pub mod r1cs;
pub trait BackendField:
Field + FromStr + TryFrom<BigUint> + TryInto<BigUint> + Into<BigUint> + PrettyField
{
fn to_circ_field(&self) -> circ_fields::FieldV;
fn to_circ_type() -> circ_fields::FieldT;
}

/// This trait allows different backends to have different cell var types.
Expand Down Expand Up @@ -199,6 +202,48 @@ pub trait Backend: Clone {

Ok(res)
}
Value::HintIR(t, named_vars) => {
let mut precomp = PreComp::new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lgtm, but I would suggest adding some unit tests here to check corner cases (0, p, p-1, p+1).

// For hint evaluation purpose, precomp only has only one output and no connections with other parts,
// so just use a dummy output var name.
precomp.add_output("x".to_string(), t.clone());

// map the named vars to env
let env = named_vars
.iter()
.map(|(name, var)| {
let val = match var {
crate::var::ConstOrCell::Const(cst) => cst.to_circ_field(),
crate::var::ConstOrCell::Cell(var) => {
let val = self.compute_var(env, var).unwrap();
val.to_circ_field()
}
};
(name.clone(), circ::ir::term::Value::Field(val))
})
.collect::<FxHashMap<String, circ::ir::term::Value>>();

// evaluate and get the only one output
let eval_map = precomp.eval(&env);
let value = eval_map.get("x").unwrap();
// convert to field
let res = match value {
circ::ir::term::Value::Field(f) => {
let bytes = f.i().to_digits::<u8>(rug::integer::Order::Lsf);
Self::Field::from_le_bytes_mod_order(&bytes)
}
circ::ir::term::Value::Bool(b) => {
if *b {
Self::Field::one()
} else {
Self::Field::zero()
}
}
_ => panic!("unexpected output type"),
};

Ok(res)
}
Value::Div(lhs, rhs) => {
let res = match (lhs, rhs) {
(ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => {
Expand Down
45 changes: 43 additions & 2 deletions src/backends/r1cs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ pub mod snarkjs;
use std::collections::{HashMap, HashSet};

use ark_ff::FpParameters;
use circ::cfg::{CircCfg, CircOpt};
use circ_fields::FieldV;
use itertools::{izip, Itertools as _};
use kimchi::o1_utils::FieldHelpers;
use num_bigint::BigUint;
use rug::Integer;
use serde::{Deserialize, Serialize};

use crate::circuit_writer::VarInfo;
Expand All @@ -22,8 +26,45 @@ pub type R1csBls12381Field = ark_bls12_381::Fr;
pub type R1csBn254Field = ark_bn254::Fr;

// Because the associated field type is BackendField, we need to implement it for the actual field types in order to use them.
impl BackendField for R1csBls12381Field {}
impl BackendField for R1csBn254Field {}
impl BackendField for R1csBls12381Field {
fn to_circ_field(&self) -> FieldV {
let mut opt = CircOpt::default();

// define the modulus for the field
opt.field.custom_modulus = R1csBls12381Field::modulus_biguint().to_str_radix(10);

let cfg = CircCfg::from(opt);

let cfg_f = cfg.field();
let int = Integer::from_str_radix(&self.to_biguint().to_str_radix(10), 10).unwrap();

cfg_f.new_v(int)
}

fn to_circ_type() -> circ_fields::FieldT {
circ_fields::FieldT::FBls12381
}
}

impl BackendField for R1csBn254Field {
fn to_circ_field(&self) -> FieldV {
let mut opt = CircOpt::default();

// define the modulus for the field
opt.field.custom_modulus = R1csBn254Field::modulus_biguint().to_str_radix(10);

let cfg = CircCfg::from(opt);

let cfg_f = cfg.field();
let int = Integer::from_str_radix(&self.to_biguint().to_str_radix(10), 10).unwrap();

cfg_f.new_v(int)
}

fn to_circ_type() -> circ_fields::FieldT {
circ_fields::FieldT::FBn254
}
}

#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CellVar {
Expand Down
Loading
Loading