Skip to content

Commit

Permalink
Merge pull request #217 from zksecurity/feat/uint-ops
Browse files Browse the repository at this point in the history
support + - * / % for uints
  • Loading branch information
katat authored Nov 4, 2024
2 parents 999a574 + 3bff84e commit d740491
Show file tree
Hide file tree
Showing 9 changed files with 546 additions and 6 deletions.
157 changes: 156 additions & 1 deletion src/backends/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
helpers::PrettyField,
imports::FnHandle,
parser::FunctionDef,
var::{Value, Var},
var::{ConstOrCell, Value, Var},
witness::WitnessEnv,
};

Expand Down Expand Up @@ -199,6 +199,161 @@ pub trait Backend: Clone {

Ok(res)
}
Value::Div(lhs, rhs) => {
let res = match (lhs, rhs) {
(ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => {
if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}

// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs / rhs;
Self::Field::from(res)
}
(ConstOrCell::Cell(lhs), ConstOrCell::Const(rhs)) => {
if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}

let lhs = self.compute_var(env, lhs)?;

// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs / rhs;

Self::Field::from(res)
}
(ConstOrCell::Const(lhs), ConstOrCell::Cell(rhs)) => {
let rhs = self.compute_var(env, rhs)?;
if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}

// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs / rhs;

Self::Field::from(res)
}
(ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => {
let lhs = self.compute_var(env, lhs)?;
let rhs = self.compute_var(env, rhs)?;

if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}
// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs / rhs;

Self::Field::from(res)
}
};

env.cached_values.insert(cache_key, res); // cache
Ok(res)
}
Value::Mod(lhs, rhs) => {
match (lhs, rhs) {
(ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => {
if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}

// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs % rhs;
Ok(Self::Field::from(res))
}
(ConstOrCell::Cell(lhs), ConstOrCell::Const(rhs)) => {
if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}

let lhs = self.compute_var(env, lhs)?;

// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs % rhs;

Ok(Self::Field::from(res))
}
(ConstOrCell::Const(lhs), ConstOrCell::Cell(rhs)) => {
let rhs = self.compute_var(env, rhs)?;
if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}

// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs % rhs;

Ok(Self::Field::from(res))
}
(ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => {
let lhs = self.compute_var(env, lhs)?;
let rhs = self.compute_var(env, rhs)?;

if rhs.is_zero() {
return Err(Error::new(
"runtime",
ErrorKind::DivisionByZero,
Span::default(),
));
}
// convert to bigints
let lhs = lhs.to_biguint();
let rhs = rhs.to_biguint();

let res = lhs % rhs;

Ok(Self::Field::from(res))
}
}
}
}
}

Expand Down
39 changes: 37 additions & 2 deletions src/stdlib/bits.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use ark_ff::PrimeField;
use std::vec;

use kimchi::{o1_utils::FieldHelpers, turshi::helper::CairoFieldHelpers};
Expand All @@ -6,22 +7,26 @@ use crate::{
backends::Backend,
circuit_writer::{CircuitWriter, VarInfo},
constants::Span,
error::Result,
error::{Error, ErrorKind, Result},
parser::types::GenericParameters,
var::{ConstOrCell, Value, Var},
};

use super::{FnInfoType, Module};

const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field";
const CHECK_FIELD_SIZE_FN: &str = "check_field_size(cmp: Field)";

pub struct BitsLib {}

impl Module for BitsLib {
const MODULE: &'static str = "bits";

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)> {
vec![(NTH_BIT_FN, nth_bit)]
vec![
(NTH_BIT_FN, nth_bit),
(CHECK_FIELD_SIZE_FN, check_field_size),
]
}
}

Expand Down Expand Up @@ -67,3 +72,33 @@ fn nth_bit<B: Backend>(

Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span)))
}

// Ensure that the field size is not exceeded
fn check_field_size<B: Backend>(
_compiler: &mut CircuitWriter<B>,
_generics: &GenericParameters,
vars: &[VarInfo<B::Field, B::Var>],
span: Span,
) -> Result<Option<Var<B::Field, B::Var>>> {
let var = &vars[0].var[0];
let bit_len = B::Field::size_in_bits() as u64;

match var {
ConstOrCell::Const(cst) => {
let to_cmp = cst.to_u64();
if to_cmp >= bit_len {
return Err(Error::new(
"constraint-generation",
ErrorKind::AssertionFailed,
span,
));
}
Ok(None)
}
ConstOrCell::Cell(_) => Err(Error::new(
"constraint-generation",
ErrorKind::ExpectedConstant,
span,
)),
}
}
8 changes: 6 additions & 2 deletions src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Builtins are imported by default.
use ark_ff::One;
use std::sync::Arc;

