diff --git a/src/backends/mod.rs b/src/backends/mod.rs index 22d9efde4..431fa4c69 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -12,7 +12,7 @@ use crate::{ helpers::PrettyField, imports::FnHandle, parser::FunctionDef, - var::{Value, Var}, + var::{ConstOrCell, Value, Var}, witness::WitnessEnv, }; @@ -199,6 +199,161 @@ pub trait Backend: Clone { Ok(res) } + Value::Div(lhs, rhs) => { + let res = match (lhs, rhs) { + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs / rhs; + Self::Field::from(res) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Const(rhs)) => { + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + + let lhs = self.compute_var(env, lhs)?; + + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs / rhs; + + Self::Field::from(res) + } + (ConstOrCell::Const(lhs), ConstOrCell::Cell(rhs)) => { + let rhs = self.compute_var(env, rhs)?; + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs / rhs; + + Self::Field::from(res) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { + let lhs = self.compute_var(env, lhs)?; + let rhs = self.compute_var(env, rhs)?; + + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs / rhs; + + Self::Field::from(res) + } + }; + + env.cached_values.insert(cache_key, res); // cache + Ok(res) + } + Value::Mod(lhs, rhs) => { + match (lhs, rhs) { + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs % rhs; + Ok(Self::Field::from(res)) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Const(rhs)) => { + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + + let lhs = self.compute_var(env, lhs)?; + + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs % rhs; + + Ok(Self::Field::from(res)) + } + (ConstOrCell::Const(lhs), ConstOrCell::Cell(rhs)) => { + let rhs = self.compute_var(env, rhs)?; + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs % rhs; + + Ok(Self::Field::from(res)) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { + let lhs = self.compute_var(env, lhs)?; + let rhs = self.compute_var(env, rhs)?; + + if rhs.is_zero() { + return Err(Error::new( + "runtime", + ErrorKind::DivisionByZero, + Span::default(), + )); + } + // convert to bigints + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + + let res = lhs % rhs; + + Ok(Self::Field::from(res)) + } + } + } } } diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index cc5898bfc..6460547e9 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -1,3 +1,4 @@ +use ark_ff::PrimeField; use std::vec; use kimchi::{o1_utils::FieldHelpers, turshi::helper::CairoFieldHelpers}; @@ -6,7 +7,7 @@ use crate::{ backends::Backend, circuit_writer::{CircuitWriter, VarInfo}, constants::Span, - error::Result, + error::{Error, ErrorKind, Result}, parser::types::GenericParameters, var::{ConstOrCell, Value, Var}, }; @@ -14,6 +15,7 @@ use crate::{ use super::{FnInfoType, Module}; const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; +const CHECK_FIELD_SIZE_FN: &str = "check_field_size(cmp: Field)"; pub struct BitsLib {} @@ -21,7 +23,10 @@ impl Module for BitsLib { const MODULE: &'static str = "bits"; fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![(NTH_BIT_FN, nth_bit)] + vec![ + (NTH_BIT_FN, nth_bit), + (CHECK_FIELD_SIZE_FN, check_field_size), + ] } } @@ -67,3 +72,33 @@ fn nth_bit( Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) } + +// Ensure that the field size is not exceeded +fn check_field_size( + _compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + let var = &vars[0].var[0]; + let bit_len = B::Field::size_in_bits() as u64; + + match var { + ConstOrCell::Const(cst) => { + let to_cmp = cst.to_u64(); + if to_cmp >= bit_len { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } + Ok(None) + } + ConstOrCell::Cell(_) => Err(Error::new( + "constraint-generation", + ErrorKind::ExpectedConstant, + span, + )), + } +} diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index 7d820e79d..d0599fdda 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -1,6 +1,10 @@ //! Builtins are imported by default. -use ark_ff::One; +use std::sync::Arc; + +use ark_ff::{One, Zero}; +use kimchi::o1_utils::FieldHelpers; +use num_bigint::BigUint; use crate::{ backends::Backend, @@ -9,7 +13,7 @@ use crate::{ error::{Error, ErrorKind, Result}, helpers::PrettyField, parser::types::{GenericParameters, TyKind}, - var::{ConstOrCell, Var}, + var::{ConstOrCell, Value, Var}, }; use super::{FnInfoType, Module}; diff --git a/src/stdlib/int.rs b/src/stdlib/int.rs new file mode 100644 index 000000000..7c5b8bfc2 --- /dev/null +++ b/src/stdlib/int.rs @@ -0,0 +1,77 @@ +use std::vec; + +use kimchi::o1_utils::FieldHelpers; + +use crate::{ + backends::Backend, + circuit_writer::{CircuitWriter, VarInfo}, + constants::Span, + error::Result, + parser::types::GenericParameters, + var::{ConstOrCell, Value, Var}, +}; + +use super::{FnInfoType, Module}; + +const DIVMOD_FN: &str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; + +pub struct IntLib {} + +impl Module for IntLib { + const MODULE: &'static str = "int"; + + fn get_fns() -> Vec<(&'static str, FnInfoType)> { + vec![(DIVMOD_FN, divmod_fn)] + } +} + +/// Divides two field elements and returns the quotient and remainder. +fn divmod_fn( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // we get two vars + let dividend_info = &vars[0]; + let divisor_info = &vars[1]; + + // retrieve the values + let dividend_var = ÷nd_info.var[0]; + let divisor_var = &divisor_info.var[0]; + + match (dividend_var, divisor_var) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + // convert to bigints + let a = a.to_biguint(); + let b = b.to_biguint(); + + let quotient = a.clone() / b.clone(); + let remainder = a % b; + + // convert back to fields + let quotient = B::Field::from_biguint("ient).unwrap(); + let remainder = B::Field::from_biguint(&remainder).unwrap(); + + Ok(Some(Var::new( + vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], + span, + ))) + } + + _ => { + let quotient = compiler + .backend + .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); + let remainder = compiler + .backend + .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); + + Ok(Some(Var::new( + vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], + span, + ))) + } + } +} diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index e01489059..9cda5d2d1 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -19,6 +19,7 @@ use std::path::Path; pub mod bits; pub mod builtins; pub mod crypto; +pub mod int; /// The directory under [NONAME_DIRECTORY] containing the native stdlib. pub const STDLIB_DIRECTORY: &str = "src/stdlib/native/"; @@ -27,6 +28,7 @@ pub enum AllStdModules { Builtins, Crypto, Bits, + Int, } impl AllStdModules { @@ -35,6 +37,7 @@ impl AllStdModules { AllStdModules::Builtins, AllStdModules::Crypto, AllStdModules::Bits, + AllStdModules::Int, ] } @@ -43,6 +46,7 @@ impl AllStdModules { AllStdModules::Builtins => builtins::BuiltinsLib::get_parsed_fns(), AllStdModules::Crypto => crypto::CryptoLib::get_parsed_fns(), AllStdModules::Bits => bits::BitsLib::get_parsed_fns(), + AllStdModules::Int => int::IntLib::get_parsed_fns(), } } @@ -51,6 +55,7 @@ impl AllStdModules { AllStdModules::Builtins => builtins::BuiltinsLib::MODULE, AllStdModules::Crypto => crypto::CryptoLib::MODULE, AllStdModules::Bits => bits::BitsLib::MODULE, + AllStdModules::Int => int::IntLib::MODULE, } } } diff --git a/src/stdlib/native/int/lib.no b/src/stdlib/native/int/lib.no index 4ad93c3b1..09bdbf16a 100644 --- a/src/stdlib/native/int/lib.no +++ b/src/stdlib/native/int/lib.no @@ -1,7 +1,12 @@ use std::bits; use std::comparator; +use std::int; + +// A hint function for calculating quotient and remainder. +hint fn divmod(dividend: Field, divisor: Field) -> [Field; 2]; // u8 +// must use new() to create a Uint8, so the value is range checked struct Uint8 { inner: Field, } @@ -9,6 +14,8 @@ struct Uint8 { fn Uint8.new(val: Field) -> Uint8 { let bit_len = 8; + bits::check_field_size(bit_len); + // range check let ignore_ = bits::to_bits(bit_len, val); @@ -18,6 +25,7 @@ fn Uint8.new(val: Field) -> Uint8 { } // u16 +// must use new() to create a Uint16, so the value is range checked struct Uint16 { inner: Field } @@ -25,6 +33,8 @@ struct Uint16 { fn Uint16.new(val: Field) -> Uint16 { let bit_len = 16; + bits::check_field_size(bit_len); + // range check let ignore_ = bits::to_bits(bit_len, val); @@ -34,6 +44,7 @@ fn Uint16.new(val: Field) -> Uint16 { } // u32 +// must use new() to create a Uint32, so the value is range checked struct Uint32 { inner: Field } @@ -41,6 +52,8 @@ struct Uint32 { fn Uint32.new(val: Field) -> Uint32 { let bit_len = 32; + bits::check_field_size(bit_len); + // range check let ignore_ = bits::to_bits(bit_len, val); @@ -50,6 +63,7 @@ fn Uint32.new(val: Field) -> Uint32 { } // u64 +// must use new() to create a Uint64, so the value is range checked struct Uint64 { inner: Field } @@ -57,6 +71,8 @@ struct Uint64 { fn Uint64.new(val: Field) -> Uint64 { let bit_len = 64; + bits::check_field_size(bit_len); + // range check let ignore_ = bits::to_bits(bit_len, val); @@ -66,7 +82,6 @@ fn Uint64.new(val: Field) -> Uint64 { } // implement comparator - fn Uint8.less_than(self, rhs: Uint8) -> Bool { return comparator::less_than(8, self.inner, rhs.inner); } @@ -97,4 +112,179 @@ fn Uint64.less_than(self, rhs: Uint64) -> Bool { fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { return comparator::less_eq_than(64, self.inner, rhs.inner); +} + +// + +fn Uint8.add(self, rhs: Uint8) -> Uint8 { + return Uint8.new(self.inner + rhs.inner); +} + +fn Uint16.add(self, rhs: Uint16) -> Uint16 { + return Uint16.new(self.inner + rhs.inner); +} + +fn Uint32.add(self, rhs: Uint32) -> Uint32 { + return Uint32.new(self.inner + rhs.inner); +} + +fn Uint64.add(self, rhs: Uint64) -> Uint64 { + return Uint64.new(self.inner + rhs.inner); +} + +// - +fn Uint8.sub(self, rhs: Uint8) -> Uint8 { + return Uint8.new(self.inner - rhs.inner); +} + +fn Uint16.sub(self, rhs: Uint16) -> Uint16 { + return Uint16.new(self.inner - rhs.inner); +} + +fn Uint32.sub(self, rhs: Uint32) -> Uint32 { + return Uint32.new(self.inner - rhs.inner); +} + +fn Uint64.sub(self, rhs: Uint64) -> Uint64 { + return Uint64.new(self.inner - rhs.inner); +} + +// * +fn Uint8.mul(self, rhs: Uint8) -> Uint8 { + return Uint8.new(self.inner * rhs.inner); +} + +fn Uint16.mul(self, rhs: Uint16) -> Uint16 { + return Uint16.new(self.inner * rhs.inner); +} + +fn Uint32.mul(self, rhs: Uint32) -> Uint32 { + return Uint32.new(self.inner * rhs.inner); +} + +fn Uint64.mul(self, rhs: Uint64) -> Uint64 { + return Uint64.new(self.inner * rhs.inner); +} + +// Division with quotient and remainder +// a = q * b + r +fn Uint8.divmod(self, rhs: Uint8) -> [Uint8; 2] { + // not allow divide by zero + assert(rhs.inner != 0); + + let q_rem = unsafe int::divmod(self.inner, rhs.inner); + let quotient = Uint8.new(q_rem[0]); + let rem = Uint8.new(q_rem[1]); + + // r < b + let is_lt = rem.less_than(rhs); + assert(is_lt); + + let qb = quotient.mul(rhs); // q * b + let expected = qb.add(rem); // a = q * b + r + + assert_eq(self.inner, expected.inner); + + return [quotient, rem]; +} + +fn Uint16.divmod(self, rhs: Uint16) -> [Uint16; 2] { + // not allow divide by zero + assert(rhs.inner != 0); + + let q_rem = unsafe int::divmod(self.inner, rhs.inner); + let quotient = Uint16.new(q_rem[0]); + let rem = Uint16.new(q_rem[1]); + + // r < b + let is_lt = rem.less_than(rhs); + assert(is_lt); + + let qb = quotient.mul(rhs); // q * b + let expected = qb.add(rem); // a = q * b + r + + assert_eq(self.inner, expected.inner); + + return [quotient, rem]; +} + +fn Uint32.divmod(self, rhs: Uint32) -> [Uint32; 2] { + // not allow divide by zero + assert(rhs.inner != 0); + + let q_rem = unsafe int::divmod(self.inner, rhs.inner); + let quotient = Uint32.new(q_rem[0]); + let rem = Uint32.new(q_rem[1]); + + // r < b + let is_lt = rem.less_than(rhs); + assert(is_lt); + + let qb = quotient.mul(rhs); // q * b + let expected = qb.add(rem); // a = q * b + r + + assert_eq(self.inner, expected.inner); + + return [quotient, rem]; +} + +fn Uint64.divmod(self, rhs: Uint64) -> [Uint64; 2] { + // not allow divide by zero + assert(rhs.inner != 0); + + let q_rem = unsafe int::divmod(self.inner, rhs.inner); + let quotient = Uint64.new(q_rem[0]); + let rem = Uint64.new(q_rem[1]); + + // r < b + let is_lt = rem.less_than(rhs); + assert(is_lt); + + let qb = quotient.mul(rhs); // q * b + let expected = qb.add(rem); // a = q * b + r + + assert_eq(self.inner, expected.inner); + + return [quotient, rem]; +} + +// Division (quotient only) +fn Uint8.div(self, rhs: Uint8) -> Uint8 { + let res = self.divmod(rhs); + return res[0]; +} + +fn Uint16.div(self, rhs: Uint16) -> Uint16 { + let res = self.divmod(rhs); + return res[0]; +} + +fn Uint32.div(self, rhs: Uint32) -> Uint32 { + let res = self.divmod(rhs); + return res[0]; +} + +fn Uint64.div(self, rhs: Uint64) -> Uint64 { + let res = self.divmod(rhs); + return res[0]; +} + +// Modulo (remainder only) +fn Uint8.mod(self, rhs: Uint8) -> Uint8 { + let res = self.divmod(rhs); + return res[1]; +} + +fn Uint16.mod(self, rhs: Uint16) -> Uint16 { + let res = self.divmod(rhs); + return res[1]; +} + +fn Uint32.mod(self, rhs: Uint32) -> Uint32 { + let res = self.divmod(rhs); + return res[1]; +} + +fn Uint64.mod(self, rhs: Uint64) -> Uint64 { + let res = self.divmod(rhs); + return res[1]; } \ No newline at end of file diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 3bac61a6f..25c98ed65 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -1,6 +1,7 @@ mod comparator; mod mimc; mod multiplexer; +mod uints; use std::{path::Path, str::FromStr}; diff --git a/src/tests/stdlib/uints/mod.rs b/src/tests/stdlib/uints/mod.rs new file mode 100644 index 000000000..efd1ede83 --- /dev/null +++ b/src/tests/stdlib/uints/mod.rs @@ -0,0 +1,67 @@ +use crate::error::{self, ErrorKind}; + +use super::test_stdlib_code; +use error::Result; +use rstest::rstest; + +// code template +static TPL: &str = r#" +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Field { + let lhs_u = int::{inttyp}.new(lhs); + let rhs_u = int::{inttyp}.new(rhs); + + let res = lhs_u.{opr}(rhs_u); + + return res.inner; +} +"#; + +#[rstest] +#[case("Uint8", "add", r#"{"lhs": "2"}"#, r#"{"rhs": "2"}"#, vec!["4"])] +#[case("Uint8", "sub", r#"{"lhs": "2"}"#, r#"{"rhs": "2"}"#, vec!["0"])] +#[case("Uint8", "mul", r#"{"lhs": "2"}"#, r#"{"rhs": "2"}"#, vec!["4"])] +#[case("Uint8", "div", r#"{"lhs": "5"}"#, r#"{"rhs": "3"}"#, vec!["1"])] +#[case("Uint8", "mod", r#"{"lhs": "5"}"#, r#"{"rhs": "3"}"#, vec!["2"])] +fn test_uint_ops( + #[case] int_type: &str, + #[case] operation: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + // Replace placeholders with the given integer type. + let code = TPL + .replace("{inttyp}", int_type) + .replace("{opr}", operation); + + // Call the test function with the given inputs and expected output. + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; + + Ok(()) +} + +/// test overflow after operation +#[rstest] +#[case("Uint8", "add", r#"{"lhs": "255"}"#, r#"{"rhs": "1"}"#)] +#[case("Uint8", "sub", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#)] +#[case("Uint8", "mul", r#"{"lhs": "255"}"#, r#"{"rhs": "2"}"#)] +fn test_uint_overflow( + #[case] int_type: &str, + #[case] operation: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = TPL + .replace("{inttyp}", int_type) + .replace("{opr}", operation); + + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec![]) + .err() + .expect("expected overflow error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} diff --git a/src/var.rs b/src/var.rs index e183a7085..ce2445472 100644 --- a/src/var.rs +++ b/src/var.rs @@ -46,6 +46,10 @@ where /// Note that it will potentially return 0 if the given variable is 0. Inverse(B::Var), + Div(ConstOrCell, ConstOrCell), + + Mod(ConstOrCell, ConstOrCell), + /// Extract the nth bit from a value. // todo: consider using Value itself as argument, which wraps B::Var or B::Field, as Value::Var or Value::Field // - so the arguments for these operations can be either B::Var or B::Field @@ -74,6 +78,8 @@ impl std::fmt::Debug for Value { Value::LinearCombination(..) => write!(f, "LinearCombination"), Value::Mul(..) => write!(f, "Mul"), Value::Inverse(_) => write!(f, "Inverse"), + Value::Div(..) => write!(f, "Divide"), + Value::Mod(..) => write!(f, "Mod"), Value::External(..) => write!(f, "External"), Value::PublicOutput(..) => write!(f, "PublicOutput"), Value::Scale(..) => write!(f, "Scaling"),