use ark_ff::{One, Zero};
use kimchi::o1_utils::FieldHelpers;
use num_bigint::BigUint;

use crate::{
backends::Backend,
Expand All @@ -9,7 +13,7 @@ use crate::{
error::{Error, ErrorKind, Result},
helpers::PrettyField,
parser::types::{GenericParameters, TyKind},
var::{ConstOrCell, Var},
var::{ConstOrCell, Value, Var},
};

use super::{FnInfoType, Module};
Expand Down
77 changes: 77 additions & 0 deletions src/stdlib/int.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use std::vec;

use kimchi::o1_utils::FieldHelpers;

use crate::{
backends::Backend,
circuit_writer::{CircuitWriter, VarInfo},
constants::Span,
error::Result,
parser::types::GenericParameters,
var::{ConstOrCell, Value, Var},
};

use super::{FnInfoType, Module};

const DIVMOD_FN: &str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]";

pub struct IntLib {}

impl Module for IntLib {
const MODULE: &'static str = "int";

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)> {
vec![(DIVMOD_FN, divmod_fn)]
}
}

/// Divides two field elements and returns the quotient and remainder.
fn divmod_fn<B: Backend>(
compiler: &mut CircuitWriter<B>,
_generics: &GenericParameters,
vars: &[VarInfo<B::Field, B::Var>],
span: Span,
) -> Result<Option<Var<B::Field, B::Var>>> {
// we get two vars
let dividend_info = &vars[0];
let divisor_info = &vars[1];

// retrieve the values
let dividend_var = &dividend_info.var[0];
let divisor_var = &divisor_info.var[0];

match (dividend_var, divisor_var) {
// two constants
(ConstOrCell::Const(a), ConstOrCell::Const(b)) => {
// convert to bigints
let a = a.to_biguint();
let b = b.to_biguint();

let quotient = a.clone() / b.clone();
let remainder = a % b;

// convert back to fields
let quotient = B::Field::from_biguint(&quotient).unwrap();
let remainder = B::Field::from_biguint(&remainder).unwrap();

Ok(Some(Var::new(
vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)],
span,
)))
}

_ => {
let quotient = compiler
.backend
.new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span);
let remainder = compiler
.backend
.new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span);

Ok(Some(Var::new(
vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)],
span,
)))
}
}
}
5 changes: 5 additions & 0 deletions src/stdlib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::path::Path;
pub mod bits;
pub mod builtins;
pub mod crypto;
pub mod int;

/// The directory under [NONAME_DIRECTORY] containing the native stdlib.
pub const STDLIB_DIRECTORY: &str = "src/stdlib/native/";
Expand All @@ -27,6 +28,7 @@ pub enum AllStdModules {
Builtins,
Crypto,
Bits,
Int,
}

impl AllStdModules {
Expand All @@ -35,6 +37,7 @@ impl AllStdModules {
AllStdModules::Builtins,
AllStdModules::Crypto,
AllStdModules::Bits,
AllStdModules::Int,
]
}

Expand All @@ -43,6 +46,7 @@ impl AllStdModules {
AllStdModules::Builtins => builtins::BuiltinsLib::get_parsed_fns(),
AllStdModules::Crypto => crypto::CryptoLib::get_parsed_fns(),
AllStdModules::Bits => bits::BitsLib::get_parsed_fns(),
AllStdModules::Int => int::IntLib::get_parsed_fns(),
}
}

Expand All @@ -51,6 +55,7 @@ impl AllStdModules {
AllStdModules::Builtins => builtins::BuiltinsLib::MODULE,
AllStdModules::Crypto => crypto::CryptoLib::MODULE,
AllStdModules::Bits => bits::BitsLib::MODULE,
AllStdModules::Int => int::IntLib::MODULE,
}
}
}
Expand Down
Loading

0 comments on commit d740491

Please sign in to comment.