From fbfa222f1da7712e87bf0de0e62e3e60abeb967a Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 27 Sep 2024 16:47:13 +0800 Subject: [PATCH 01/36] fix: avoid re-instantiate functions/methods --- src/mast/mod.rs | 129 ++++++++++++++++++++++++++++++------------------ 1 file changed, 80 insertions(+), 49 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 60039c45d..07d36a276 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -152,10 +152,10 @@ where { tast: TypeChecker, generic_func_scope: Option, - // fully qualified function name - functions_to_delete: Vec, - // fully qualified struct name, method name - methods_to_delete: Vec<(FullyQualified, String)>, + // new fully qualified function name as the key, old fully qualified function name as the value + functions_instantiated: HashMap, + // new method name as the key, old method name as the value + methods_instantiated: HashMap<(FullyQualified, String), String>, } impl MastCtx { @@ -163,8 +163,8 @@ impl MastCtx { Self { tast, generic_func_scope: Some(0), - functions_to_delete: vec![], - methods_to_delete: vec![], + functions_instantiated: HashMap::new(), + methods_instantiated: HashMap::new(), } } @@ -190,9 +190,7 @@ impl MastCtx { ) { self.tast .add_monomorphized_fn(new_qualified.clone(), fn_info); - if new_qualified != old_qualified { - self.functions_to_delete.push(old_qualified); - } + self.functions_instantiated.insert(new_qualified, old_qualified); } pub fn add_monomorphized_method( @@ -205,20 +203,23 @@ impl MastCtx { self.tast .add_monomorphized_method(struct_qualified.clone(), method_name, fn_info); - if method_name != old_method_name { - self.methods_to_delete - .push((struct_qualified, old_method_name.to_string())); - } + self.methods_instantiated + .insert((struct_qualified, method_name.to_string()), old_method_name.to_string()); } pub fn clear_generic_fns(&mut self) { - for qualified in &self.functions_to_delete { - self.tast.remove_fn(qualified); + for (new, old) in &self.functions_instantiated { + // don't remove the instantiated function with no generic arguments + if new != old { + self.tast.remove_fn(old); + } } - self.functions_to_delete.clear(); - for (struct_qualified, method_name) in &self.methods_to_delete { - self.tast.remove_method(struct_qualified, method_name); + for ((struct_qualified, new), old) in &self.methods_instantiated { + // don't remove the instantiated method with no generic arguments + if new != old { + self.tast.remove_method(struct_qualified, old); + } } } } @@ -422,21 +423,36 @@ fn monomorphize_expr( let args_mono = observed.clone().into_iter().map(|e| e.expr).collect(); - let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::FnCall { - module: module.clone(), - fn_name: fn_name_mono.clone(), - args: args_mono, - }, - ..expr.clone() - }; - - let qualified = FullyQualified::new(module, &fn_name_mono.value); - ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); - - // assume the function call won't return constant value - ExprMonoInfo::new(mexpr, typ, None) + // check if this function is already monomorphized + if ctx.functions_instantiated.contains_key(&old_qualified) { + // todo: cache the propagated constant from instantiated function, + // so it doesn't need to re-instantiate the function + let mexpr = Expr { + kind: ExprKind::FnCall { + module: module.clone(), + fn_name: fn_name.clone(), + args: args_mono, + }, + ..expr.clone() + }; + ExprMonoInfo::new(mexpr, typ, None) + } + else { + let fn_name_mono = &fn_info_mono.sig().name; + let mexpr = Expr { + kind: ExprKind::FnCall { + module: module.clone(), + fn_name: fn_name_mono.clone(), + args: args_mono, + }, + ..expr.clone() + }; + + let qualified = FullyQualified::new(module, &fn_name_mono.value); + ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); + + ExprMonoInfo::new(mexpr, typ, None) + } } // `lhs.method_name(args)` @@ -493,22 +509,37 @@ fn monomorphize_expr( // monomorphize the function call let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::MethodCall { - lhs: Box::new(lhs_mono.expr), - method_name: fn_name_mono.clone(), - args: args_mono, - }, - ..expr.clone() - }; - - let fn_def = fn_info_mono.native(); - ctx.tast - .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - - // assume the function call won't return constant value - ExprMonoInfo::new(mexpr, typ, None) + // check if this function is already monomorphized + if ctx.methods_instantiated.contains_key(&(struct_qualified.clone(), method_name.value.clone())) { + // todo: cache the propagated constant from instantiated method, + // so it doesn't need to re-instantiate the function + let mexpr = Expr { + kind: ExprKind::MethodCall { + lhs: Box::new(lhs_mono.expr), + method_name: method_name.clone(), + args: args_mono, + }, + ..expr.clone() + }; + ExprMonoInfo::new(mexpr, typ, None) + } + else { + let fn_name_mono = &fn_info_mono.sig().name; + let mexpr = Expr { + kind: ExprKind::MethodCall { + lhs: Box::new(lhs_mono.expr), + method_name: fn_name_mono.clone(), + args: args_mono, + }, + ..expr.clone() + }; + + let fn_def = fn_info_mono.native(); + ctx.tast + .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); + + ExprMonoInfo::new(mexpr, typ, None) + } } ExprKind::Assignment { lhs, rhs } => { From 4eaed5bc2b570c1bc4c5092e229aa601a720cc58 Mon Sep 17 00:00:00 2001 From: kata Date: Sat, 5 Oct 2024 11:56:06 +0800 Subject: [PATCH 02/36] fix: monomorphized expr shouldn't override existing nodes --- src/mast/ast.rs | 13 +++------ src/mast/mod.rs | 70 +++++++++++++++++++++++++++++-------------------- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/src/mast/ast.rs b/src/mast/ast.rs index 3dbde0da2..ecb63e864 100644 --- a/src/mast/ast.rs +++ b/src/mast/ast.rs @@ -8,15 +8,10 @@ use super::MastCtx; impl Expr { /// Convert an expression to another expression, with the same span and a regenerated node id. pub fn to_mast(&self, ctx: &mut MastCtx, kind: &ExprKind) -> Expr { - match ctx.generic_func_scope { - // not in any generic function scope - Some(0) => self.clone(), - // in a generic function scope - _ => Expr { - node_id: ctx.next_node_id(), - kind: kind.clone(), - ..self.clone() - }, + Expr { + node_id: ctx.next_node_id(), + kind: kind.clone(), + ..self.clone() } } } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 07d36a276..454b670e5 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -418,36 +418,41 @@ fn monomorphize_expr( .expect("function not found") .to_owned(); - // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - let args_mono = observed.clone().into_iter().map(|e| e.expr).collect(); // check if this function is already monomorphized if ctx.functions_instantiated.contains_key(&old_qualified) { - // todo: cache the propagated constant from instantiated function, - // so it doesn't need to re-instantiate the function - let mexpr = Expr { - kind: ExprKind::FnCall { + let mexpr = expr.to_mast( + ctx, + &ExprKind::FnCall { module: module.clone(), fn_name: fn_name.clone(), args: args_mono, }, - ..expr.clone() - }; + ); + let fn_info = ctx + .tast + .fn_info(&old_qualified) + .expect("function not found") + .to_owned(); + let typ = fn_info.sig().return_type.clone().map(|t| t.kind); + ExprMonoInfo::new(mexpr, typ, None) } else { + // monomorphize the function call + let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::FnCall { + let mexpr = expr.to_mast( + ctx, + &ExprKind::FnCall { module: module.clone(), fn_name: fn_name_mono.clone(), args: args_mono, }, - ..expr.clone() - }; - + ); + let qualified = FullyQualified::new(module, &fn_name_mono.value); ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); @@ -506,34 +511,41 @@ fn monomorphize_expr( args_mono.push(expr_mono.expr); } - // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - // check if this function is already monomorphized - if ctx.methods_instantiated.contains_key(&(struct_qualified.clone(), method_name.value.clone())) { - // todo: cache the propagated constant from instantiated method, - // so it doesn't need to re-instantiate the function - let mexpr = Expr { - kind: ExprKind::MethodCall { + if ctx + .methods_instantiated + .contains_key(&(struct_qualified.clone(), method_name.value.clone())) + { + let mexpr = expr.to_mast( + ctx, + &ExprKind::MethodCall { lhs: Box::new(lhs_mono.expr), method_name: method_name.clone(), args: args_mono, }, - ..expr.clone() - }; + ); + let fn_info = ctx + .tast + .fn_info(&struct_qualified) + .expect("function not found") + .to_owned(); + let typ = fn_info.sig().return_type.clone().map(|t| t.kind); ExprMonoInfo::new(mexpr, typ, None) } else { + // monomorphize the function call + let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::MethodCall { + let mexpr = expr.to_mast( + ctx, + &ExprKind::MethodCall { lhs: Box::new(lhs_mono.expr), method_name: fn_name_mono.clone(), args: args_mono, }, - ..expr.clone() - }; - + ); + let fn_def = fn_info_mono.native(); ctx.tast .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); From 10432c3f98eb19c8641e7b4098e274b8367d00ee Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 9 Oct 2024 17:15:53 +0800 Subject: [PATCH 03/36] fmt --- src/mast/mod.rs | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 454b670e5..d610422da 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -190,7 +190,8 @@ impl MastCtx { ) { self.tast .add_monomorphized_fn(new_qualified.clone(), fn_info); - self.functions_instantiated.insert(new_qualified, old_qualified); + self.functions_instantiated + .insert(new_qualified, old_qualified); } pub fn add_monomorphized_method( @@ -203,8 +204,10 @@ impl MastCtx { self.tast .add_monomorphized_method(struct_qualified.clone(), method_name, fn_info); - self.methods_instantiated - .insert((struct_qualified, method_name.to_string()), old_method_name.to_string()); + self.methods_instantiated.insert( + (struct_qualified, method_name.to_string()), + old_method_name.to_string(), + ); } pub fn clear_generic_fns(&mut self) { @@ -436,13 +439,12 @@ fn monomorphize_expr( .expect("function not found") .to_owned(); let typ = fn_info.sig().return_type.clone().map(|t| t.kind); - + ExprMonoInfo::new(mexpr, typ, None) - } - else { + } else { // monomorphize the function call let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - + let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( ctx, @@ -455,7 +457,7 @@ fn monomorphize_expr( let qualified = FullyQualified::new(module, &fn_name_mono.value); ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); - + ExprMonoInfo::new(mexpr, typ, None) } } @@ -531,8 +533,7 @@ fn monomorphize_expr( .to_owned(); let typ = fn_info.sig().return_type.clone().map(|t| t.kind); ExprMonoInfo::new(mexpr, typ, None) - } - else { + } else { // monomorphize the function call let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; @@ -549,7 +550,7 @@ fn monomorphize_expr( let fn_def = fn_info_mono.native(); ctx.tast .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - + ExprMonoInfo::new(mexpr, typ, None) } } From 0221102f61a14d8cde90bb28058f9130383d1da0 Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 9 Oct 2024 14:14:55 +0800 Subject: [PATCH 04/36] support unsafe/hint attributes --- src/circuit_writer/writer.rs | 1 + src/error.rs | 9 +++ src/lexer/mod.rs | 8 +++ src/mast/mod.rs | 7 +++ src/name_resolution/context.rs | 2 +- src/name_resolution/expr.rs | 1 + src/negative_tests.rs | 112 +++++++++++++++++++++++++++------ src/parser/expr.rs | 20 ++++++ src/parser/mod.rs | 21 +++++++ src/parser/structs.rs | 2 +- src/parser/types.rs | 29 ++++++++- src/stdlib/mod.rs | 1 + src/type_checker/checker.rs | 16 +++++ src/type_checker/mod.rs | 44 +++++++++++++ 14 files changed, 252 insertions(+), 21 deletions(-) diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 924c57b6f..ec1a744c7 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -379,6 +379,7 @@ impl CircuitWriter { module, fn_name, args, + .. } => { // sanity check if fn_name.value == "main" { diff --git a/src/error.rs b/src/error.rs index 9c53036f5..87f8845d0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -237,6 +237,9 @@ pub enum ErrorKind { #[error("function `{0}` not present in scope (did you misspell it?)")] UndefinedFunction(String), + #[error("hint function `{0}` signature is missing its corresponding builtin function")] + MissingHintMapping(String), + #[error("function name `{0}` is already in use by a variable present in the scope")] FunctionNameInUsebyVariable(String), @@ -246,6 +249,12 @@ pub enum ErrorKind { #[error("attribute not recognized: `{0:?}`")] InvalidAttribute(AttributeKind), + #[error("unsafe attribute is needed to call a hint function. eg: `unsafe fn foo()`")] + ExpectedUnsafeAttribute, + + #[error("unsafe attribute should only be applied to hint function calls")] + UnexpectedUnsafeAttribute, + #[error("A return value is not used")] UnusedReturnValue, diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 89784b213..21d634fc3 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -44,6 +44,10 @@ pub enum Keyword { Use, /// A function Fn, + /// A hint function + Hint, + /// Attribute required for hint functions + Unsafe, /// New variable Let, /// Public input @@ -75,6 +79,8 @@ impl Keyword { match s { "use" => Some(Self::Use), "fn" => Some(Self::Fn), + "hint" => Some(Self::Hint), + "unsafe" => Some(Self::Unsafe), "let" => Some(Self::Let), "pub" => Some(Self::Pub), "return" => Some(Self::Return), @@ -97,6 +103,8 @@ impl Display for Keyword { let desc = match self { Self::Use => "use", Self::Fn => "fn", + Self::Hint => "hint", + Self::Unsafe => "unsafe", Self::Let => "let", Self::Pub => "pub", Self::Return => "return", diff --git a/src/mast/mod.rs b/src/mast/mod.rs index d610422da..5ca421e2b 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -405,6 +405,7 @@ fn monomorphize_expr( module, fn_name, args, + unsafe_attr, } => { // compute the observed arguments types let mut observed = Vec::with_capacity(args.len()); @@ -431,6 +432,7 @@ fn monomorphize_expr( module: module.clone(), fn_name: fn_name.clone(), args: args_mono, + unsafe_attr: *unsafe_attr, }, ); let fn_info = ctx @@ -452,6 +454,7 @@ fn monomorphize_expr( module: module.clone(), fn_name: fn_name_mono.clone(), args: args_mono, + unsafe_attr: *unsafe_attr, }, ); @@ -495,6 +498,7 @@ fn monomorphize_expr( let fn_kind = FnKind::Native(method_type.clone()); let fn_info = FnInfo { kind: fn_kind, + is_hint: false, span: method_type.span, }; @@ -1151,6 +1155,7 @@ pub fn instantiate_fn_call( let func_def = match fn_info.kind { FnKind::BuiltIn(_, handle) => FnInfo { kind: FnKind::BuiltIn(sig_typed, handle), + is_hint: fn_info.is_hint, span: fn_info.span, }, FnKind::Native(fn_def) => { @@ -1162,7 +1167,9 @@ pub fn instantiate_fn_call( sig: sig_typed, body: stmts_typed, span: fn_def.span, + is_hint: fn_def.is_hint, }), + is_hint: fn_info.is_hint, span: fn_info.span, } } diff --git a/src/name_resolution/context.rs b/src/name_resolution/context.rs index 59a5efad6..d4d196eea 100644 --- a/src/name_resolution/context.rs +++ b/src/name_resolution/context.rs @@ -66,7 +66,7 @@ impl NameResCtx { } pub(crate) fn resolve_fn_def(&self, fn_def: &mut FunctionDef) -> Result<()> { - let FunctionDef { sig, body, span: _ } = fn_def; + let FunctionDef { sig, body, .. } = fn_def; // // signature diff --git a/src/name_resolution/expr.rs b/src/name_resolution/expr.rs index 292afeb97..d2630fe00 100644 --- a/src/name_resolution/expr.rs +++ b/src/name_resolution/expr.rs @@ -20,6 +20,7 @@ impl NameResCtx { module, fn_name, args, + unsafe_attr: _, } => { if matches!(module, ModulePath::Local) && BUILTIN_FN_NAMES.contains(&fn_name.value.as_str()) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 9c7013d4d..e2a560118 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -1,30 +1,23 @@ use crate::{ - backends::r1cs::{R1csBn254Field, R1CS}, - circuit_writer::CircuitWriter, - compiler::{get_nast, typecheck_next_file_inner, Sources}, - error::{ErrorKind, Result}, - mast::Mast, - name_resolution::NAST, - type_checker::TypeChecker, - witness::CompiledCircuit, + backends::{r1cs::{R1csBn254Field, R1CS}, Backend}, circuit_writer::{CircuitWriter, VarInfo}, compiler::{get_nast, typecheck_next_file_inner, Sources}, constants::Span, error::{ErrorKind, Result}, imports::FnKind, lexer::Token, mast::Mast, name_resolution::NAST, parser::{types::{FnSig, GenericParameters}, ParserCtx}, type_checker::{FnInfo, FullyQualified, TypeChecker}, var::Var, witness::CompiledCircuit }; -fn nast_pass(code: &str) -> Result<(NAST>, usize)> { +type R1csBackend = R1CS; + +fn nast_pass(code: &str) -> Result<(NAST, usize)> { let mut source = Sources::new(); - let res = get_nast( + get_nast( None, &mut source, "example.no".to_string(), code.to_string(), 0, - ); - - res + ) } -fn tast_pass(code: &str) -> (Result, TypeChecker>, Sources) { +fn tast_pass(code: &str) -> (Result, TypeChecker, Sources) { let mut source = Sources::new(); - let mut tast = TypeChecker::>::new(); + let mut tast = TypeChecker::::new(); let res = typecheck_next_file_inner( &mut tast, None, @@ -37,12 +30,12 @@ fn tast_pass(code: &str) -> (Result, TypeChecker>, S (res, tast, source) } -fn mast_pass(code: &str) -> Result>> { +fn mast_pass(code: &str) -> Result> { let (_, tast, _) = tast_pass(code); crate::mast::monomorphize(tast) } -fn synthesizer_pass(code: &str) -> Result>> { +fn synthesizer_pass(code: &str) -> Result> { let mast = mast_pass(code); CircuitWriter::generate_circuit(mast?, R1CS::new()) } @@ -396,6 +389,89 @@ fn test_generic_missing_parenthesis() { "#; let res = nast_pass(code).err(); - println!("{:?}", res); assert!(matches!(res.unwrap().kind, ErrorKind::MissingParenthesis)); } +fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result { + let mut source = Sources::new(); + let mut tast = TypeChecker::::new(); + // mock a builtin function + let ctx = &mut ParserCtx::default(); + let mut tokens = Token::parse(0, "calc(val: Field) -> Field;").unwrap(); + let sig = FnSig::parse(ctx, &mut tokens).unwrap(); + + fn mocked_builtin_fn( + _: &mut CircuitWriter, + _: &GenericParameters, + _: &[VarInfo], + _: Span, + ) -> Result>> { + Ok(None) + } + + let fn_info = FnInfo { + kind: FnKind::BuiltIn(sig, mocked_builtin_fn::), + is_hint: true, + span: Span::default(), + }; + + // add the mocked builtin function + // note that this should happen in the tast phase, instead of mast phase. + // currently this function is the only way to mock a builtin function. + tast.add_monomorphized_fn(qualified.clone(), fn_info); + + typecheck_next_file_inner( + &mut tast, + None, + &mut source, + "example.no".to_string(), + code.to_string(), + 0, + ) +} + +#[test] +fn test_hint_call_missing_unsafe() { + let qualified = FullyQualified { + module: None, + name: "calc".to_string(), + }; + + let valid_code = r#" + hint fn calc(val: Field) -> Field; + + fn main(pub xx: Field) { + let yy = unsafe calc(xx); + } + "#; + + + let res = test_hint_builtin_fn(&qualified, valid_code); + assert!(res.is_ok()); + + let invalid_code = r#" + hint fn calc(val: Field) -> Field; + + fn main(pub xx: Field) { + let yy = calc(xx); + } + "#; + + let res = test_hint_builtin_fn(&qualified, invalid_code); + assert!(matches!(res.unwrap_err().kind, ErrorKind::ExpectedUnsafeAttribute)); +} + +#[test] +fn test_nonhint_call_with_unsafe() { + let code = r#" + fn calc(val: Field) -> Field { + return val + 1; + } + + fn main(pub xx: Field) { + let yy = unsafe calc(xx); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!(res.unwrap_err().kind, ErrorKind::UnexpectedUnsafeAttribute)); +} \ No newline at end of file diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 9fed71992..89a646d3f 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -58,6 +58,7 @@ pub enum ExprKind { module: ModulePath, fn_name: Ident, args: Vec, + unsafe_attr: bool, }, /// `lhs.method_name(args)` @@ -379,6 +380,24 @@ impl Expr { } } + TokenKind::Keyword(Keyword::Unsafe) => { + let mut fn_call = Expr::parse(ctx, tokens)?; + // should be FnCall + match &mut fn_call.kind { + ExprKind::FnCall { unsafe_attr, .. } => { + *unsafe_attr = true; + }, + _ => { + return Err(ctx.error( + ErrorKind::InvalidExpression, + fn_call.span, + )); + } + }; + + fn_call + } + // unrecognized pattern _ => { return Err(ctx.error(ErrorKind::InvalidExpression, token.span)); @@ -576,6 +595,7 @@ impl Expr { module, fn_name, args, + unsafe_attr: false, }, span, ) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index da765cbe3..dd838663c 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -158,6 +158,26 @@ impl AST { }); } + // `hint fn calc() { }` + TokenKind::Keyword(Keyword::Hint) => { + // expect fn token + tokens.bump_expected(ctx, TokenKind::Keyword(Keyword::Fn))?; + + function_observed = true; + + let func = FunctionDef::parse_hint(ctx, &mut tokens)?; + + // expect ;, as the hint function is an empty function wired with a builtin. + // todo: later these hint functions will be migrated from builtins to native functions + // then it will expect a function block instead of ; + tokens.bump_expected(ctx, TokenKind::SemiColon)?; + + ast.push(Root { + kind: RootKind::FunctionDef(func), + span: token.span, + }); + } + // `struct Foo { a: Field, b: Field }` TokenKind::Keyword(Keyword::Struct) => { let s = StructDef::parse(ctx, &mut tokens)?; @@ -200,6 +220,7 @@ mod tests { let code = r#"main(pub public_input: [Fel; 3], private_input: [Fel; 3]) -> [Fel; 3] { return public_input; }"#; let tokens = &mut Token::parse(0, code).unwrap(); let ctx = &mut ParserCtx::default(); + let is_hint = false; let parsed = FunctionDef::parse(ctx, tokens).unwrap(); println!("{:?}", parsed); } diff --git a/src/parser/structs.rs b/src/parser/structs.rs index 38c5f62bf..12bbffcda 100644 --- a/src/parser/structs.rs +++ b/src/parser/structs.rs @@ -104,7 +104,7 @@ impl StructDef { } // TODO: why is Default implemented here? -#[derive(Default, Debug, Clone, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct CustomType { pub module: ModulePath, // name resolution pub name: String, diff --git a/src/parser/types.rs b/src/parser/types.rs index 9c60dc65d..58c59ee87 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -739,6 +739,7 @@ impl Attribute { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionDef { + pub is_hint: bool, pub sig: FnSig, pub body: Vec, pub span: Span, @@ -1161,7 +1162,33 @@ impl FunctionDef { )); } - let func = Self { sig, body, span }; + let func = Self { sig, body, span, is_hint: false }; + + Ok(func) + } + + /// Parse a hint function signature + pub fn parse_hint(ctx: &mut ParserCtx, tokens: &mut Tokens) -> Result { + // parse signature + let sig = FnSig::parse(ctx, tokens)?; + let span = sig.name.span; + + // make sure that it doesn't shadow a builtin + if BUILTIN_FN_NAMES.contains(&sig.name.value.as_ref()) { + return Err(ctx.error( + ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), + span, + )); + } + + // for now the body is empty. + // this will be changed once the native hint is implemented. + let func = Self { + sig, + body: vec![], + span, + is_hint: true, + }; Ok(func) } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index f2ad66a8e..150d31517 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -72,6 +72,7 @@ trait Module { let sig = FnSig::parse(ctx, &mut tokens).unwrap(); res.push(FnInfo { kind: FnKind::BuiltIn(sig, fn_handle), + is_hint: false, span: Span::default(), }); } diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 7046d8141..74a8ace86 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -25,6 +25,11 @@ where B: Backend, { pub kind: FnKind, + // TODO: We will remove this once the native hint is supported + // This field is to indicate if a builtin function should be treated as a hint. + // instead of adding this flag to the FnKind::Builtin, we add this field to the FnInfo. + // Then this flag will only present in the FunctionDef. + pub is_hint: bool, pub span: Span, } @@ -130,6 +135,7 @@ impl TypeChecker { module, fn_name, args, + unsafe_attr, } => { // retrieve the function signature let qualified = FullyQualified::new(&module, &fn_name.value); @@ -141,6 +147,16 @@ impl TypeChecker { })?; let fn_sig = fn_info.sig().clone(); + // check if the function is a hint + if fn_info.is_hint && !unsafe_attr { + return Err(self.error(ErrorKind::ExpectedUnsafeAttribute, expr.span)); + } + + // unsafe attribute should only be used on hints + if !fn_info.is_hint && *unsafe_attr { + return Err(self.error(ErrorKind::UnexpectedUnsafeAttribute, expr.span)); + } + // check if generic is allowed if fn_sig.require_monomorphization() && typed_fn_env.is_in_forloop() { return Err(self.error(ErrorKind::GenericInForLoop, expr.span)); diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 7306f2d39..7f8b237de 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -312,9 +312,53 @@ impl TypeChecker { } // save the function in the typed global env + + if function.is_hint { + // convert to builtin function + let qualified = match &function.sig.kind { + FuncOrMethod::Function(module) => { + FullyQualified::new(module, &function.sig.name.value) + } + FuncOrMethod::Method(_) => unreachable!("methods are not supported") + }; + + // this will override the builtin function in the global env + let builtin_fn = self.functions + .get(&qualified) + .ok_or_else(|| { + Error::new( + "type-checker", + ErrorKind::MissingHintMapping(qualified.name.clone()), + function.span, + ) + })? + .kind + .clone(); + + // check it is a builtin function + let fn_handle = match builtin_fn { + FnKind::BuiltIn(_, fn_handle) => fn_handle, + _ => return Err(Error::new( + "type-checker", + ErrorKind::UnexpectedError("expected builtin function"), + function.span, + )), + }; + + // override the builtin function as a hint function + self.functions.insert(qualified, FnInfo { + is_hint: true, + kind: FnKind::BuiltIn(function.sig.clone(), fn_handle), + span: function.span, + }); + + continue; + }; + let fn_kind = FnKind::Native(function.clone()); let fn_info = FnInfo { kind: fn_kind, + is_hint: function.is_hint, span: function.span, }; From 5353d42f2c6ffb7311c49f8481afcf59760cf5b1 Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 9 Oct 2024 17:35:30 +0800 Subject: [PATCH 05/36] fmt --- src/negative_tests.rs | 33 ++++++++++++++++++++++++++++----- src/parser/expr.rs | 7 ++----- src/parser/structs.rs | 2 +- src/parser/types.rs | 12 +++++++----- src/type_checker/mod.rs | 30 ++++++++++++++++++------------ 5 files changed, 56 insertions(+), 28 deletions(-) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index e2a560118..1ed4dc063 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -1,5 +1,23 @@ use crate::{ - backends::{r1cs::{R1csBn254Field, R1CS}, Backend}, circuit_writer::{CircuitWriter, VarInfo}, compiler::{get_nast, typecheck_next_file_inner, Sources}, constants::Span, error::{ErrorKind, Result}, imports::FnKind, lexer::Token, mast::Mast, name_resolution::NAST, parser::{types::{FnSig, GenericParameters}, ParserCtx}, type_checker::{FnInfo, FullyQualified, TypeChecker}, var::Var, witness::CompiledCircuit + backends::{ + r1cs::{R1csBn254Field, R1CS}, + Backend, + }, + circuit_writer::{CircuitWriter, VarInfo}, + compiler::{get_nast, typecheck_next_file_inner, Sources}, + constants::Span, + error::{ErrorKind, Result}, + imports::FnKind, + lexer::Token, + mast::Mast, + name_resolution::NAST, + parser::{ + types::{FnSig, GenericParameters}, + ParserCtx, + }, + type_checker::{FnInfo, FullyQualified, TypeChecker}, + var::Var, + witness::CompiledCircuit, }; type R1csBackend = R1CS; @@ -444,7 +462,6 @@ fn test_hint_call_missing_unsafe() { } "#; - let res = test_hint_builtin_fn(&qualified, valid_code); assert!(res.is_ok()); @@ -457,7 +474,10 @@ fn test_hint_call_missing_unsafe() { "#; let res = test_hint_builtin_fn(&qualified, invalid_code); - assert!(matches!(res.unwrap_err().kind, ErrorKind::ExpectedUnsafeAttribute)); + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ExpectedUnsafeAttribute + )); } #[test] @@ -473,5 +493,8 @@ fn test_nonhint_call_with_unsafe() { "#; let res = tast_pass(code).0; - assert!(matches!(res.unwrap_err().kind, ErrorKind::UnexpectedUnsafeAttribute)); -} \ No newline at end of file + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::UnexpectedUnsafeAttribute + )); +} diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 89a646d3f..d070290bb 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -386,12 +386,9 @@ impl Expr { match &mut fn_call.kind { ExprKind::FnCall { unsafe_attr, .. } => { *unsafe_attr = true; - }, + } _ => { - return Err(ctx.error( - ErrorKind::InvalidExpression, - fn_call.span, - )); + return Err(ctx.error(ErrorKind::InvalidExpression, fn_call.span)); } }; diff --git a/src/parser/structs.rs b/src/parser/structs.rs index 12bbffcda..38c5f62bf 100644 --- a/src/parser/structs.rs +++ b/src/parser/structs.rs @@ -104,7 +104,7 @@ impl StructDef { } // TODO: why is Default implemented here? -#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct CustomType { pub module: ModulePath, // name resolution pub name: String, diff --git a/src/parser/types.rs b/src/parser/types.rs index 58c59ee87..a8685896d 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -1162,7 +1162,12 @@ impl FunctionDef { )); } - let func = Self { sig, body, span, is_hint: false }; + let func = Self { + sig, + body, + span, + is_hint: false, + }; Ok(func) } @@ -1175,10 +1180,7 @@ impl FunctionDef { // make sure that it doesn't shadow a builtin if BUILTIN_FN_NAMES.contains(&sig.name.value.as_ref()) { - return Err(ctx.error( - ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), - span, - )); + return Err(ctx.error(ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), span)); } // for now the body is empty. diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 7f8b237de..2b826f8a2 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -319,11 +319,12 @@ impl TypeChecker { FuncOrMethod::Function(module) => { FullyQualified::new(module, &function.sig.name.value) } - FuncOrMethod::Method(_) => unreachable!("methods are not supported") + FuncOrMethod::Method(_) => unreachable!("methods are not supported"), }; // this will override the builtin function in the global env - let builtin_fn = self.functions + let builtin_fn = self + .functions .get(&qualified) .ok_or_else(|| { Error::new( @@ -338,19 +339,24 @@ impl TypeChecker { // check it is a builtin function let fn_handle = match builtin_fn { FnKind::BuiltIn(_, fn_handle) => fn_handle, - _ => return Err(Error::new( - "type-checker", - ErrorKind::UnexpectedError("expected builtin function"), - function.span, - )), + _ => { + return Err(Error::new( + "type-checker", + ErrorKind::UnexpectedError("expected builtin function"), + function.span, + )) + } }; // override the builtin function as a hint function - self.functions.insert(qualified, FnInfo { - is_hint: true, - kind: FnKind::BuiltIn(function.sig.clone(), fn_handle), - span: function.span, - }); + self.functions.insert( + qualified, + FnInfo { + is_hint: true, + kind: FnKind::BuiltIn(function.sig.clone(), fn_handle), + span: function.span, + }, + ); continue; }; From 253f4b57b32d6fdf39bc5b009827edf0753eda46 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 10 Oct 2024 17:10:04 +0800 Subject: [PATCH 06/36] fix: always check the resolved qualified name --- src/mast/mod.rs | 220 ++++++++++++++++++++++++++++---------------- src/parser/types.rs | 66 ++++++------- 2 files changed, 169 insertions(+), 117 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 5ca421e2b..f16a2cbac 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -8,7 +8,9 @@ use crate::{ error::{Error, ErrorKind, Result}, imports::FnKind, parser::{ - types::{FnArg, FnSig, ForLoopArgument, Range, Stmt, StmtKind, Symbolic, Ty, TyKind}, + types::{ + FnSig, ForLoopArgument, GenericParameters, Range, Stmt, StmtKind, Symbolic, Ty, TyKind, + }, CustomType, Expr, ExprKind, FunctionDef, Op2, }, syntax::{is_generic_parameter, is_type}, @@ -144,6 +146,117 @@ impl MonomorphizedFnEnv { } } +impl FnInfo { + /// Resolves the generic values based on observed arguments. + pub fn resolve_generic_signature( + &mut self, + observed_args: &[ExprMonoInfo], + ctx: &mut MastCtx, + ) -> Result { + match self.kind { + FnKind::BuiltIn(ref mut sig, _) => { + sig.resolve_generic_values(observed_args, ctx)?; + } + FnKind::Native(ref mut func) => { + func.sig.resolve_generic_values(observed_args, ctx)?; + } + }; + + Ok(self.resolved_sig()) + } + + /// Returns the resolved signature of the function. + pub fn resolved_sig(&self) -> FnSig { + let fn_sig = self.sig(); + + let (ret_typed, fn_args_typed) = if let Some(resolved) = &fn_sig.generics.resolved_sig { + (resolved.return_type.clone(), resolved.arguments.clone()) + } else { + (fn_sig.return_type.clone(), fn_sig.arguments.clone()) + }; + + FnSig { + name: fn_sig.monomorphized_name(), + arguments: fn_args_typed, + return_type: ret_typed, + ..fn_sig.clone() + } + } +} + +impl FnSig { + /// Recursively resolve a type based on generic values + pub fn resolve_type(&self, typ: &TyKind, ctx: &mut MastCtx) -> TyKind { + match typ { + TyKind::Array(ty, size) => TyKind::Array(Box::new(self.resolve_type(ty, ctx)), *size), + TyKind::GenericSizedArray(ty, sym) => { + let val = sym.eval(&self.generics, &ctx.tast); + TyKind::Array(Box::new(self.resolve_type(ty, ctx)), val) + } + _ => typ.clone(), + } + } + + /// Resolve generic values for each generic parameter + pub fn resolve_generic_values( + &mut self, + observed: &[ExprMonoInfo], + ctx: &mut MastCtx, + ) -> Result<()> { + for (sig_arg, observed_arg) in self.arguments.clone().iter().zip(observed) { + let observed_ty = observed_arg.typ.clone().expect("expected type"); + match (&sig_arg.typ.kind, &observed_ty) { + (TyKind::GenericSizedArray(_, _), TyKind::Array(_, _)) + | (TyKind::Array(_, _), TyKind::Array(_, _)) => { + self.resolve_generic_array( + &sig_arg.typ.kind, + &observed_ty, + observed_arg.expr.span, + )?; + } + // const NN: Field + _ => { + let cst = observed_arg.constant; + if is_generic_parameter(sig_arg.name.value.as_str()) && cst.is_some() { + self.generics.assign( + &sig_arg.name.value, + cst.unwrap(), + observed_arg.expr.span, + )?; + } + } + } + } + + // resolve the argument types + let mut resolved_args = vec![]; + for arg in &self.arguments { + let resolved_arg_typ = self.resolve_type(&arg.typ.kind, ctx); + let mut resolved_arg = arg.clone(); + resolved_arg.typ = Ty { + kind: resolved_arg_typ, + span: arg.typ.span, + }; + resolved_args.push(resolved_arg); + } + + // resolve the return type + let mut return_type: Option = None; + if let Some(ty) = &self.return_type { + let ret_typed = self.resolve_type(&ty.kind, ctx); + return_type = Some(Ty { + kind: ret_typed, + span: ty.span, + }); + } + + // store the resolved types in arguments and return + self.generics.resolve_sig(resolved_args, return_type); + + Ok(()) + } +} + /// A context to store the last node id for the monomorphized AST. #[derive(Debug)] pub struct MastCtx @@ -229,7 +342,7 @@ impl MastCtx { impl Symbolic { /// Evaluate symbolic size to an integer. - pub fn eval(&self, mono_fn_env: &MonomorphizedFnEnv, tast: &TypeChecker) -> u32 { + pub fn eval(&self, gens: &GenericParameters, tast: &TypeChecker) -> u32 { match self { Symbolic::Concrete(v) => *v, Symbolic::Constant(var) => { @@ -240,9 +353,9 @@ impl Symbolic { let bigint: BigUint = cst.value[0].into(); bigint.try_into().expect("biguint too large") } - Symbolic::Generic(g) => mono_fn_env.get_type_info(&g.value).unwrap().value.unwrap(), - Symbolic::Add(a, b) => a.eval(mono_fn_env, tast) + b.eval(mono_fn_env, tast), - Symbolic::Mul(a, b) => a.eval(mono_fn_env, tast) * b.eval(mono_fn_env, tast), + Symbolic::Generic(g) => gens.get(&g.value), + Symbolic::Add(a, b) => a.eval(gens, tast) + b.eval(gens, tast), + Symbolic::Mul(a, b) => a.eval(gens, tast) * b.eval(gens, tast), } } } @@ -416,7 +529,7 @@ fn monomorphize_expr( // retrieve the function signature let old_qualified = FullyQualified::new(module, &fn_name.value); - let fn_info = ctx + let mut fn_info = ctx .tast .fn_info(&old_qualified) .expect("function not found") @@ -424,23 +537,26 @@ fn monomorphize_expr( let args_mono = observed.clone().into_iter().map(|e| e.expr).collect(); + let resolved_sig = fn_info.resolve_generic_signature(&observed, ctx)?; + + let mono_qualified = FullyQualified::new(module, &resolved_sig.name.value); + // check if this function is already monomorphized - if ctx.functions_instantiated.contains_key(&old_qualified) { + if ctx.functions_instantiated.contains_key(&mono_qualified) { let mexpr = expr.to_mast( ctx, &ExprKind::FnCall { module: module.clone(), - fn_name: fn_name.clone(), + fn_name: resolved_sig.name, args: args_mono, unsafe_attr: *unsafe_attr, }, ); - let fn_info = ctx - .tast - .fn_info(&old_qualified) - .expect("function not found") - .to_owned(); - let typ = fn_info.sig().return_type.clone().map(|t| t.kind); + let resolved_sig = &fn_info.sig().generics.resolved_sig; + + let typ = resolved_sig + .as_ref() + .and_then(|sig| sig.return_type.clone().map(|t| t.kind)); ExprMonoInfo::new(mexpr, typ, None) } else { @@ -458,8 +574,8 @@ fn monomorphize_expr( }, ); - let qualified = FullyQualified::new(module, &fn_name_mono.value); - ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); + let new_qualified = FullyQualified::new(module, &fn_name_mono.value); + ctx.add_monomorphized_fn(old_qualified, new_qualified, fn_info_mono); ExprMonoInfo::new(mexpr, typ, None) } @@ -496,7 +612,7 @@ fn monomorphize_expr( .expect("method not found on custom struct (TODO: better error)"); let fn_kind = FnKind::Native(method_type.clone()); - let fn_info = FnInfo { + let mut fn_info = FnInfo { kind: fn_kind, is_hint: false, span: method_type.span, @@ -517,25 +633,22 @@ fn monomorphize_expr( args_mono.push(expr_mono.expr); } + let resolved_sig = fn_info.resolve_generic_signature(&observed, ctx)?; + // check if this function is already monomorphized if ctx .methods_instantiated - .contains_key(&(struct_qualified.clone(), method_name.value.clone())) + .contains_key(&(struct_qualified.clone(), resolved_sig.name.value.clone())) { let mexpr = expr.to_mast( ctx, &ExprKind::MethodCall { lhs: Box::new(lhs_mono.expr), - method_name: method_name.clone(), + method_name: resolved_sig.name, args: args_mono, }, ); - let fn_info = ctx - .tast - .fn_info(&struct_qualified) - .expect("function not found") - .to_owned(); - let typ = fn_info.sig().return_type.clone().map(|t| t.kind); + let typ = resolved_sig.return_type.clone().map(|t| t.kind); ExprMonoInfo::new(mexpr, typ, None) } else { // monomorphize the function call @@ -1058,21 +1171,7 @@ pub fn instantiate_fn_call( ) -> Result<(FnInfo, Option)> { ctx.start_monomorphize_func(); - // resolve generic values - let (fn_sig, stmts) = match &fn_info.kind { - FnKind::BuiltIn(sig, _) => { - let mut sig = sig.clone(); - sig.resolve_generic_values(args)?; - - (sig, Vec::::new()) - } - FnKind::Native(func) => { - let mut sig = func.sig.clone(); - sig.resolve_generic_values(args)?; - - (sig, func.body.clone()) - } - }; + let fn_sig = fn_info.sig(); // canonicalize the arguments depending on method call or not let expected: Vec<_> = fn_sig.arguments.iter().collect(); @@ -1113,43 +1212,8 @@ pub fn instantiate_fn_call( )?; } - // reconstruct FnArgs using the observed types - let fn_args_typed = expected - .iter() - .zip(args) - .map(|(arg, mono_info)| FnArg { - name: arg.name.clone(), - attribute: arg.attribute.clone(), - span: arg.span, - typ: Ty { - kind: mono_info.typ.clone().expect("expected a type"), - span: arg.typ.span, - }, - }) - .collect(); - - // evaluate return types using the resolved generic values - let ret_typed = match &fn_sig.return_type { - Some(ret_ty) => match &ret_ty.kind { - TyKind::GenericSizedArray(typ, size) => { - let val = size.eval(mono_fn_env, &ctx.tast); - let tykind = TyKind::Array(typ.clone(), val); - Some(Ty { - kind: tykind, - span: ret_ty.span, - }) - } - _ => Some(ret_ty.clone()), - }, - None => None, - }; - - let sig_typed = FnSig { - name: fn_sig.monomorphized_name(), - arguments: fn_args_typed, - return_type: ret_typed.clone(), - ..fn_sig - }; + let sig_typed = fn_info.resolved_sig(); + let ret_typed = sig_typed.return_type.clone(); // construct the monomorphized function AST let func_def = match fn_info.kind { @@ -1160,7 +1224,7 @@ pub fn instantiate_fn_call( }, FnKind::Native(fn_def) => { let (stmts_typed, _) = - monomorphize_block(ctx, mono_fn_env, &stmts, ret_typed.as_ref())?; + monomorphize_block(ctx, mono_fn_env, &fn_def.body, ret_typed.as_ref())?; FnInfo { kind: FnKind::Native(FunctionDef { diff --git a/src/parser/types.rs b/src/parser/types.rs index a8685896d..106092f32 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -559,7 +559,7 @@ impl FnSig { } /// Recursively assign values to the generic parameters based on observed Array type argument - fn resolve_generic_array( + pub fn resolve_generic_array( &mut self, sig_arg: &TyKind, observed: &TyKind, @@ -590,36 +590,6 @@ impl FnSig { Ok(()) } - /// Resolve generic values for each generic parameter - pub fn resolve_generic_values(&mut self, observed: &[ExprMonoInfo]) -> Result<()> { - for (sig_arg, observed_arg) in self.arguments.clone().iter().zip(observed) { - let observed_ty = observed_arg.typ.clone().expect("expected type"); - match (&sig_arg.typ.kind, &observed_ty) { - (TyKind::GenericSizedArray(_, _), TyKind::Array(_, _)) - | (TyKind::Array(_, _), TyKind::Array(_, _)) => { - self.resolve_generic_array( - &sig_arg.typ.kind, - &observed_ty, - observed_arg.expr.span, - )?; - } - // const NN: Field - _ => { - let cst = observed_arg.constant; - if is_generic_parameter(sig_arg.name.value.as_str()) && cst.is_some() { - self.generics.assign( - &sig_arg.name.value, - cst.unwrap(), - observed_arg.expr.span, - )?; - } - } - } - } - - Ok(()) - } - /// Returns true if the function signature contains generic parameters or generic array types. /// Either: /// - `const NN: Field` or `[[Field; NN]; MM]` @@ -658,7 +628,7 @@ impl FnSig { let mut name = self.name.clone(); if self.require_monomorphization() { - let mut generics = self.generics.0.iter().collect::>(); + let mut generics = self.generics.parameters.iter().collect::>(); generics.sort_by(|a, b| a.0.cmp(b.0)); let generics = generics @@ -762,23 +732,34 @@ impl Default for FuncOrMethod { } } +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +/// Resolved types for a function signature +pub struct ResolvedSig { + pub arguments: Vec, + pub return_type: Option, +} + #[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct GenericParameters(HashMap>); +/// Generic parameters for a function signature +pub struct GenericParameters { + pub parameters: HashMap>, + pub resolved_sig: Option, +} impl GenericParameters { /// Return all generic parameter names pub fn names(&self) -> HashSet { - self.0.keys().cloned().collect() + self.parameters.keys().cloned().collect() } /// Add an unbound generic parameter pub fn add(&mut self, name: String) { - self.0.insert(name, None); + self.parameters.insert(name, None); } /// Get the value of a generic parameter pub fn get(&self, name: &str) -> u32 { - self.0 + self.parameters .get(name) .expect("generic parameter not found") .expect("generic value not assigned") @@ -786,12 +767,12 @@ impl GenericParameters { /// Returns whether the generic parameters are empty pub fn is_empty(&self) -> bool { - self.0.is_empty() + self.parameters.is_empty() } /// Bind a generic parameter to a value pub fn assign(&mut self, name: &String, value: u32, span: Span) -> Result<()> { - let existing = self.0.get(name); + let existing = self.parameters.get(name); match existing { Some(Some(v)) => { if *v == value { @@ -805,7 +786,7 @@ impl GenericParameters { )) } Some(None) => { - self.0.insert(name.to_string(), Some(value)); + self.parameters.insert(name.to_string(), Some(value)); Ok(()) } None => Err(Error::new( @@ -815,6 +796,13 @@ impl GenericParameters { )), } } + + pub fn resolve_sig(&mut self, arguments: Vec, return_type: Option) { + self.resolved_sig = Some(ResolvedSig { + arguments, + return_type, + }); + } } // TODO: remove default here? From 64f17582c1aa81ba4c130982db141ec7b84f39cb Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 27 Sep 2024 16:53:20 +0800 Subject: [PATCH 07/36] incorporate native stdlib for testing --- src/error.rs | 5 ++- src/tests/examples.rs | 4 ++ src/tests/mod.rs | 23 ++++++++++ src/tests/stdlib/mod.rs | 94 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 src/tests/stdlib/mod.rs diff --git a/src/error.rs b/src/error.rs index 87f8845d0..1c3586e77 100644 --- a/src/error.rs +++ b/src/error.rs @@ -130,7 +130,7 @@ pub enum ErrorKind { #[error("invalid array size, expected [_; x] with x in [0,2^32]")] InvalidArraySize, - #[error("only allow a single generic parameter for the size of an array argument")] + #[error("Invalid expression in symbolic size")] InvalidSymbolicSize, #[error("invalid generic parameter, expected single uppercase letter, such as N, M, etc.")] @@ -358,4 +358,7 @@ pub enum ErrorKind { #[error("invalid range, the end value can't be smaller than the start value")] InvalidRange, + + #[error("division by zero")] + DivisionByZero, } diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 3aade9de0..05a4b4768 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -13,6 +13,8 @@ use crate::{ type_checker::TypeChecker, }; +use super::init_stdlib_dep; + fn test_file( file_name: &str, public_inputs: &str, @@ -36,6 +38,7 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); + init_stdlib_dep(&mut sources, &mut tast); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -98,6 +101,7 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); + init_stdlib_dep(&mut sources, &mut tast); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3d2702bcb..34c0d577b 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,2 +1,25 @@ +use std::path::Path; + +use crate::{ + backends::Backend, + cli::packages::UserRepo, + compiler::{typecheck_next_file, Sources}, + type_checker::TypeChecker, +}; + mod examples; mod modules; +mod stdlib; + +fn init_stdlib_dep(sources: &mut Sources, tast: &mut TypeChecker) { + let libs = vec!["int", "comparator", "bigint"]; + + // read stdlib files from src/stdlib/native/ + for lib in libs { + let module = UserRepo::new(&format!("std/{}", lib)); + let prefix_stdlib = Path::new("src/stdlib/native/"); + let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); + let _node_id = + typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); + } +} diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs new file mode 100644 index 000000000..3955a474f --- /dev/null +++ b/src/tests/stdlib/mod.rs @@ -0,0 +1,94 @@ +mod bigint; +mod comparator; + +use std::{path::Path, str::FromStr}; + +use crate::{ + backends::r1cs::{R1csBn254Field, R1CS}, + circuit_writer::CircuitWriter, + compiler::{typecheck_next_file, Sources}, + error::Result, + inputs::parse_inputs, + mast, + tests::init_stdlib_dep, + type_checker::TypeChecker, + witness::CompiledCircuit, +}; + +fn test_stdlib( + path: &str, + asm_path: &str, + public_inputs: &str, + private_inputs: &str, + expected_public_output: Vec<&str>, +) -> Result>> { + let r1cs = R1CS::new(); + let root = env!("CARGO_MANIFEST_DIR"); + let prefix_path = Path::new(root).join("src/tests/stdlib"); + + // read noname file + let code = std::fs::read_to_string(prefix_path.clone().join(path)).unwrap(); + + // parse inputs + let public_inputs = parse_inputs(public_inputs).unwrap(); + let private_inputs = parse_inputs(private_inputs).unwrap(); + + // compile + let mut sources = Sources::new(); + let mut tast = TypeChecker::new(); + init_stdlib_dep(&mut sources, &mut tast); + + let this_module = None; + let _node_id = typecheck_next_file( + &mut tast, + this_module, + &mut sources, + path.to_string(), + code.clone(), + 0, + ) + .unwrap(); + + let mast = mast::monomorphize(tast)?; + let compiled_circuit = CircuitWriter::generate_circuit(mast, r1cs)?; + + // this should check the constraints + let generated_witness = compiled_circuit + .generate_witness(public_inputs.clone(), private_inputs.clone()) + .unwrap(); + + let expected_public_output = expected_public_output + .iter() + .map(|x| crate::backends::r1cs::R1csBn254Field::from_str(x).unwrap()) + .collect::>(); + + if generated_witness.outputs != expected_public_output { + eprintln!("obtained by executing the circuit:"); + generated_witness + .outputs + .iter() + .for_each(|x| eprintln!("- {x}")); + eprintln!("passed as output by the verifier:"); + expected_public_output + .iter() + .for_each(|x| eprintln!("- {x}")); + panic!("Obtained output does not match expected output"); + } + + // check the ASM + if compiled_circuit.circuit.backend.num_constraints() < 100 { + let prefix_asm = Path::new(root).join("src/tests/stdlib/"); + let expected_asm = std::fs::read_to_string(prefix_asm.clone().join(asm_path)).unwrap(); + let obtained_asm = compiled_circuit.asm(&Sources::new(), false); + + if obtained_asm != expected_asm { + eprintln!("obtained:"); + eprintln!("{obtained_asm}"); + eprintln!("expected:"); + eprintln!("{expected_asm}"); + panic!("Obtained ASM does not match expected ASM"); + } + } + + Ok(compiled_circuit) +} From be66e557dc7589a3f09c7f18ae4125aece94bd3c Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 27 Sep 2024 16:54:28 +0800 Subject: [PATCH 08/36] add int module --- src/stdlib/native/int.no | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 src/stdlib/native/int.no diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no new file mode 100644 index 000000000..9b0bc1067 --- /dev/null +++ b/src/stdlib/native/int.no @@ -0,0 +1,17 @@ +use std::bits; + +struct Uint8 { + // todo: maybe add a const attribute to Field to forbid reassignment + inner: Field, + bit_len: Field, +} + +fn Uint8.new(val: Field) -> Uint8 { + // range check + let ignore_ = bits::to_bits(8, val); + + return Uint8 { + inner: val, + bit_len: 8 + }; +} \ No newline at end of file From 6f26d5bc12d8587ba9638a08e7521e6156af4c6e Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 27 Sep 2024 16:54:58 +0800 Subject: [PATCH 09/36] add comparator stdlib --- src/stdlib/native/comparator.no | 44 +++++++++ .../comparator/less_eq_than/less_eq_than.asm | 11 +++ .../less_eq_than/less_eq_than_main.no | 6 ++ .../stdlib/comparator/less_than/less_than.asm | 11 +++ .../comparator/less_than/less_than_main.no | 10 +++ src/tests/stdlib/comparator/mod.rs | 89 +++++++++++++++++++ 6 files changed, 171 insertions(+) create mode 100644 src/stdlib/native/comparator.no create mode 100644 src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm create mode 100644 src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no create mode 100644 src/tests/stdlib/comparator/less_than/less_than.asm create mode 100644 src/tests/stdlib/comparator/less_than/less_than_main.no create mode 100644 src/tests/stdlib/comparator/mod.rs diff --git a/src/stdlib/native/comparator.no b/src/stdlib/native/comparator.no new file mode 100644 index 000000000..43a4e14bf --- /dev/null +++ b/src/stdlib/native/comparator.no @@ -0,0 +1,44 @@ +use std::bits; +use std::int; + +// Instead of comparing bit by bit, we check the carry bit: +// lhs + (1 << LEN) - rhs +// proof: +// lhs + (1 << LEN) will add a carry bit, valued 1, to the bit array representing lhs, +// resulted in a bit array of length LEN + 1, named as sum_bits. +// if `lhs < rhs``, then `lhs - rhs < 0`, thus `(1 << LEN) + lhs - rhs < (1 << LEN)` +// then, the carry bit of sum_bits is 0. +// otherwise, the carry bit of sum_bits is 1. +fn less_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { + let carry_bit_len = LEN + 1; + + // 1 << LEN + let mut pow2 = 1; + for ii in 0..LEN { + pow2 = pow2 + pow2; + } + + let sum = (pow2 + lhs) - rhs; + let sum_bit = bits::to_bits(carry_bit_len, sum); + + // todo: modify the ife to allow literals + let b1 = false; + let b2 = true; + let res = if sum_bit[LEN] { b1 } else { b2 }; + + return res; +} + +// Less than or equal to. +// based on the proof of less_than(): +// adding 1 to the rhs, can upper bound by 1 for the lhs: +// lhs < rhs + 1 +// is equivalent to +// lhs <= rhs +fn less_eq_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { + return less_than(LEN, lhs, rhs + 1); +} + +fn uint8_less_than(lhs: int::Uint8, rhs: int::Uint8) -> Bool { + return less_than(8, lhs.inner, rhs.inner); +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm new file mode 100644 index 000000000..396f070da --- /dev/null +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm @@ -0,0 +1,11 @@ +@ noname.0.7.0 +@ public inputs: 2 + +v_5 == (v_4) * (v_4 + -1) +0 == (v_5) * (1) +v_7 == (v_6) * (v_6 + -1) +0 == (v_7) * (1) +v_9 == (v_8) * (v_8 + -1) +0 == (v_9) * (1) +v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 3) * (1) +-1 * v_8 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no new file mode 100644 index 000000000..5f998ae9f --- /dev/null +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no @@ -0,0 +1,6 @@ +use std::comparator; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let res = comparator::less_eq_than(2, lhs, rhs); + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm new file mode 100644 index 000000000..8c33ac348 --- /dev/null +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -0,0 +1,11 @@ +@ noname.0.7.0 +@ public inputs: 2 + +v_5 == (v_4) * (v_4 + -1) +0 == (v_5) * (1) +v_7 == (v_6) * (v_6 + -1) +0 == (v_7) * (1) +v_9 == (v_8) * (v_8 + -1) +0 == (v_9) * (1) +v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 4) * (1) +-1 * v_8 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no new file mode 100644 index 000000000..f0cba5462 --- /dev/null +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -0,0 +1,10 @@ +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + // todo bug: this also throws error "method call only work on custom types" + let lhs_bigint = int::Uint8.new(lhs); + let rhs_bigint = int::Uint8.new(rhs); + // let res = comparator::uint8_less_than(lhs_bigint, rhs_bigint); + return true; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/mod.rs b/src/tests/stdlib/comparator/mod.rs new file mode 100644 index 000000000..bc85f91fe --- /dev/null +++ b/src/tests/stdlib/comparator/mod.rs @@ -0,0 +1,89 @@ +use crate::error; + +use super::test_stdlib; +use error::Result; + +#[test] +fn test_less_than_true() -> Result<()> { + let public_inputs = r#"{"lhs": "0"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "comparator/less_than/less_than_main.no", + "comparator/less_than/less_than.asm", + public_inputs, + private_inputs, + vec!["1"], + )?; + + Ok(()) +} + +// test false +#[test] +fn test_less_than_false() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "0"}"#; + + test_stdlib( + "comparator/less_than/less_than_main.no", + "comparator/less_than/less_than.asm", + public_inputs, + private_inputs, + vec!["0"], + )?; + + Ok(()) +} + +#[test] +fn test_less_eq_than_true_1() -> Result<()> { + let public_inputs = r#"{"lhs": "0"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + "comparator/less_eq_than/less_eq_than.asm", + public_inputs, + private_inputs, + vec!["1"], + )?; + + Ok(()) +} + +#[test] +fn test_less_eq_than_true_2() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + "comparator/less_eq_than/less_eq_than.asm", + public_inputs, + private_inputs, + vec!["1"], + )?; + + Ok(()) +} + +#[test] +fn test_less_eq_than_false() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "0"}"#; + + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + "comparator/less_eq_than/less_eq_than.asm", + public_inputs, + private_inputs, + vec!["0"], + )?; + + Ok(()) +} + +// test value overflow modulus +// it shouldn't need user to enter the bit length +// should have a way to restrict and type check the value to a certain bit length From b2830469ab73e03e00645dc383afa11828534765 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 27 Sep 2024 16:55:17 +0800 Subject: [PATCH 10/36] native bits stdlib --- .../asm/kimchi/generic_builtin_bits.asm | 95 +++++++++------- .../fixture/asm/r1cs/generic_builtin_bits.asm | 7 +- examples/generic_builtin_bits.no | 103 ++++++++++++++++-- 3 files changed, 155 insertions(+), 50 deletions(-) diff --git a/examples/fixture/asm/kimchi/generic_builtin_bits.asm b/examples/fixture/asm/kimchi/generic_builtin_bits.asm index 671fcfc5e..a447ae7b1 100644 --- a/examples/fixture/asm/kimchi/generic_builtin_bits.asm +++ b/examples/fixture/asm/kimchi/generic_builtin_bits.asm @@ -2,19 +2,34 @@ @ public inputs: 1 DoubleGeneric<1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<4,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,1> @@ -25,47 +40,51 @@ DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<4,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> -DoubleGeneric<1> -DoubleGeneric<1,0,0,0,-1> -DoubleGeneric<1,0,-1> -DoubleGeneric<2,0,-1> -DoubleGeneric<1,1,-1> -DoubleGeneric<4,0,-1> -DoubleGeneric<1,1,-1> -DoubleGeneric<1,-1> -(0,0) -> (15,0) -> (28,1) -> (36,1) -(1,0) -> (2,0) -> (4,0) -> (16,0) -> (23,0) -(1,2) -> (2,1) -(2,2) -> (3,0) -(4,2) -> (9,0) -(5,0) -> (6,0) -> (8,0) -> (19,0) -> (24,0) -(5,2) -> (6,1) -(6,2) -> (7,0) -(8,2) -> (9,1) -(9,2) -> (14,0) -(10,0) -> (11,0) -> (13,0) -> (20,0) -> (26,0) -(10,2) -> (11,1) -(11,2) -> (12,0) +DoubleGeneric<1,0,0,0,-2> +(0,0) -> (30,1) -> (49,1) -> (50,0) +(1,0) -> (2,0) -> (7,0) -> (8,0) -> (31,0) -> (38,0) -> (39,0) +(1,2) -> (4,0) -> (5,0) +(2,1) -> (3,0) +(4,2) -> (5,1) +(5,2) -> (6,0) +(7,2) -> (19,0) +(8,1) -> (9,0) +(10,0) -> (11,0) -> (16,0) -> (17,0) -> (34,0) -> (41,0) -> (42,0) +(10,2) -> (13,0) -> (14,0) +(11,1) -> (12,0) (13,2) -> (14,1) -(14,2) -> (15,1) -(16,1) -> (17,0) -(17,2) -> (18,0) -(20,1) -> (21,0) -(21,2) -> (22,0) -(23,2) -> (25,0) -(24,2) -> (25,1) -(25,2) -> (27,0) -(26,2) -> (27,1) -(27,2) -> (28,0) -(29,0) -> (31,0) -> (34,0) -(30,0) -> (32,0) -(31,2) -> (33,0) -(32,2) -> (33,1) -(33,2) -> (35,0) -(34,2) -> (35,1) -(35,2) -> (36,0) +(14,2) -> (15,0) +(16,2) -> (19,1) +(17,1) -> (18,0) +(19,2) -> (29,0) +(20,0) -> (21,0) -> (26,0) -> (27,0) -> (35,0) -> (45,0) -> (46,0) +(20,2) -> (23,0) -> (24,0) +(21,1) -> (22,0) +(23,2) -> (24,1) +(24,2) -> (25,0) +(26,2) -> (29,1) +(27,1) -> (28,0) +(29,2) -> (30,0) +(31,1) -> (32,0) +(32,2) -> (33,0) +(35,1) -> (36,0) +(36,2) -> (37,0) +(38,2) -> (44,0) +(39,1) -> (40,0) +(41,2) -> (44,1) +(42,1) -> (43,0) +(44,2) -> (48,0) +(45,2) -> (48,1) +(46,1) -> (47,0) +(48,2) -> (49,0) diff --git a/examples/fixture/asm/r1cs/generic_builtin_bits.asm b/examples/fixture/asm/r1cs/generic_builtin_bits.asm index c216dc416..43d7a4f02 100644 --- a/examples/fixture/asm/r1cs/generic_builtin_bits.asm +++ b/examples/fixture/asm/r1cs/generic_builtin_bits.asm @@ -7,12 +7,9 @@ v_5 == (v_4) * (v_4 + -1) 0 == (v_5) * (1) v_7 == (v_6) * (v_6 + -1) 0 == (v_7) * (1) -v_2 + 2 * v_4 + 4 * v_6 == (v_1) * (1) +v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) 1 == (-1 * v_2 + 1) * (1) 1 == (v_4) * (1) 1 == (-1 * v_6 + 1) * (1) v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) -0 == (v_8) * (1) -1 == (v_9) * (1) -0 == (v_10) * (1) -v_1 == (v_8 + 2 * v_9 + 4 * v_10) * (1) +2 == (v_1) * (1) diff --git a/examples/generic_builtin_bits.no b/examples/generic_builtin_bits.no index db4830d07..62c412c08 100644 --- a/examples/generic_builtin_bits.no +++ b/examples/generic_builtin_bits.no @@ -1,22 +1,111 @@ +// circom versions: +// template Num2Bits(n) { +// signal input in; +// signal output out[n]; +// var lc1=0; + +// var e2=1; +// for (var i = 0; i> i) & 1; +// out[i] * (out[i] -1 ) === 0; +// lc1 += out[i] * e2; +// e2 = e2+e2; +// } + +// lc1 === in; +// } + +// template Bits2Num(n) { +// signal input in[n]; +// signal output out; +// var lc1=0; + +// var e2 = 1; +// for (var i = 0; i out; +// } + use std::bits; -// 010 = xx, where xx = 2 +// obviously writing this in native is much simpler than the builtin version +fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { + let mut bits = [false; LEN]; + let mut lc1 = 0; + let mut e2 = 1; + + let one = 1; + let zero = 0; + + for index in 0..LEN { + // maybe add a unconstrained / unsafe attribute before bits::nth_bit, such that: + // bits[index] = unsafe bits::nth_bit(value, index); + // here we need to ensure the related variables are constrained: + // 1. value: constrained to be equal with the sum of bits, which involves the index as well + // 2. index: a cell index in bits + // 3. bits: the output bits + // beyond the notation purpose, what security measures can we take to help guide this unsafe operation? + // one idea is to rely on this unsafe attribute to check if it is non-deterministic when constraining the bits[index] + // eg. + // - bits::nth_bit(value, index) is non-deterministic + // - a metadata can be added to the var of the bits as non-deterministic + // - when CS trying to constrain the non-deterministic var, + // it will raise an error if the var is not marked unsafe via the attribute unsafe + // thus, it seems we also need to add the attribute to the builtin function signature + // eg. `unsafe nth_bit(val: Field, const nth: Field) -> Bool` + // while the unsafe attribute in `bits[index] = unsafe bits::nth_bit(value, index);` + // is for the users to acknowledge they are responsible for having additional constraints. + // This approach makes it explicit whether an expression is non-deterministic at the first place. + // On the other hand, circom lang determines whether it is non-deterministic by folding the arithmetic operation. + + bits[index] = bits::nth_bit(value, index); + // nth_bit is a hint function, and it doesn't constraint the value of the bits as boolean, + // although its return type is boolean. + // can we make the arithmetic operation compatible with boolean? + // or just make a stdlib to convert boolean to Field while adding the constraint? + let bit_num = if bits[index] {one} else {zero}; + assert_eq(bit_num * (bit_num - 1), 0); + + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + assert_eq(lc1, value); + return bits; +} + +fn from_bits(bits: [Bool; LEN]) -> Field { + let mut lc1 = 0; + let mut e2 = 1; + let zero = 0; + + for index in 0..LEN { + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + return lc1; +} + fn main(pub xx: Field) { - // var - let bits = bits::to_bits(3, xx); + // calculate on a cell var + let bits = to_bits(3, xx); assert(!bits[0]); assert(bits[1]); assert(!bits[2]); - let val = bits::from_bits(bits); + let val = from_bits(bits); assert_eq(val, xx); - // constant - let cst_bits = bits::to_bits(3, 2); + // calculate on a constant + let cst_bits = to_bits(3, 2); assert(!cst_bits[0]); assert(cst_bits[1]); assert(!cst_bits[2]); - let cst = bits::from_bits(cst_bits); + let cst = from_bits(cst_bits); assert_eq(cst, xx); } + +// ^ the asm diffs originated from the fact that the builtin version stored constant as cell vars. \ No newline at end of file From 00d230bd133d955363495d74eb3169533db5bc31 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 11:50:47 +0800 Subject: [PATCH 11/36] move init_stdlib_dep to stdlib --- src/stdlib/mod.rs | 30 ++++++++++++++++++++---------- src/tests/examples.rs | 15 ++++++--------- src/tests/mod.rs | 22 ---------------------- src/tests/stdlib/mod.rs | 16 ++++------------ 4 files changed, 30 insertions(+), 53 deletions(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 150d31517..0483bdeaa 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -1,17 +1,10 @@ use crate::{ - backends::Backend, - circuit_writer::{CircuitWriter, VarInfo}, - constants::Span, - error::Result, - imports::FnKind, - lexer::Token, - parser::{ + backends::Backend, circuit_writer::{CircuitWriter, VarInfo}, cli::packages::UserRepo, compiler::{typecheck_next_file, Sources}, constants::Span, error::Result, imports::FnKind, lexer::Token, parser::{ types::{FnSig, GenericParameters}, ParserCtx, - }, - type_checker::FnInfo, - var::Var, + }, type_checker::{FnInfo, TypeChecker}, var::Var }; +use std::path::Path; pub mod bits; pub mod builtins; @@ -79,3 +72,20 @@ trait Module { res } } + +pub fn init_stdlib_dep(sources: &mut Sources, tast: &mut TypeChecker, node_id: usize) -> usize { + // list the stdlib dependency in order + let libs = vec!["int", "comparator"]; + + let mut node_id = node_id; + + for lib in libs { + let module = UserRepo::new(&format!("std/{}", lib)); + let prefix_stdlib = Path::new("src/stdlib/native/"); + let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); + node_id = + typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); + } + + node_id +} diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 05a4b4768..9fd373db8 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -7,14 +7,9 @@ use crate::{ kimchi::{KimchiVesta, VestaField}, r1cs::R1CS, BackendKind, - }, - compiler::{compile, typecheck_next_file, Sources}, - inputs::{parse_inputs, ExtField}, - type_checker::TypeChecker, + }, compiler::{compile, typecheck_next_file, Sources}, inputs::{parse_inputs, ExtField}, stdlib::init_stdlib_dep, type_checker::TypeChecker }; -use super::init_stdlib_dep; - fn test_file( file_name: &str, public_inputs: &str, @@ -38,7 +33,8 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); - init_stdlib_dep(&mut sources, &mut tast); + let mut node_id = 0; + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -46,7 +42,7 @@ fn test_file( &mut sources, file_name.to_string(), code.clone(), - 0, + node_id, ) .unwrap(); @@ -101,7 +97,8 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); - init_stdlib_dep(&mut sources, &mut tast); + let mut node_id = 0; + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 34c0d577b..3036fa55d 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,25 +1,3 @@ -use std::path::Path; - -use crate::{ - backends::Backend, - cli::packages::UserRepo, - compiler::{typecheck_next_file, Sources}, - type_checker::TypeChecker, -}; - mod examples; mod modules; mod stdlib; - -fn init_stdlib_dep(sources: &mut Sources, tast: &mut TypeChecker) { - let libs = vec!["int", "comparator", "bigint"]; - - // read stdlib files from src/stdlib/native/ - for lib in libs { - let module = UserRepo::new(&format!("std/{}", lib)); - let prefix_stdlib = Path::new("src/stdlib/native/"); - let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); - let _node_id = - typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); - } -} diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 3955a474f..a00c50d48 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -1,18 +1,9 @@ -mod bigint; mod comparator; use std::{path::Path, str::FromStr}; use crate::{ - backends::r1cs::{R1csBn254Field, R1CS}, - circuit_writer::CircuitWriter, - compiler::{typecheck_next_file, Sources}, - error::Result, - inputs::parse_inputs, - mast, - tests::init_stdlib_dep, - type_checker::TypeChecker, - witness::CompiledCircuit, + backends::r1cs::{R1csBn254Field, R1CS}, circuit_writer::CircuitWriter, compiler::{typecheck_next_file, Sources}, error::Result, inputs::parse_inputs, mast, stdlib::init_stdlib_dep, type_checker::TypeChecker, witness::CompiledCircuit }; fn test_stdlib( @@ -36,7 +27,8 @@ fn test_stdlib( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); - init_stdlib_dep(&mut sources, &mut tast); + let mut node_id = 0; + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id); let this_module = None; let _node_id = typecheck_next_file( @@ -45,7 +37,7 @@ fn test_stdlib( &mut sources, path.to_string(), code.clone(), - 0, + node_id, ) .unwrap(); From 1d2cba78a1235336699763904af14d0b25d4e520 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 14:16:56 +0800 Subject: [PATCH 12/36] add bits stdlib --- .../asm/kimchi/generic_builtin_bits.asm | 121 ++++++++++++------ .../fixture/asm/r1cs/generic_builtin_bits.asm | 31 +++-- examples/generic_builtin_bits.no | 98 +------------- src/stdlib/bits.rs | 59 ++++++++- src/stdlib/mod.rs | 22 +++- src/stdlib/native/bits.no | 66 ++++++++++ src/tests/examples.rs | 6 +- .../comparator/less_eq_than/less_eq_than.asm | 27 +++- .../comparator/less_than/less_than_main.no | 4 +- src/tests/stdlib/mod.rs | 10 +- 10 files changed, 289 insertions(+), 155 deletions(-) create mode 100644 src/stdlib/native/bits.no diff --git a/examples/fixture/asm/kimchi/generic_builtin_bits.asm b/examples/fixture/asm/kimchi/generic_builtin_bits.asm index a447ae7b1..a7843d1c4 100644 --- a/examples/fixture/asm/kimchi/generic_builtin_bits.asm +++ b/examples/fixture/asm/kimchi/generic_builtin_bits.asm @@ -2,10 +2,16 @@ @ public inputs: 1 DoubleGeneric<1> -DoubleGeneric<1,0,-1> +DoubleGeneric<1,0,-1,0,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,1,-1> +DoubleGeneric<0,0,-1,1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,-1,0,-1> +DoubleGeneric<1,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<1,0,-1> @@ -17,16 +23,35 @@ DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> -DoubleGeneric<2,0,-1> DoubleGeneric<1,1> -DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,1,-1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> +DoubleGeneric<1,1> +DoubleGeneric<1,1,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,-1> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<4,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> @@ -52,39 +77,61 @@ DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,0,0,0,-2> -(0,0) -> (30,1) -> (49,1) -> (50,0) -(1,0) -> (2,0) -> (7,0) -> (8,0) -> (31,0) -> (38,0) -> (39,0) -(1,2) -> (4,0) -> (5,0) -(2,1) -> (3,0) -(4,2) -> (5,1) -(5,2) -> (6,0) -(7,2) -> (19,0) +(0,0) -> (55,1) -> (74,1) -> (75,0) +(1,0) -> (2,0) -> (5,0) +(1,2) -> (2,1) +(2,2) -> (3,0) +(4,0) -> (6,0) -> (23,0) -> (41,0) +(5,1) -> (6,1) +(6,2) -> (7,1) -> (11,1) +(7,2) -> (10,0) +(8,0) -> (11,0) -> (13,0) -> (14,0) (8,1) -> (9,0) -(10,0) -> (11,0) -> (16,0) -> (17,0) -> (34,0) -> (41,0) -> (42,0) -(10,2) -> (13,0) -> (14,0) -(11,1) -> (12,0) -(13,2) -> (14,1) -(14,2) -> (15,0) -(16,2) -> (19,1) +(9,2) -> (10,1) +(11,2) -> (12,0) +(13,2) -> (16,0) -> (17,0) -> (56,0) -> (63,0) -> (64,0) +(14,1) -> (15,0) +(16,2) -> (36,0) (17,1) -> (18,0) -(19,2) -> (29,0) -(20,0) -> (21,0) -> (26,0) -> (27,0) -> (35,0) -> (45,0) -> (46,0) -(20,2) -> (23,0) -> (24,0) -(21,1) -> (22,0) -(23,2) -> (24,1) -(24,2) -> (25,0) -(26,2) -> (29,1) -(27,1) -> (28,0) -(29,2) -> (30,0) +(19,0) -> (20,0) -> (22,0) +(19,2) -> (20,1) +(20,2) -> (21,0) +(22,1) -> (23,1) +(23,2) -> (24,1) -> (28,1) +(24,2) -> (27,0) +(25,0) -> (28,0) -> (30,0) -> (31,0) +(25,1) -> (26,0) +(26,2) -> (27,1) +(28,2) -> (29,0) +(30,2) -> (33,0) -> (34,0) -> (59,0) -> (66,0) -> (67,0) (31,1) -> (32,0) -(32,2) -> (33,0) -(35,1) -> (36,0) -(36,2) -> (37,0) -(38,2) -> (44,0) -(39,1) -> (40,0) -(41,2) -> (44,1) -(42,1) -> (43,0) -(44,2) -> (48,0) -(45,2) -> (48,1) -(46,1) -> (47,0) -(48,2) -> (49,0) +(33,2) -> (36,1) +(34,1) -> (35,0) +(36,2) -> (54,0) +(37,0) -> (38,0) -> (40,0) +(37,2) -> (38,1) +(38,2) -> (39,0) +(40,1) -> (41,1) +(41,2) -> (42,1) -> (46,1) +(42,2) -> (45,0) +(43,0) -> (46,0) -> (48,0) -> (49,0) +(43,1) -> (44,0) +(44,2) -> (45,1) +(46,2) -> (47,0) +(48,2) -> (51,0) -> (52,0) -> (60,0) -> (70,0) -> (71,0) +(49,1) -> (50,0) +(51,2) -> (54,1) +(52,1) -> (53,0) +(54,2) -> (55,0) +(56,1) -> (57,0) +(57,2) -> (58,0) +(60,1) -> (61,0) +(61,2) -> (62,0) +(63,2) -> (69,0) +(64,1) -> (65,0) +(66,2) -> (69,1) +(67,1) -> (68,0) +(69,2) -> (73,0) +(70,2) -> (73,1) +(71,1) -> (72,0) +(73,2) -> (74,0) diff --git a/examples/fixture/asm/r1cs/generic_builtin_bits.asm b/examples/fixture/asm/r1cs/generic_builtin_bits.asm index 43d7a4f02..fb72f65a4 100644 --- a/examples/fixture/asm/r1cs/generic_builtin_bits.asm +++ b/examples/fixture/asm/r1cs/generic_builtin_bits.asm @@ -3,13 +3,28 @@ v_3 == (v_2) * (v_2 + -1) 0 == (v_3) * (1) -v_5 == (v_4) * (v_4 + -1) -0 == (v_5) * (1) -v_7 == (v_6) * (v_6 + -1) -0 == (v_7) * (1) -v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) -1 == (-1 * v_2 + 1) * (1) 1 == (v_4) * (1) -1 == (-1 * v_6 + 1) * (1) -v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) +v_6 == (v_5) * (-1 * v_2 + v_4) +-1 * v_7 + 1 == (v_6) * (1) +v_8 == (v_7) * (-1 * v_2 + v_4) +0 == (v_8) * (1) +v_10 == (v_9) * (v_9 + -1) +0 == (v_10) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_9 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_9 + v_11) +0 == (v_15) * (1) +v_17 == (v_16) * (v_16 + -1) +0 == (v_17) * (1) +1 == (v_18) * (1) +v_20 == (v_19) * (-1 * v_16 + v_18) +-1 * v_21 + 1 == (v_20) * (1) +v_22 == (v_21) * (-1 * v_16 + v_18) +0 == (v_22) * (1) +v_1 == (v_7 + 2 * v_14 + 4 * v_21) * (1) +1 == (-1 * v_7 + 1) * (1) +1 == (v_14) * (1) +1 == (-1 * v_21 + 1) * (1) +v_1 == (v_7 + 2 * v_14 + 4 * v_21) * (1) 2 == (v_1) * (1) diff --git a/examples/generic_builtin_bits.no b/examples/generic_builtin_bits.no index 62c412c08..c15b2b178 100644 --- a/examples/generic_builtin_bits.no +++ b/examples/generic_builtin_bits.no @@ -1,111 +1,21 @@ -// circom versions: -// template Num2Bits(n) { -// signal input in; -// signal output out[n]; -// var lc1=0; - -// var e2=1; -// for (var i = 0; i> i) & 1; -// out[i] * (out[i] -1 ) === 0; -// lc1 += out[i] * e2; -// e2 = e2+e2; -// } - -// lc1 === in; -// } - -// template Bits2Num(n) { -// signal input in[n]; -// signal output out; -// var lc1=0; - -// var e2 = 1; -// for (var i = 0; i out; -// } - use std::bits; -// obviously writing this in native is much simpler than the builtin version -fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { - let mut bits = [false; LEN]; - let mut lc1 = 0; - let mut e2 = 1; - - let one = 1; - let zero = 0; - - for index in 0..LEN { - // maybe add a unconstrained / unsafe attribute before bits::nth_bit, such that: - // bits[index] = unsafe bits::nth_bit(value, index); - // here we need to ensure the related variables are constrained: - // 1. value: constrained to be equal with the sum of bits, which involves the index as well - // 2. index: a cell index in bits - // 3. bits: the output bits - // beyond the notation purpose, what security measures can we take to help guide this unsafe operation? - // one idea is to rely on this unsafe attribute to check if it is non-deterministic when constraining the bits[index] - // eg. - // - bits::nth_bit(value, index) is non-deterministic - // - a metadata can be added to the var of the bits as non-deterministic - // - when CS trying to constrain the non-deterministic var, - // it will raise an error if the var is not marked unsafe via the attribute unsafe - // thus, it seems we also need to add the attribute to the builtin function signature - // eg. `unsafe nth_bit(val: Field, const nth: Field) -> Bool` - // while the unsafe attribute in `bits[index] = unsafe bits::nth_bit(value, index);` - // is for the users to acknowledge they are responsible for having additional constraints. - // This approach makes it explicit whether an expression is non-deterministic at the first place. - // On the other hand, circom lang determines whether it is non-deterministic by folding the arithmetic operation. - - bits[index] = bits::nth_bit(value, index); - // nth_bit is a hint function, and it doesn't constraint the value of the bits as boolean, - // although its return type is boolean. - // can we make the arithmetic operation compatible with boolean? - // or just make a stdlib to convert boolean to Field while adding the constraint? - let bit_num = if bits[index] {one} else {zero}; - assert_eq(bit_num * (bit_num - 1), 0); - - lc1 = lc1 + if bits[index] {e2} else {zero}; - e2 = e2 + e2; - } - assert_eq(lc1, value); - return bits; -} - -fn from_bits(bits: [Bool; LEN]) -> Field { - let mut lc1 = 0; - let mut e2 = 1; - let zero = 0; - - for index in 0..LEN { - lc1 = lc1 + if bits[index] {e2} else {zero}; - e2 = e2 + e2; - } - return lc1; -} - fn main(pub xx: Field) { // calculate on a cell var - let bits = to_bits(3, xx); + let bits = bits::to_bits(3, xx); assert(!bits[0]); assert(bits[1]); assert(!bits[2]); - let val = from_bits(bits); + let val = bits::from_bits(bits); assert_eq(val, xx); // calculate on a constant - let cst_bits = to_bits(3, 2); + let cst_bits = bits::to_bits(3, 2); assert(!cst_bits[0]); assert(cst_bits[1]); assert(!cst_bits[2]); - let cst = from_bits(cst_bits); + let cst = bits::from_bits(cst_bits); assert_eq(cst, xx); } - -// ^ the asm diffs originated from the fact that the builtin version stored constant as cell vars. \ No newline at end of file diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 3fe256ee2..6b7d49643 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -1,7 +1,7 @@ use std::vec; use ark_ff::One; -use kimchi::o1_utils::FieldHelpers; +use kimchi::{o1_utils::FieldHelpers, turshi::helper::CairoFieldHelpers}; use crate::{ backends::Backend, @@ -17,6 +17,7 @@ use super::{FnInfoType, Module}; const TO_BITS_FN: &str = "to_bits(const LEN: Field, val: Field) -> [Bool; LEN]"; const FROM_BITS_FN: &str = "from_bits(bits: [Bool; LEN]) -> Field"; +const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; pub struct BitsLib {} @@ -24,7 +25,11 @@ impl Module for BitsLib { const MODULE: &'static str = "bits"; fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![(TO_BITS_FN, to_bits), (FROM_BITS_FN, from_bits)] + vec![ + (TO_BITS_FN, to_bits), + (FROM_BITS_FN, from_bits), + (NTH_BIT_FN, nth_bit), + ] } } @@ -164,3 +169,53 @@ fn from_bits( Ok(Some(Var::new_cvar(cvar, span))) } + +fn nth_bit( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let nth = &var_info.var; + assert_eq!(nth.len(), 1); + + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); + } + }; + + // create a cell var for the symbolic value representing the nth bit. + // it seems we will always have to create cell vars to allocate the symbolic values that involve non-deterministic calculations. + // it is non-deterministic because it involves non-deterministic arithmetic on a cell var. + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); + + // constrain it to be either 0 or 1 + // bit * (bit - 1 ) === 0; + // boolean::check(compiler, &ConstOrCell::Cell(bit.clone()), span); + + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) +} diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 0483bdeaa..e89ee71a5 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -1,8 +1,18 @@ use crate::{ - backends::Backend, circuit_writer::{CircuitWriter, VarInfo}, cli::packages::UserRepo, compiler::{typecheck_next_file, Sources}, constants::Span, error::Result, imports::FnKind, lexer::Token, parser::{ + backends::Backend, + circuit_writer::{CircuitWriter, VarInfo}, + cli::packages::UserRepo, + compiler::{typecheck_next_file, Sources}, + constants::Span, + error::Result, + imports::FnKind, + lexer::Token, + parser::{ types::{FnSig, GenericParameters}, ParserCtx, - }, type_checker::{FnInfo, TypeChecker}, var::Var + }, + type_checker::{FnInfo, TypeChecker}, + var::Var, }; use std::path::Path; @@ -73,9 +83,13 @@ trait Module { } } -pub fn init_stdlib_dep(sources: &mut Sources, tast: &mut TypeChecker, node_id: usize) -> usize { +pub fn init_stdlib_dep( + sources: &mut Sources, + tast: &mut TypeChecker, + node_id: usize, +) -> usize { // list the stdlib dependency in order - let libs = vec!["int", "comparator"]; + let libs = vec!["int", "bits", "comparator"]; let mut node_id = node_id; diff --git a/src/stdlib/native/bits.no b/src/stdlib/native/bits.no new file mode 100644 index 000000000..65825cdf1 --- /dev/null +++ b/src/stdlib/native/bits.no @@ -0,0 +1,66 @@ +hint fn nth_bit(value: Field, const nth: Field) -> Field; + +// obviously writing this in native is much simpler than the builtin version +fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { + let mut bits = [false; LEN]; + let mut lc1 = 0; + let mut e2 = 1; + + let one = 1; + let zero = 0; + + // todo: ITE should allow literals + let true_val = true; + let false_val = false; + + for index in 0..LEN { + // maybe add a unconstrained / unsafe attribute before bits::nth_bit, such that: + // bits[index] = unsafe bits::nth_bit(value, index); + // here we need to ensure the related variables are constrained: + // 1. value: constrained to be equal with the sum of bits, which involves the index as well + // 2. index: a cell index in bits + // 3. bits: the output bits + // beyond the notation purpose, what security measures can we take to help guide this unsafe operation? + // one idea is to rely on this unsafe attribute to check if it is non-deterministic when constraining the bits[index] + // eg. + // - bits::nth_bit(value, index) is non-deterministic + // - a metadata can be added to the var of the bits as non-deterministic + // - when CS trying to constrain the non-deterministic var, + // it will raise an error if the var is not marked unsafe via the attribute unsafe + // thus, it seems we also need to add the attribute to the builtin function signature + // eg. `unsafe nth_bit(val: Field, const nth: Field) -> Bool` + // while the unsafe attribute in `bits[index] = unsafe bits::nth_bit(value, index);` + // is for the users to acknowledge they are responsible for having additional constraints. + // This approach makes it explicit whether an expression is non-deterministic at the first place. + // On the other hand, circom lang determines whether it is non-deterministic by folding the arithmetic operation. + + let bit_num = unsafe nth_bit(value, index); + // nth_bit is a hint function, and it doesn't constraint the value of the bits as boolean, + // although its return type is boolean. + // can we make the arithmetic operation compatible with boolean? + // or just make a stdlib to convert boolean to Field while adding the constraint? + + // constrain the bit_num to be 0 or 1 + assert_eq(bit_num * (bit_num - 1), 0); + + // convert the bit_num to boolean + bits[index] = if bit_num == 1 {true_val} else {false_val}; + + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + assert_eq(lc1, value); + return bits; +} + +fn from_bits(bits: [Bool; LEN]) -> Field { + let mut lc1 = 0; + let mut e2 = 1; + let zero = 0; + + for index in 0..LEN { + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + return lc1; +} \ No newline at end of file diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 9fd373db8..ddffba037 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -7,7 +7,11 @@ use crate::{ kimchi::{KimchiVesta, VestaField}, r1cs::R1CS, BackendKind, - }, compiler::{compile, typecheck_next_file, Sources}, inputs::{parse_inputs, ExtField}, stdlib::init_stdlib_dep, type_checker::TypeChecker + }, + compiler::{compile, typecheck_next_file, Sources}, + inputs::{parse_inputs, ExtField}, + stdlib::init_stdlib_dep, + type_checker::TypeChecker, }; fn test_file( diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm index 396f070da..4608ebd97 100644 --- a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm @@ -3,9 +3,24 @@ v_5 == (v_4) * (v_4 + -1) 0 == (v_5) * (1) -v_7 == (v_6) * (v_6 + -1) -0 == (v_7) * (1) -v_9 == (v_8) * (v_8 + -1) -0 == (v_9) * (1) -v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 3) * (1) --1 * v_8 + 1 == (v_1) * (1) +1 == (v_6) * (1) +v_8 == (v_7) * (-1 * v_4 + v_6) +-1 * v_9 + 1 == (v_8) * (1) +v_10 == (v_9) * (-1 * v_4 + v_6) +0 == (v_10) * (1) +v_12 == (v_11) * (v_11 + -1) +0 == (v_12) * (1) +1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_11 + v_13) +-1 * v_16 + 1 == (v_15) * (1) +v_17 == (v_16) * (-1 * v_11 + v_13) +0 == (v_17) * (1) +v_19 == (v_18) * (v_18 + -1) +0 == (v_19) * (1) +1 == (v_20) * (1) +v_22 == (v_21) * (-1 * v_18 + v_20) +-1 * v_23 + 1 == (v_22) * (1) +v_24 == (v_23) * (-1 * v_18 + v_20) +0 == (v_24) * (1) +v_2 + -1 * v_3 + 3 == (v_9 + 2 * v_16 + 4 * v_23) * (1) +-1 * v_23 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no index f0cba5462..a4eb8e7b9 100644 --- a/src/tests/stdlib/comparator/less_than/less_than_main.no +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -5,6 +5,6 @@ fn main(pub lhs: Field, rhs: Field) -> Bool { // todo bug: this also throws error "method call only work on custom types" let lhs_bigint = int::Uint8.new(lhs); let rhs_bigint = int::Uint8.new(rhs); - // let res = comparator::uint8_less_than(lhs_bigint, rhs_bigint); - return true; + let res = comparator::uint8_less_than(lhs_bigint, rhs_bigint); + return res; } \ No newline at end of file diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index a00c50d48..92a114847 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -3,7 +3,15 @@ mod comparator; use std::{path::Path, str::FromStr}; use crate::{ - backends::r1cs::{R1csBn254Field, R1CS}, circuit_writer::CircuitWriter, compiler::{typecheck_next_file, Sources}, error::Result, inputs::parse_inputs, mast, stdlib::init_stdlib_dep, type_checker::TypeChecker, witness::CompiledCircuit + backends::r1cs::{R1csBn254Field, R1CS}, + circuit_writer::CircuitWriter, + compiler::{typecheck_next_file, Sources}, + error::Result, + inputs::parse_inputs, + mast, + stdlib::init_stdlib_dep, + type_checker::TypeChecker, + witness::CompiledCircuit, }; fn test_stdlib( From 9b94dfbfc1bf0af62568d75d2cc7dc5b5bf327d5 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 17:08:25 +0800 Subject: [PATCH 13/36] remove bits builtin --- src/stdlib/bits.rs | 150 --------------------------------------------- src/stdlib/mod.rs | 2 +- 2 files changed, 1 insertion(+), 151 deletions(-) diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 6b7d49643..8baa88f65 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -1,13 +1,11 @@ use std::vec; -use ark_ff::One; use kimchi::{o1_utils::FieldHelpers, turshi::helper::CairoFieldHelpers}; use crate::{ backends::Backend, circuit_writer::{CircuitWriter, VarInfo}, constants::Span, - constraints::boolean, error::Result, parser::types::GenericParameters, var::{ConstOrCell, Value, Var}, @@ -15,8 +13,6 @@ use crate::{ use super::{FnInfoType, Module}; -const TO_BITS_FN: &str = "to_bits(const LEN: Field, val: Field) -> [Bool; LEN]"; -const FROM_BITS_FN: &str = "from_bits(bits: [Bool; LEN]) -> Field"; const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; pub struct BitsLib {} @@ -26,150 +22,11 @@ impl Module for BitsLib { fn get_fns() -> Vec<(&'static str, FnInfoType)> { vec![ - (TO_BITS_FN, to_bits), - (FROM_BITS_FN, from_bits), (NTH_BIT_FN, nth_bit), ] } } -fn to_bits( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // should be two input vars - assert_eq!(vars.len(), 2); - - // but the better practice would be to retrieve the value from the generics - let bitlen = generics.get("LEN") as usize; - - // num should be greater than 0 - assert!(bitlen > 0); - - let modulus_bits: usize = B::Field::modulus_biguint() - .bits() - .try_into() - .expect("modulus is too large"); - - assert!(bitlen <= (modulus_bits - 1)); - - // alternatively, it can be retrieved from the first var, but it is not recommended - // let num_var = &vars[0]; - - // second var is the value to convert - let var_info = &vars[1]; - let var = &var_info.var; - assert_eq!(var.len(), 1); - - let val = match &var[0] { - ConstOrCell::Cell(cvar) => cvar.clone(), - ConstOrCell::Const(cst) => { - // extract the first bitlen bits - let bits = cst - .to_bits() - .iter() - .take(bitlen) - .copied() - // convert to ConstOrVar - .map(|b| ConstOrCell::Const(B::Field::from(b))) - .collect::>(); - - return Ok(Some(Var::new(bits, span))); - } - }; - - // convert value to bits - let mut bits = Vec::with_capacity(bitlen); - let mut e2 = B::Field::one(); - let mut lc: Option = None; - - for i in 0..bitlen { - let bit = compiler - .backend - .new_internal_var(Value::NthBit(val.clone(), i), span); - - // constrain it to be either 0 or 1 - // bits[i] * (bits[i] - 1 ) === 0; - boolean::check(compiler, &ConstOrCell::Cell(bit.clone()), span); - - // lc += bits[i] * e2; - let weighted_bit = compiler.backend.mul_const(&bit, &e2, span); - lc = if i == 0 { - Some(weighted_bit) - } else { - Some(compiler.backend.add(&lc.unwrap(), &weighted_bit, span)) - }; - - bits.push(bit.clone()); - e2 = e2 + e2; - } - - compiler.backend.assert_eq_var(&val, &lc.unwrap(), span); - - let bits_cvars = bits.into_iter().map(ConstOrCell::Cell).collect(); - Ok(Some(Var::new(bits_cvars, span))) -} - -fn from_bits( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // only one input var - assert_eq!(vars.len(), 1); - - let var_info = &vars[0]; - let bitlen = generics.get("LEN") as usize; - - let modulus_bits: usize = B::Field::modulus_biguint() - .bits() - .try_into() - .expect("modulus is too large"); - - assert!(bitlen <= (modulus_bits - 1)); - - let bits_vars: Vec<_> = var_info - .var - .cvars - .iter() - .map(|c| match c { - ConstOrCell::Cell(c) => c.clone(), - ConstOrCell::Const(cst) => { - // use a cell var to represent the const for now - // later we will refactor the backend handle ConstOrCell arguments, so we don't have deal with this everywhere - compiler - .backend - .add_constant(Some("converted constant"), *cst, span) - } - }) - .collect(); - - // this might not be necessary since it should be checked in the type checker - assert_eq!(bitlen, bits_vars.len()); - - let mut e2 = B::Field::one(); - let mut lc: Option = None; - - // accumulate the contribution of each bit - for bit in bits_vars { - let weighted_bit = compiler.backend.mul_const(&bit, &e2, span); - - lc = match lc { - None => Some(weighted_bit), - Some(v) => Some(compiler.backend.add(&v, &weighted_bit, span)), - }; - - e2 = e2 + e2; - } - - let cvar = ConstOrCell::Cell(lc.unwrap()); - - Ok(Some(Var::new_cvar(cvar, span))) -} - fn nth_bit( compiler: &mut CircuitWriter, _generics: &GenericParameters, @@ -206,16 +63,9 @@ fn nth_bit( } }; - // create a cell var for the symbolic value representing the nth bit. - // it seems we will always have to create cell vars to allocate the symbolic values that involve non-deterministic calculations. - // it is non-deterministic because it involves non-deterministic arithmetic on a cell var. let bit = compiler .backend .new_internal_var(Value::NthBit(val.clone(), nth), span); - // constrain it to be either 0 or 1 - // bit * (bit - 1 ) === 0; - // boolean::check(compiler, &ConstOrCell::Cell(bit.clone()), span); - Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index e89ee71a5..78ad860c6 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -89,7 +89,7 @@ pub fn init_stdlib_dep( node_id: usize, ) -> usize { // list the stdlib dependency in order - let libs = vec!["int", "bits", "comparator"]; + let libs = vec!["bits", "int", "comparator"]; let mut node_id = node_id; From 64c2352435c554b4d74df6f6cc1a2a80062ff6cd Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 17:09:06 +0800 Subject: [PATCH 14/36] download stdlib --- src/cli/cmd_build_and_check.rs | 29 ++++++++++++++++++-- src/cli/mod.rs | 3 ++ src/cli/packages.rs | 50 +++++++++++++++++++++++++++++++++- src/stdlib/mod.rs | 3 +- src/tests/examples.rs | 4 +-- src/tests/stdlib/mod.rs | 2 +- 6 files changed, 84 insertions(+), 7 deletions(-) diff --git a/src/cli/cmd_build_and_check.rs b/src/cli/cmd_build_and_check.rs index ffa456756..c8186365b 100644 --- a/src/cli/cmd_build_and_check.rs +++ b/src/cli/cmd_build_and_check.rs @@ -11,14 +11,16 @@ use crate::{ r1cs::{snarkjs::SnarkjsExporter, R1CS}, Backend, BackendField, BackendKind, }, - cli::packages::path_to_package, + cli::packages::{path_to_package, path_to_stdlib}, compiler::{compile, generate_witness, typecheck_next_file, Sources}, inputs::{parse_inputs, JsonInputs}, + stdlib::init_stdlib_dep, type_checker::TypeChecker, }; use super::packages::{ - get_deps_of_package, is_lib, validate_package_and_get_manifest, DependencyGraph, UserRepo, + download_stdlib, get_deps_of_package, is_lib, + validate_package_and_get_manifest, DependencyGraph, UserRepo, }; const COMPILED_DIR: &str = "compiled"; @@ -137,6 +139,26 @@ pub fn cmd_check(args: CmdCheck) -> miette::Result<()> { Ok(()) } +fn add_stdlib( + sources: &mut Sources, + tast: &mut TypeChecker, + node_id: usize, +) -> miette::Result { + let mut node_id = node_id; + + // check if the release folder exists, otherwise download the latest release + // todo: check the latest version and compare it with the current version, to decide if download is needed + let stdlib_dir = path_to_stdlib(); + + if !stdlib_dir.exists() { + download_stdlib()?; + } + + node_id = init_stdlib_dep(sources, tast, node_id, stdlib_dir.as_ref()); + + Ok(node_id) +} + fn produce_all_asts(path: &PathBuf) -> miette::Result<(Sources, TypeChecker)> { // find manifest let manifest = validate_package_and_get_manifest(&path, false)?; @@ -161,6 +183,9 @@ fn produce_all_asts(path: &PathBuf) -> miette::Result<(Sources, Type let mut tast = TypeChecker::new(); + // adding stdlib + add_stdlib(&mut sources, &mut tast, node_id)?; + for dep in dep_graph.from_leaves_to_roots() { let path = path_to_package(&dep); diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 84fb80ee8..6e470dd29 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -15,3 +15,6 @@ pub const NONAME_DIRECTORY: &str = ".noname"; /// The directory under [NONAME_DIRECTORY] containing all package-related files. pub const PACKAGE_DIRECTORY: &str = "packages"; + +/// The directory under [NONAME_DIRECTORY] containing all the latest noname release. +pub const RELEASE_DIRECTORY: &str = "release"; diff --git a/src/cli/packages.rs b/src/cli/packages.rs index 03060c6d2..28effedcd 100644 --- a/src/cli/packages.rs +++ b/src/cli/packages.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use super::{ manifest::{read_manifest, Manifest}, - NONAME_DIRECTORY, PACKAGE_DIRECTORY, + NONAME_DIRECTORY, PACKAGE_DIRECTORY, RELEASE_DIRECTORY, }; /// A dependency is a Github `user/repo` pair. @@ -241,6 +241,16 @@ pub(crate) fn path_to_package(dep: &UserRepo) -> PathBuf { package_dir.join(&dep.user).join(&dep.repo) } +pub(crate) fn path_to_stdlib() -> PathBuf { + let home_dir: PathBuf = dirs::home_dir() + .expect("could not find home directory of current user") + .try_into() + .expect("invalid UTF8 path"); + let noname_dir = home_dir.join(NONAME_DIRECTORY); + + noname_dir.join(RELEASE_DIRECTORY).join("src/stdlib/native") +} + /// download package from github pub fn download_from_github(dep: &UserRepo) -> Result<()> { let url = format!( @@ -264,6 +274,44 @@ pub fn download_from_github(dep: &UserRepo) -> Result<()> { Ok(()) } +pub fn download_stdlib() -> Result<()> { + // Hardcoded repository details and target branch + let repo_owner = "katat"; + let repo_name = "noname"; + let target_branch = "release"; + let repo_url = format!( + "https://github.com/{owner}/{repo}.git", + owner = repo_owner, + repo = repo_name + ); + + let home_dir: PathBuf = dirs::home_dir() + .expect("could not find home directory of current user") + .try_into() + .expect("invalid UTF8 path"); + let noname_dir = home_dir.join(NONAME_DIRECTORY); + let release_dir = noname_dir.join("release"); + + // Clone the repository and checkout the specified branch to the temporary directory + let output = process::Command::new("git") + .arg("clone") + .arg("--branch") + .arg(target_branch) + .arg("--single-branch") + .arg(repo_url) + .arg(release_dir) + .output() + .expect("failed to execute git clone command"); + + if !output.status.success() { + miette::bail!(format!( + "Could not clone branch `{target_branch}` of repository `{repo_owner}/{repo_name}`." + )); + } + + Ok(()) +} + pub fn is_lib(path: &PathBuf) -> bool { path.join("src").join("lib.no").exists() } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 78ad860c6..b66c56f27 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -87,6 +87,7 @@ pub fn init_stdlib_dep( sources: &mut Sources, tast: &mut TypeChecker, node_id: usize, + path_prefix: &str, ) -> usize { // list the stdlib dependency in order let libs = vec!["bits", "int", "comparator"]; @@ -95,7 +96,7 @@ pub fn init_stdlib_dep( for lib in libs { let module = UserRepo::new(&format!("std/{}", lib)); - let prefix_stdlib = Path::new("src/stdlib/native/"); + let prefix_stdlib = Path::new(path_prefix); let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); node_id = typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); diff --git a/src/tests/examples.rs b/src/tests/examples.rs index ddffba037..1aa9a31f3 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -38,7 +38,7 @@ fn test_file( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -102,7 +102,7 @@ fn test_file( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 92a114847..107d44fe9 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -36,7 +36,7 @@ fn test_stdlib( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); let this_module = None; let _node_id = typecheck_next_file( From 3c0b022937370883118ecc926c18411928a26550 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 17:30:31 +0800 Subject: [PATCH 15/36] add comparator methods for uint8 --- src/stdlib/mod.rs | 2 +- src/stdlib/native/bits.no | 25 ------------------- src/stdlib/native/comparator.no | 4 --- src/stdlib/native/int.no | 10 +++++++- .../less_eq_than/less_eq_than_main.no | 6 ++++- .../comparator/less_than/less_than_main.no | 9 ++++--- 6 files changed, 20 insertions(+), 36 deletions(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index b66c56f27..cde82ead8 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -90,7 +90,7 @@ pub fn init_stdlib_dep( path_prefix: &str, ) -> usize { // list the stdlib dependency in order - let libs = vec!["bits", "int", "comparator"]; + let libs = vec!["bits", "comparator", "int"]; let mut node_id = node_id; diff --git a/src/stdlib/native/bits.no b/src/stdlib/native/bits.no index 65825cdf1..d991b04de 100644 --- a/src/stdlib/native/bits.no +++ b/src/stdlib/native/bits.no @@ -1,6 +1,5 @@ hint fn nth_bit(value: Field, const nth: Field) -> Field; -// obviously writing this in native is much simpler than the builtin version fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { let mut bits = [false; LEN]; let mut lc1 = 0; @@ -14,31 +13,7 @@ fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { let false_val = false; for index in 0..LEN { - // maybe add a unconstrained / unsafe attribute before bits::nth_bit, such that: - // bits[index] = unsafe bits::nth_bit(value, index); - // here we need to ensure the related variables are constrained: - // 1. value: constrained to be equal with the sum of bits, which involves the index as well - // 2. index: a cell index in bits - // 3. bits: the output bits - // beyond the notation purpose, what security measures can we take to help guide this unsafe operation? - // one idea is to rely on this unsafe attribute to check if it is non-deterministic when constraining the bits[index] - // eg. - // - bits::nth_bit(value, index) is non-deterministic - // - a metadata can be added to the var of the bits as non-deterministic - // - when CS trying to constrain the non-deterministic var, - // it will raise an error if the var is not marked unsafe via the attribute unsafe - // thus, it seems we also need to add the attribute to the builtin function signature - // eg. `unsafe nth_bit(val: Field, const nth: Field) -> Bool` - // while the unsafe attribute in `bits[index] = unsafe bits::nth_bit(value, index);` - // is for the users to acknowledge they are responsible for having additional constraints. - // This approach makes it explicit whether an expression is non-deterministic at the first place. - // On the other hand, circom lang determines whether it is non-deterministic by folding the arithmetic operation. - let bit_num = unsafe nth_bit(value, index); - // nth_bit is a hint function, and it doesn't constraint the value of the bits as boolean, - // although its return type is boolean. - // can we make the arithmetic operation compatible with boolean? - // or just make a stdlib to convert boolean to Field while adding the constraint? // constrain the bit_num to be 0 or 1 assert_eq(bit_num * (bit_num - 1), 0); diff --git a/src/stdlib/native/comparator.no b/src/stdlib/native/comparator.no index 43a4e14bf..2a71890db 100644 --- a/src/stdlib/native/comparator.no +++ b/src/stdlib/native/comparator.no @@ -38,7 +38,3 @@ fn less_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { fn less_eq_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { return less_than(LEN, lhs, rhs + 1); } - -fn uint8_less_than(lhs: int::Uint8, rhs: int::Uint8) -> Bool { - return less_than(8, lhs.inner, rhs.inner); -} \ No newline at end of file diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 9b0bc1067..6cbac6de7 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -1,7 +1,7 @@ use std::bits; +use std::comparator; struct Uint8 { - // todo: maybe add a const attribute to Field to forbid reassignment inner: Field, bit_len: Field, } @@ -14,4 +14,12 @@ fn Uint8.new(val: Field) -> Uint8 { inner: val, bit_len: 8 }; +} + +fn Uint8.less_than(self, rhs: Uint8) -> Bool { + return comparator::less_than(8, self.inner, rhs.inner); +} + +fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { + return comparator::less_eq_than(8, self.inner, rhs.inner); } \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no index 5f998ae9f..f46cbfb38 100644 --- a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no @@ -1,6 +1,10 @@ use std::comparator; +use std::int; fn main(pub lhs: Field, rhs: Field) -> Bool { - let res = comparator::less_eq_than(2, lhs, rhs); + let lhs_u8 = int::Uint8.new(lhs); + let rhs_u8 = int::Uint8.new(rhs); + + let res = lhs_u8.less_eq_than(rhs_u8); return res; } \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no index a4eb8e7b9..bb8cb1f0c 100644 --- a/src/tests/stdlib/comparator/less_than/less_than_main.no +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -2,9 +2,10 @@ use std::comparator; use std::int; fn main(pub lhs: Field, rhs: Field) -> Bool { - // todo bug: this also throws error "method call only work on custom types" - let lhs_bigint = int::Uint8.new(lhs); - let rhs_bigint = int::Uint8.new(rhs); - let res = comparator::uint8_less_than(lhs_bigint, rhs_bigint); + let lhs_u8 = int::Uint8.new(lhs); + let rhs_u8 = int::Uint8.new(rhs); + + let res = lhs_u8.less_than(rhs_u8); + return res; } \ No newline at end of file From f54a53c2d4981a89388ffb2e8dd0d8c76286f66e Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 17:45:16 +0800 Subject: [PATCH 16/36] fmt --- src/cli/cmd_build_and_check.rs | 4 ++-- src/stdlib/bits.rs | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/cli/cmd_build_and_check.rs b/src/cli/cmd_build_and_check.rs index c8186365b..20ff93dba 100644 --- a/src/cli/cmd_build_and_check.rs +++ b/src/cli/cmd_build_and_check.rs @@ -19,8 +19,8 @@ use crate::{ }; use super::packages::{ - download_stdlib, get_deps_of_package, is_lib, - validate_package_and_get_manifest, DependencyGraph, UserRepo, + download_stdlib, get_deps_of_package, is_lib, validate_package_and_get_manifest, + DependencyGraph, UserRepo, }; const COMPILED_DIR: &str = "compiled"; diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 8baa88f65..cc5898bfc 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -21,9 +21,7 @@ impl Module for BitsLib { const MODULE: &'static str = "bits"; fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![ - (NTH_BIT_FN, nth_bit), - ] + vec![(NTH_BIT_FN, nth_bit)] } } From 5c8918edef242e98fad325be4399b00011078f5e Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 11 Oct 2024 18:09:21 +0800 Subject: [PATCH 17/36] remove unused import --- src/stdlib/native/comparator.no | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stdlib/native/comparator.no b/src/stdlib/native/comparator.no index 2a71890db..dd11c6c62 100644 --- a/src/stdlib/native/comparator.no +++ b/src/stdlib/native/comparator.no @@ -1,5 +1,4 @@ use std::bits; -use std::int; // Instead of comparing bit by bit, we check the carry bit: // lhs + (1 << LEN) - rhs From 18825a3a4a53572bd038c679e823ffc5d2e17ff9 Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Tue, 22 Oct 2024 19:57:45 +0800 Subject: [PATCH 18/36] Fix: Allow struct fields to propagate constants (#204) * fix: make struct field propagate constants --- src/mast/mod.rs | 252 ++++++++++++++++++++++++++---------- src/negative_tests.rs | 52 ++++++++ src/stdlib/native/int.no | 10 +- src/type_checker/checker.rs | 32 +++++ src/type_checker/mod.rs | 23 ++++ 5 files changed, 300 insertions(+), 69 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index f16a2cbac..89ad4cc6b 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -9,7 +9,8 @@ use crate::{ imports::FnKind, parser::{ types::{ - FnSig, ForLoopArgument, GenericParameters, Range, Stmt, StmtKind, Symbolic, Ty, TyKind, + FnSig, ForLoopArgument, GenericParameters, Ident, Range, Stmt, StmtKind, Symbolic, Ty, + TyKind, }, CustomType, Expr, ExprKind, FunctionDef, Op2, }, @@ -29,19 +30,51 @@ pub struct ExprMonoInfo { /// The generic types shouldn't be presented in this field. pub typ: Option, - // todo: see if we can do constant folding on the expression nodes. - // - it is possible to remove this field, as the constant value can be extracted from folded expression node - /// Numeric value of the expression - /// applicable to BigInt type - pub constant: Option, + /// Propagated constant value + pub constant: Option, } -impl ExprMonoInfo { - pub fn new(expr: Expr, typ: Option, value: Option) -> Self { - if value.is_some() && !matches!(typ, Some(TyKind::Field { constant: true })) { - panic!("value can only be set for BigInt type"); +#[derive(Debug, Clone)] +pub enum PropagatedConstant { + Single(u32), + Array(Vec), + Custom(HashMap), +} + +impl PropagatedConstant { + pub fn as_single(&self) -> u32 { + match self { + PropagatedConstant::Single(v) => *v, + _ => panic!("expected single value"), } + } + pub fn as_array(&self) -> Vec { + match self { + PropagatedConstant::Array(v) => v.iter().map(|c| c.as_single()).collect(), + _ => panic!("expected array value"), + } + } + + pub fn as_custom(&self) -> HashMap { + match self { + PropagatedConstant::Custom(v) => { + v.iter().map(|(k, c)| (k.clone(), c.as_single())).collect() + } + _ => panic!("expected custom value"), + } + } +} + +/// impl From trait for single value +impl From for PropagatedConstant { + fn from(v: u32) -> Self { + PropagatedConstant::Single(v) + } +} + +impl ExprMonoInfo { + pub fn new(expr: Expr, typ: Option, value: Option) -> Self { Self { expr, typ, @@ -68,18 +101,18 @@ pub struct MTypeInfo { pub typ: TyKind, /// Store constant value - pub value: Option, + pub constant: Option, /// The span of the variable declaration. pub span: Span, } impl MTypeInfo { - pub fn new(typ: &TyKind, span: Span, value: Option) -> Self { + pub fn new(typ: &TyKind, span: Span, value: Option) -> Self { Self { typ: typ.clone(), span, - value, + constant: value, } } } @@ -216,11 +249,11 @@ impl FnSig { } // const NN: Field _ => { - let cst = observed_arg.constant; + let cst = observed_arg.constant.clone(); if is_generic_parameter(sig_arg.name.value.as_str()) && cst.is_some() { self.generics.assign( &sig_arg.name.value, - cst.unwrap(), + cst.unwrap().as_single(), observed_arg.expr.span, )?; } @@ -269,6 +302,10 @@ where functions_instantiated: HashMap, // new method name as the key, old method name as the value methods_instantiated: HashMap<(FullyQualified, String), String>, + // cache for [PropagatedConstant] values from instantiated methods + cst_method_cache: HashMap<(FullyQualified, String), PropagatedConstant>, + // cache for [PropagatedConstant] values from instantiated functions + cst_fn_cache: HashMap, } impl MastCtx { @@ -278,6 +315,8 @@ impl MastCtx { generic_func_scope: Some(0), functions_instantiated: HashMap::new(), methods_instantiated: HashMap::new(), + cst_method_cache: HashMap::new(), + cst_fn_cache: HashMap::new(), } } @@ -509,7 +548,12 @@ fn monomorphize_expr( }, ); - let cst = None; + // propagate the constant value + let cst = lhs_mono.constant.and_then(|c| match c { + PropagatedConstant::Custom(map) => map.get(rhs).cloned(), + _ => None, + }); + ExprMonoInfo::new(mexpr, typ, cst) } @@ -558,10 +602,19 @@ fn monomorphize_expr( .as_ref() .and_then(|sig| sig.return_type.clone().map(|t| t.kind)); - ExprMonoInfo::new(mexpr, typ, None) + // retrieve the constant value from the cache + let cst = ctx.cst_fn_cache.get(&mono_qualified).cloned(); + + ExprMonoInfo::new(mexpr, typ, cst) } else { // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ, cst) = + instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + + // cache the constant value + if let Some(cst) = cst.clone() { + ctx.cst_fn_cache.insert(mono_qualified.clone(), cst); + } let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( @@ -577,7 +630,7 @@ fn monomorphize_expr( let new_qualified = FullyQualified::new(module, &fn_name_mono.value); ctx.add_monomorphized_fn(old_qualified, new_qualified, fn_info_mono); - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, cst) } } @@ -649,10 +702,23 @@ fn monomorphize_expr( }, ); let typ = resolved_sig.return_type.clone().map(|t| t.kind); - ExprMonoInfo::new(mexpr, typ, None) + + // retrieve the constant value from the cache + let cst = ctx + .cst_method_cache + .get(&(struct_qualified.clone(), method_name.value.clone())) + .cloned(); + + ExprMonoInfo::new(mexpr, typ, cst) } else { // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ, cst) = + instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + // cache the constant value + if let Some(cst) = cst.clone() { + ctx.cst_method_cache + .insert((struct_qualified.clone(), method_name.value.clone()), cst); + } let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( @@ -668,7 +734,7 @@ fn monomorphize_expr( ctx.tast .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, cst) } } @@ -739,9 +805,14 @@ fn monomorphize_expr( Some(v) => { let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); - ExprMonoInfo::new(mexpr, typ, v.to_u32()) + ExprMonoInfo::new( + mexpr, + typ, + Some(PropagatedConstant::from(v.to_u32().unwrap())), + ) } - None => { + // keep as is + _ => { let mexpr = expr.to_mast( ctx, &ExprKind::BinaryOp { @@ -778,7 +849,11 @@ fn monomorphize_expr( let cst: u32 = inner.try_into().expect("biguint too large"); let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(inner.clone())); - ExprMonoInfo::new(mexpr, Some(TyKind::Field { constant: true }), Some(cst)) + ExprMonoInfo::new( + mexpr, + Some(TyKind::Field { constant: true }), + Some(PropagatedConstant::from(cst)), + ) } ExprKind::Bool(inner) => { @@ -794,10 +869,10 @@ fn monomorphize_expr( let res = if is_generic_parameter(&name.value) { let mtype = mono_fn_env.get_type_info(&name.value).unwrap(); - let mexpr = - expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(mtype.value.unwrap()))); + let cst = mtype.constant.clone().unwrap().as_single(); + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(cst))); - ExprMonoInfo::new(mexpr, Some(mtype.typ.clone()), mtype.value) + ExprMonoInfo::new(mexpr, Some(mtype.typ.clone()), mtype.constant.clone()) } else if is_type(&name.value) { let mtype = TyKind::Custom { module: module.clone(), @@ -820,7 +895,11 @@ fn monomorphize_expr( let cst: u32 = bigint.clone().try_into().expect("biguint too large"); let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(bigint)); - ExprMonoInfo::new(mexpr, Some(TyKind::Field { constant: true }), Some(cst)) + ExprMonoInfo::new( + mexpr, + Some(TyKind::Field { constant: true }), + Some(PropagatedConstant::from(cst)), + ) } else { // otherwise it's a local variable let mexpr = expr.to_mast( @@ -832,7 +911,7 @@ fn monomorphize_expr( ); let mtype = mono_fn_env.get_type_info(&name.value).unwrap().clone(); - ExprMonoInfo::new(mexpr, Some(mtype.typ), mtype.value) + ExprMonoInfo::new(mexpr, Some(mtype.typ), mtype.constant) }; res @@ -955,24 +1034,46 @@ fn monomorphize_expr( )); } - fields_mono.push((ident, observed_mono.expr.clone())); + fields_mono.push(( + ident, + observed_mono.expr.clone(), + observed_mono.constant.clone(), + )); } let mexpr = expr.to_mast( ctx, &ExprKind::CustomTypeDeclaration { custom: custom.clone(), - fields: fields_mono, + // extract a tuple of first two elements + fields: fields_mono + .iter() + .map(|(a, b, _)| (a.clone(), b.clone())) + .collect(), }, ); + let cst_fields = { + let csts = HashMap::from_iter( + fields_mono + .into_iter() + .filter(|(_, _, cst)| cst.is_some()) + .map(|(ident, _, cst)| (ident, cst.unwrap())), + ); + if csts.is_empty() { + None + } else { + Some(PropagatedConstant::Custom(csts)) + } + }; + ExprMonoInfo::new( mexpr, Some(TyKind::Custom { module: module.clone(), name: name.clone(), }), - None, + cst_fields, ) } ExprKind::RepeatedArrayInit { item, size } => { @@ -989,7 +1090,7 @@ fn monomorphize_expr( ); if let Some(cst) = size_mono.constant { - let arr_typ = TyKind::Array(Box::new(item_typ), cst); + let arr_typ = TyKind::Array(Box::new(item_typ), cst.as_single()); ExprMonoInfo::new(mexpr, Some(arr_typ), None) } else { return Err(error(ErrorKind::InvalidArraySize, expr.span)); @@ -1011,23 +1112,30 @@ pub fn monomorphize_block( mono_fn_env: &mut MonomorphizedFnEnv, stmts: &[Stmt], expected_return: Option<&Ty>, -) -> Result<(Vec, Option)> { +) -> Result<(Vec, Option)> { mono_fn_env.nest(); - let mut return_typ = None; + let mut ret_expr_mono = None; let mut stmts_mono = vec![]; for stmt in stmts { - if let Some((stmt, ret_typ)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? { + if let Some((stmt, expr_mono)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? { stmts_mono.push(stmt); - if ret_typ.is_some() { - return_typ = ret_typ; + // only return stmt can return `ExprMonoInfo` which contains propagated constants + if expr_mono.is_some() { + ret_expr_mono = expr_mono; } } } + let return_typ = if let Some(expr_mono) = ret_expr_mono.clone() { + expr_mono.typ + } else { + None + }; + // check the return if let (Some(expected), Some(observed)) = (expected_return, return_typ.clone()) { if !observed.match_expected(&expected.kind, true) { @@ -1040,7 +1148,7 @@ pub fn monomorphize_block( mono_fn_env.pop(); - Ok((stmts_mono, return_typ)) + Ok((stmts_mono, ret_expr_mono)) } /// Monomorphize a statement. @@ -1048,7 +1156,7 @@ pub fn monomorphize_stmt( ctx: &mut MastCtx, mono_fn_env: &mut MonomorphizedFnEnv, stmt: &Stmt, -) -> Result)>> { +) -> Result)>> { let res = match &stmt.kind { StmtKind::Assign { mutable, lhs, rhs } => { let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?; @@ -1093,7 +1201,9 @@ pub fn monomorphize_stmt( return Err(error(ErrorKind::InvalidRangeSize, stmt.span)); } - if start_mono.constant.unwrap() > end_mono.constant.unwrap() { + if start_mono.constant.unwrap().as_single() + > end_mono.constant.unwrap().as_single() + { return Err(error(ErrorKind::InvalidRangeSize, stmt.span)); } @@ -1147,11 +1257,11 @@ pub fn monomorphize_stmt( StmtKind::Return(res) => { let expr_mono = monomorphize_expr(ctx, res, mono_fn_env)?; let stmt_mono = Stmt { - kind: StmtKind::Return(Box::new(expr_mono.expr)), + kind: StmtKind::Return(Box::new(expr_mono.expr.clone())), span: stmt.span, }; - Some((stmt_mono, expr_mono.typ)) + Some((stmt_mono, Some(expr_mono))) } StmtKind::Comment(_) => None, }; @@ -1168,7 +1278,7 @@ pub fn instantiate_fn_call( fn_info: FnInfo, args: &[ExprMonoInfo], span: Span, -) -> Result<(FnInfo, Option)> { +) -> Result<(FnInfo, Option, Option)> { ctx.start_monomorphize_func(); let fn_sig = fn_info.sig(); @@ -1192,7 +1302,11 @@ pub fn instantiate_fn_call( let val = fn_sig.generics.get(gen); mono_fn_env.store_type( gen, - &MTypeInfo::new(&TyKind::Field { constant: true }, span, Some(val)), + &MTypeInfo::new( + &TyKind::Field { constant: true }, + span, + Some(PropagatedConstant::from(val)), + ), )?; } @@ -1208,7 +1322,7 @@ pub fn instantiate_fn_call( let typ = mono_info.typ.as_ref().expect("expected a value"); mono_fn_env.store_type( arg_name, - &MTypeInfo::new(typ, mono_info.expr.span, mono_info.constant), + &MTypeInfo::new(typ, mono_info.expr.span, mono_info.constant.clone()), )?; } @@ -1216,32 +1330,40 @@ pub fn instantiate_fn_call( let ret_typed = sig_typed.return_type.clone(); // construct the monomorphized function AST - let func_def = match fn_info.kind { - FnKind::BuiltIn(_, handle) => FnInfo { - kind: FnKind::BuiltIn(sig_typed, handle), - is_hint: fn_info.is_hint, - span: fn_info.span, - }, + let (func_def, mono_info) = match fn_info.kind { + FnKind::BuiltIn(_, handle) => ( + FnInfo { + kind: FnKind::BuiltIn(sig_typed, handle), + ..fn_info + }, + // todo: we will need to propagate the constant value from builtin function as well + None, + ), FnKind::Native(fn_def) => { - let (stmts_typed, _) = + let (stmts_typed, mono_info) = monomorphize_block(ctx, mono_fn_env, &fn_def.body, ret_typed.as_ref())?; - FnInfo { - kind: FnKind::Native(FunctionDef { - sig: sig_typed, - body: stmts_typed, - span: fn_def.span, - is_hint: fn_def.is_hint, - }), - is_hint: fn_info.is_hint, - span: fn_info.span, - } + ( + FnInfo { + kind: FnKind::Native(FunctionDef { + sig: sig_typed, + body: stmts_typed, + span: fn_def.span, + is_hint: fn_def.is_hint, + }), + is_hint: fn_info.is_hint, + span: fn_info.span, + }, + mono_info, + ) } }; ctx.finish_monomorphize_func(); - Ok((func_def, ret_typed.map(|t| t.kind))) + let cst = mono_info.and_then(|c| c.constant); + + Ok((func_def, ret_typed.map(|t| t.kind), cst)) } pub fn error(kind: ErrorKind, span: Span) -> Error { Error::new("mast", kind, span) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 1ed4dc063..0759a6333 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -498,3 +498,55 @@ fn test_nonhint_call_with_unsafe() { ErrorKind::UnexpectedUnsafeAttribute )); } + +#[test] +fn test_no_cst_struct_field_prop() { + let code = r#" + struct Thing { + val: Field, + } + + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let thing = Thing { val: xx }; + + let arr = gen(thing.val); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} + +#[test] +fn test_mut_cst_struct_field_prop() { + let code = r#" + struct Thing { + val: Field, + } + + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let mut thing = Thing { val: 3 }; + thing.val = xx; + + let arr = gen(thing.val); + assert_eq(arr[0], xx); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 6cbac6de7..84bd7a98e 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -7,19 +7,21 @@ struct Uint8 { } fn Uint8.new(val: Field) -> Uint8 { + let bit_len = 8; + // range check - let ignore_ = bits::to_bits(8, val); + let ignore_ = bits::to_bits(bit_len, val); return Uint8 { inner: val, - bit_len: 8 + bit_len: bit_len }; } fn Uint8.less_than(self, rhs: Uint8) -> Bool { - return comparator::less_than(8, self.inner, rhs.inner); + return comparator::less_than(self.bit_len, self.inner, rhs.inner); } fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { - return comparator::less_eq_than(8, self.inner, rhs.inner); + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); } \ No newline at end of file diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 74a8ace86..d5fb3f54e 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -224,6 +224,7 @@ impl TypeChecker { .compute_type(lhs, typed_fn_env)? .expect("type-checker bug: lhs access on an empty var"); + // todo: check and update the const field type for other cases // lhs can be a local variable or a path to an array let lhs_name = match &lhs.kind { // `name = ` @@ -290,6 +291,31 @@ impl TypeChecker { )); } + // update struct field type + if let ExprKind::FieldAccess { + lhs, + rhs: field_name, + } = &lhs.kind + { + // get variable behind lhs + let lhs_node = self + .compute_type(lhs, typed_fn_env)? + .expect("type-checker bug: lhs access on an empty var"); + + // obtain the qualified name of the struct + let (module, struct_name) = match lhs_node.typ { + TyKind::Custom { module, name } => (module, name), + _ => { + return Err( + self.error(ErrorKind::FieldAccessOnNonCustomStruct, lhs.span) + ) + } + }; + + let qualified = FullyQualified::new(&module, &struct_name); + self.update_struct_field(&qualified, &field_name.value, rhs_typ.typ); + } + None } @@ -538,6 +564,12 @@ impl TypeChecker { expr.span, )); } + + // If the observed type is a Field type, then init that struct field as the observed type. + // This is because the field type can be a constant or not, which needs to be propagated. + if matches!(observed_typ.typ, TyKind::Field { .. }) { + self.update_struct_field(&qualified, &defined.0, observed_typ.typ.clone()); + } } let res = ExprTyInfo::new_anon(TyKind::Custom { diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 2b826f8a2..7d9da5061 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -150,6 +150,29 @@ impl TypeChecker { .expect("couldn't find the struct for storing the method"); struct_info.methods.remove(method_name); } + + /// Update the type of a struct field. + /// When the assignment is done, we need to update the type of the field. + /// This is only for the case of updating field types to either a constant or a variable. + pub fn update_struct_field( + &mut self, + qualified: &FullyQualified, + field_name: &str, + typ: TyKind, + ) { + let struct_info = self + .structs + .get_mut(qualified) + .expect("couldn't find the struct for storing the method"); + + // update the field type + for field in struct_info.fields.iter_mut() { + if field.0 == field_name { + field.1 = typ; + return; + } + } + } } impl TypeChecker { From ca35175d8f8d525fa9993f88c3141bd126934db7 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 18 Oct 2024 14:51:02 +0800 Subject: [PATCH 19/36] add uint16/32/64 --- src/stdlib/native/int.no | 81 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 84bd7a98e..4390c3db6 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -1,6 +1,7 @@ use std::bits; use std::comparator; +// u8 struct Uint8 { inner: Field, bit_len: Field, @@ -18,10 +19,90 @@ fn Uint8.new(val: Field) -> Uint8 { }; } +// u16 +struct Uint16 { + inner: Field, + bit_len: Field, +} + +fn Uint16.new(val: Field) -> Uint16 { + let bit_len = 16; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint16 { + inner: val, + bit_len: bit_len + }; +} + +// u32 +struct Uint32 { + inner: Field, + bit_len: Field, +} + +fn Uint32.new(val: Field) -> Uint32 { + let bit_len = 32; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint32 { + inner: val, + bit_len: bit_len + }; +} + +// u64 +struct Uint64 { + inner: Field, + bit_len: Field, +} + +fn Uint64.new(val: Field) -> Uint64 { + let bit_len = 64; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint64 { + inner: val, + bit_len: bit_len + }; +} + +// implement comparator + fn Uint8.less_than(self, rhs: Uint8) -> Bool { return comparator::less_than(self.bit_len, self.inner, rhs.inner); } fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint16.less_than(self, rhs: Uint16) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint32.less_than(self, rhs: Uint32) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint64.less_than(self, rhs: Uint64) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); } \ No newline at end of file From 55f441a8b0b81214465a58e824a724b1430bceba Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 18 Oct 2024 14:51:09 +0800 Subject: [PATCH 20/36] update tests --- .../less_eq_than/less_eq_than_main.no | 6 +- .../stdlib/comparator/less_than/less_than.asm | 27 ++- .../comparator/less_than/less_than_main.no | 7 +- src/tests/stdlib/comparator/mod.rs | 182 +++++++++++++----- src/tests/stdlib/mod.rs | 38 +++- 5 files changed, 189 insertions(+), 71 deletions(-) diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no index f46cbfb38..f916a96a2 100644 --- a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no @@ -1,10 +1,8 @@ use std::comparator; -use std::int; fn main(pub lhs: Field, rhs: Field) -> Bool { - let lhs_u8 = int::Uint8.new(lhs); - let rhs_u8 = int::Uint8.new(rhs); + let bit_len = 2; + let res = comparator::less_eq_than(bit_len, lhs, rhs); - let res = lhs_u8.less_eq_than(rhs_u8); return res; } \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm index 8c33ac348..beac8c576 100644 --- a/src/tests/stdlib/comparator/less_than/less_than.asm +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -3,9 +3,24 @@ v_5 == (v_4) * (v_4 + -1) 0 == (v_5) * (1) -v_7 == (v_6) * (v_6 + -1) -0 == (v_7) * (1) -v_9 == (v_8) * (v_8 + -1) -0 == (v_9) * (1) -v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 4) * (1) --1 * v_8 + 1 == (v_1) * (1) +1 == (v_6) * (1) +v_8 == (v_7) * (-1 * v_4 + v_6) +-1 * v_9 + 1 == (v_8) * (1) +v_10 == (v_9) * (-1 * v_4 + v_6) +0 == (v_10) * (1) +v_12 == (v_11) * (v_11 + -1) +0 == (v_12) * (1) +1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_11 + v_13) +-1 * v_16 + 1 == (v_15) * (1) +v_17 == (v_16) * (-1 * v_11 + v_13) +0 == (v_17) * (1) +v_19 == (v_18) * (v_18 + -1) +0 == (v_19) * (1) +1 == (v_20) * (1) +v_22 == (v_21) * (-1 * v_18 + v_20) +-1 * v_23 + 1 == (v_22) * (1) +v_24 == (v_23) * (-1 * v_18 + v_20) +0 == (v_24) * (1) +v_2 + -1 * v_3 + 4 == (v_9 + 2 * v_16 + 4 * v_23) * (1) +-1 * v_23 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no index bb8cb1f0c..e39d9cf95 100644 --- a/src/tests/stdlib/comparator/less_than/less_than_main.no +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -1,11 +1,8 @@ use std::comparator; -use std::int; fn main(pub lhs: Field, rhs: Field) -> Bool { - let lhs_u8 = int::Uint8.new(lhs); - let rhs_u8 = int::Uint8.new(rhs); - - let res = lhs_u8.less_than(rhs_u8); + let bit_len = 2; + let res = comparator::less_than(bit_len, lhs, rhs); return res; } \ No newline at end of file diff --git a/src/tests/stdlib/comparator/mod.rs b/src/tests/stdlib/comparator/mod.rs index bc85f91fe..ed4ce618c 100644 --- a/src/tests/stdlib/comparator/mod.rs +++ b/src/tests/stdlib/comparator/mod.rs @@ -1,89 +1,177 @@ -use crate::error; +use crate::error::{self, ErrorKind}; -use super::test_stdlib; +use super::{test_stdlib, test_stdlib_code}; use error::Result; +use rstest::rstest; -#[test] -fn test_less_than_true() -> Result<()> { - let public_inputs = r#"{"lhs": "0"}"#; - let private_inputs = r#"{"rhs": "1"}"#; +// code template +static LESS_THAN_TPL: &str = r#" +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let lhs_u = int::{}.new(lhs); + let rhs_u = int::{}.new(rhs); + + let res = lhs_u.less_than(rhs_u); + + return res; +} +"#; +static LESS_THAN_EQ_TPL: &str = r#" +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let lhs_u = int::{}.new(lhs); + let rhs_u = int::{}.new(rhs); + + let res = lhs_u.less_eq_than(rhs_u); + + return res; +} +"#; + +#[rstest] +#[case(r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_less_than( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { test_stdlib( "comparator/less_than/less_than_main.no", - "comparator/less_than/less_than.asm", + Some("comparator/less_than/less_than.asm"), public_inputs, private_inputs, - vec!["1"], + expected_output, )?; Ok(()) } -// test false #[test] -fn test_less_than_false() -> Result<()> { - let public_inputs = r#"{"lhs": "1"}"#; +fn test_less_than_witness_failure() -> Result<()> { + let public_inputs = r#"{"lhs": "4"}"#; let private_inputs = r#"{"rhs": "0"}"#; - test_stdlib( + let err = test_stdlib( "comparator/less_than/less_than_main.no", - "comparator/less_than/less_than.asm", + None, public_inputs, private_inputs, - vec!["0"], - )?; + vec![], + ) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); Ok(()) } -#[test] -fn test_less_eq_than_true_1() -> Result<()> { - let public_inputs = r#"{"lhs": "0"}"#; - let private_inputs = r#"{"rhs": "1"}"#; - - test_stdlib( - "comparator/less_eq_than/less_eq_than_main.no", - "comparator/less_eq_than/less_eq_than.asm", - public_inputs, - private_inputs, - vec!["1"], - )?; +#[rstest] +#[case("Uint8", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint16", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +#[case("Uint32", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint64", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_uint_less_than( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + // Replace placeholders with the given integer type. + let code = LESS_THAN_TPL.replace("{}", int_type); + + // Call the test function with the given inputs and expected output. + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; Ok(()) } -#[test] -fn test_less_eq_than_true_2() -> Result<()> { - let public_inputs = r#"{"lhs": "1"}"#; - let private_inputs = r#"{"rhs": "1"}"#; +#[rstest] +#[case("Uint8", r#"{"lhs": "256"}"#, r#"{"rhs": "0"}"#)] // Uint8 overflow +#[case("Uint16", r#"{"lhs": "65536"}"#, r#"{"rhs": "0"}"#)] // Uint16 overflow +#[case("Uint32", r#"{"lhs": "4294967296"}"#, r#"{"rhs": "0"}"#)] // Uint32 overflow +#[case("Uint64", r#"{"lhs": "18446744073709551616"}"#, r#"{"rhs": "0"}"#)] // Uint64 overflow +fn test_uint_less_than_range_failure( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = LESS_THAN_TPL.replace("{}", int_type); + + // Test that the provided inputs result in an error due to overflow. + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec!["0"]) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} +// Test for less than or equal scenarios +#[rstest] +#[case(r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] // True case (lhs < rhs) +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "1"}"#, vec!["1"])] // True case (lhs == rhs) +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] // False case +fn test_less_eq_than( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { test_stdlib( "comparator/less_eq_than/less_eq_than_main.no", - "comparator/less_eq_than/less_eq_than.asm", + Some("comparator/less_eq_than/less_eq_than.asm"), public_inputs, private_inputs, - vec!["1"], + expected_output, )?; Ok(()) } -#[test] -fn test_less_eq_than_false() -> Result<()> { - let public_inputs = r#"{"lhs": "1"}"#; - let private_inputs = r#"{"rhs": "0"}"#; +// implement the rest for less than eq - test_stdlib( - "comparator/less_eq_than/less_eq_than_main.no", - "comparator/less_eq_than/less_eq_than.asm", - public_inputs, - private_inputs, - vec!["0"], - )?; +#[rstest] +#[case("Uint8", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint16", r#"{"lhs": "1"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint32", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint64", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_uint_less_eq_than( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + let code = LESS_THAN_EQ_TPL.replace("{}", int_type); + + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; Ok(()) } -// test value overflow modulus -// it shouldn't need user to enter the bit length -// should have a way to restrict and type check the value to a certain bit length +#[rstest] +#[case("Uint8", r#"{"lhs": "256"}"#, r#"{"rhs": "0"}"#)] // Uint8 overflow +#[case("Uint16", r#"{"lhs": "65536"}"#, r#"{"rhs": "0"}"#)] // Uint16 overflow +#[case("Uint32", r#"{"lhs": "4294967296"}"#, r#"{"rhs": "0"}"#)] // Uint32 overflow +#[case("Uint64", r#"{"lhs": "18446744073709551616"}"#, r#"{"rhs": "0"}"#)] // Uint64 overflow +fn test_uint_less_eq_than_range_failure( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = LESS_THAN_EQ_TPL.replace("{}", int_type); + + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec!["0"]) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 107d44fe9..5442e9016 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -16,18 +16,38 @@ use crate::{ fn test_stdlib( path: &str, - asm_path: &str, + asm_path: Option<&str>, public_inputs: &str, private_inputs: &str, expected_public_output: Vec<&str>, ) -> Result>> { - let r1cs = R1CS::new(); let root = env!("CARGO_MANIFEST_DIR"); let prefix_path = Path::new(root).join("src/tests/stdlib"); // read noname file let code = std::fs::read_to_string(prefix_path.clone().join(path)).unwrap(); + let compiled_circuit = test_stdlib_code( + &code, + asm_path, + public_inputs, + private_inputs, + expected_public_output, + )?; + + Ok(compiled_circuit) +} + +fn test_stdlib_code( + code: &str, + asm_path: Option<&str>, + public_inputs: &str, + private_inputs: &str, + expected_public_output: Vec<&str>, +) -> Result>> { + let r1cs = R1CS::new(); + let root = env!("CARGO_MANIFEST_DIR"); + // parse inputs let public_inputs = parse_inputs(public_inputs).unwrap(); let private_inputs = parse_inputs(private_inputs).unwrap(); @@ -43,8 +63,8 @@ fn test_stdlib( &mut tast, this_module, &mut sources, - path.to_string(), - code.clone(), + "test.no".to_string(), + code.to_string(), node_id, ) .unwrap(); @@ -53,9 +73,8 @@ fn test_stdlib( let compiled_circuit = CircuitWriter::generate_circuit(mast, r1cs)?; // this should check the constraints - let generated_witness = compiled_circuit - .generate_witness(public_inputs.clone(), private_inputs.clone()) - .unwrap(); + let generated_witness = + compiled_circuit.generate_witness(public_inputs.clone(), private_inputs.clone())?; let expected_public_output = expected_public_output .iter() @@ -76,9 +95,10 @@ fn test_stdlib( } // check the ASM - if compiled_circuit.circuit.backend.num_constraints() < 100 { + if asm_path.is_some() && compiled_circuit.circuit.backend.num_constraints() < 100 { let prefix_asm = Path::new(root).join("src/tests/stdlib/"); - let expected_asm = std::fs::read_to_string(prefix_asm.clone().join(asm_path)).unwrap(); + let expected_asm = + std::fs::read_to_string(prefix_asm.clone().join(asm_path.unwrap())).unwrap(); let obtained_asm = compiled_circuit.asm(&Sources::new(), false); if obtained_asm != expected_asm { From 2b47a8d7342a12e675a367cab1849ea8f3556043 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:16:55 +0800 Subject: [PATCH 21/36] fix: remove bit_len from uints --- src/stdlib/native/int.no | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 4390c3db6..4ad93c3b1 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -4,7 +4,6 @@ use std::comparator; // u8 struct Uint8 { inner: Field, - bit_len: Field, } fn Uint8.new(val: Field) -> Uint8 { @@ -14,15 +13,13 @@ fn Uint8.new(val: Field) -> Uint8 { let ignore_ = bits::to_bits(bit_len, val); return Uint8 { - inner: val, - bit_len: bit_len + inner: val }; } // u16 struct Uint16 { - inner: Field, - bit_len: Field, + inner: Field } fn Uint16.new(val: Field) -> Uint16 { @@ -32,15 +29,13 @@ fn Uint16.new(val: Field) -> Uint16 { let ignore_ = bits::to_bits(bit_len, val); return Uint16 { - inner: val, - bit_len: bit_len + inner: val }; } // u32 struct Uint32 { - inner: Field, - bit_len: Field, + inner: Field } fn Uint32.new(val: Field) -> Uint32 { @@ -50,15 +45,13 @@ fn Uint32.new(val: Field) -> Uint32 { let ignore_ = bits::to_bits(bit_len, val); return Uint32 { - inner: val, - bit_len: bit_len + inner: val }; } // u64 struct Uint64 { - inner: Field, - bit_len: Field, + inner: Field } fn Uint64.new(val: Field) -> Uint64 { @@ -68,41 +61,40 @@ fn Uint64.new(val: Field) -> Uint64 { let ignore_ = bits::to_bits(bit_len, val); return Uint64 { - inner: val, - bit_len: bit_len + inner: val }; } // implement comparator fn Uint8.less_than(self, rhs: Uint8) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(8, self.inner, rhs.inner); } fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(8, self.inner, rhs.inner); } fn Uint16.less_than(self, rhs: Uint16) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(16, self.inner, rhs.inner); } fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(16, self.inner, rhs.inner); } fn Uint32.less_than(self, rhs: Uint32) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(32, self.inner, rhs.inner); } fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(32, self.inner, rhs.inner); } fn Uint64.less_than(self, rhs: Uint64) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(64, self.inner, rhs.inner); } fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(64, self.inner, rhs.inner); } \ No newline at end of file From 8ef099379ed8c88a260fb47719b81a232caa1c86 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:32:14 +0800 Subject: [PATCH 22/36] simplify to_bits --- .../asm/kimchi/generic_builtin_bits.asm | 128 ++++++++---------- .../fixture/asm/r1cs/generic_builtin_bits.asm | 46 +++---- src/stdlib/native/bits.no | 37 ++--- .../comparator/less_eq_than/less_eq_than.asm | 40 +++--- .../stdlib/comparator/less_than/less_than.asm | 40 +++--- 5 files changed, 123 insertions(+), 168 deletions(-) diff --git a/examples/fixture/asm/kimchi/generic_builtin_bits.asm b/examples/fixture/asm/kimchi/generic_builtin_bits.asm index a7843d1c4..87ceb2cce 100644 --- a/examples/fixture/asm/kimchi/generic_builtin_bits.asm +++ b/examples/fixture/asm/kimchi/generic_builtin_bits.asm @@ -1,9 +1,6 @@ @ noname.0.7.0 @ public inputs: 1 -DoubleGeneric<1> -DoubleGeneric<1,0,-1,0,-1> -DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,1> @@ -17,12 +14,6 @@ DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,-1> -DoubleGeneric<1,1> -DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,-1,0,-1> -DoubleGeneric<0,0,-1,1> -DoubleGeneric<1> DoubleGeneric<1,1> DoubleGeneric<1,1,-1> DoubleGeneric<0,0,-1,1> @@ -34,13 +25,6 @@ DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<2,0,-1> -DoubleGeneric<1,1> -DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,1,-1> -DoubleGeneric<1,0,-1,0,-1> -DoubleGeneric<0,0,-1,1> -DoubleGeneric<1> DoubleGeneric<1,1> DoubleGeneric<1,1,-1> DoubleGeneric<0,0,-1,1> @@ -52,6 +36,13 @@ DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,1,-1> DoubleGeneric<4,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> @@ -77,61 +68,52 @@ DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,0,0,0,-2> -(0,0) -> (55,1) -> (74,1) -> (75,0) -(1,0) -> (2,0) -> (5,0) -(1,2) -> (2,1) -(2,2) -> (3,0) -(4,0) -> (6,0) -> (23,0) -> (41,0) -(5,1) -> (6,1) -(6,2) -> (7,1) -> (11,1) -(7,2) -> (10,0) -(8,0) -> (11,0) -> (13,0) -> (14,0) -(8,1) -> (9,0) -(9,2) -> (10,1) -(11,2) -> (12,0) -(13,2) -> (16,0) -> (17,0) -> (56,0) -> (63,0) -> (64,0) -(14,1) -> (15,0) -(16,2) -> (36,0) -(17,1) -> (18,0) -(19,0) -> (20,0) -> (22,0) -(19,2) -> (20,1) -(20,2) -> (21,0) -(22,1) -> (23,1) -(23,2) -> (24,1) -> (28,1) -(24,2) -> (27,0) -(25,0) -> (28,0) -> (30,0) -> (31,0) -(25,1) -> (26,0) -(26,2) -> (27,1) -(28,2) -> (29,0) -(30,2) -> (33,0) -> (34,0) -> (59,0) -> (66,0) -> (67,0) -(31,1) -> (32,0) -(33,2) -> (36,1) -(34,1) -> (35,0) -(36,2) -> (54,0) -(37,0) -> (38,0) -> (40,0) -(37,2) -> (38,1) -(38,2) -> (39,0) -(40,1) -> (41,1) -(41,2) -> (42,1) -> (46,1) -(42,2) -> (45,0) -(43,0) -> (46,0) -> (48,0) -> (49,0) +(0,0) -> (46,1) -> (65,1) -> (66,0) +(1,0) -> (3,0) -> (14,0) -> (25,0) +(2,1) -> (3,1) +(3,2) -> (4,1) -> (8,1) +(4,2) -> (7,0) +(5,0) -> (8,0) -> (10,0) -> (11,0) +(5,1) -> (6,0) +(6,2) -> (7,1) +(8,2) -> (9,0) +(10,2) -> (35,0) -> (36,0) -> (47,0) -> (54,0) -> (55,0) +(11,1) -> (12,0) +(13,1) -> (14,1) +(14,2) -> (15,1) -> (19,1) +(15,2) -> (18,0) +(16,0) -> (19,0) -> (21,0) -> (22,0) +(16,1) -> (17,0) +(17,2) -> (18,1) +(19,2) -> (20,0) +(21,2) -> (38,0) -> (39,0) -> (50,0) -> (57,0) -> (58,0) +(22,1) -> (23,0) +(24,1) -> (25,1) +(25,2) -> (26,1) -> (30,1) +(26,2) -> (29,0) +(27,0) -> (30,0) -> (32,0) -> (33,0) +(27,1) -> (28,0) +(28,2) -> (29,1) +(30,2) -> (31,0) +(32,2) -> (42,0) -> (43,0) -> (51,0) -> (61,0) -> (62,0) +(33,1) -> (34,0) +(35,2) -> (41,0) +(36,1) -> (37,0) +(38,2) -> (41,1) +(39,1) -> (40,0) +(41,2) -> (45,0) +(42,2) -> (45,1) (43,1) -> (44,0) -(44,2) -> (45,1) -(46,2) -> (47,0) -(48,2) -> (51,0) -> (52,0) -> (60,0) -> (70,0) -> (71,0) -(49,1) -> (50,0) -(51,2) -> (54,1) -(52,1) -> (53,0) -(54,2) -> (55,0) -(56,1) -> (57,0) -(57,2) -> (58,0) -(60,1) -> (61,0) -(61,2) -> (62,0) -(63,2) -> (69,0) -(64,1) -> (65,0) -(66,2) -> (69,1) -(67,1) -> (68,0) -(69,2) -> (73,0) -(70,2) -> (73,1) -(71,1) -> (72,0) -(73,2) -> (74,0) +(45,2) -> (46,0) +(47,1) -> (48,0) +(48,2) -> (49,0) +(51,1) -> (52,0) +(52,2) -> (53,0) +(54,2) -> (60,0) +(55,1) -> (56,0) +(57,2) -> (60,1) +(58,1) -> (59,0) +(60,2) -> (64,0) +(61,2) -> (64,1) +(62,1) -> (63,0) +(64,2) -> (65,0) diff --git a/examples/fixture/asm/r1cs/generic_builtin_bits.asm b/examples/fixture/asm/r1cs/generic_builtin_bits.asm index fb72f65a4..0988a4dbe 100644 --- a/examples/fixture/asm/r1cs/generic_builtin_bits.asm +++ b/examples/fixture/asm/r1cs/generic_builtin_bits.asm @@ -1,30 +1,24 @@ @ noname.0.7.0 @ public inputs: 1 -v_3 == (v_2) * (v_2 + -1) -0 == (v_3) * (1) -1 == (v_4) * (1) -v_6 == (v_5) * (-1 * v_2 + v_4) --1 * v_7 + 1 == (v_6) * (1) -v_8 == (v_7) * (-1 * v_2 + v_4) -0 == (v_8) * (1) -v_10 == (v_9) * (v_9 + -1) -0 == (v_10) * (1) -1 == (v_11) * (1) -v_13 == (v_12) * (-1 * v_9 + v_11) --1 * v_14 + 1 == (v_13) * (1) -v_15 == (v_14) * (-1 * v_9 + v_11) -0 == (v_15) * (1) -v_17 == (v_16) * (v_16 + -1) -0 == (v_17) * (1) -1 == (v_18) * (1) -v_20 == (v_19) * (-1 * v_16 + v_18) --1 * v_21 + 1 == (v_20) * (1) -v_22 == (v_21) * (-1 * v_16 + v_18) -0 == (v_22) * (1) -v_1 == (v_7 + 2 * v_14 + 4 * v_21) * (1) -1 == (-1 * v_7 + 1) * (1) -1 == (v_14) * (1) -1 == (-1 * v_21 + 1) * (1) -v_1 == (v_7 + 2 * v_14 + 4 * v_21) * (1) +1 == (v_3) * (1) +v_5 == (v_4) * (-1 * v_2 + v_3) +-1 * v_6 + 1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_2 + v_3) +0 == (v_7) * (1) +1 == (v_9) * (1) +v_11 == (v_10) * (-1 * v_8 + v_9) +-1 * v_12 + 1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_8 + v_9) +0 == (v_13) * (1) +1 == (v_15) * (1) +v_17 == (v_16) * (-1 * v_14 + v_15) +-1 * v_18 + 1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_14 + v_15) +0 == (v_19) * (1) +v_1 == (v_6 + 2 * v_12 + 4 * v_18) * (1) +1 == (-1 * v_6 + 1) * (1) +1 == (v_12) * (1) +1 == (-1 * v_18 + 1) * (1) +v_1 == (v_6 + 2 * v_12 + 4 * v_18) * (1) 2 == (v_1) * (1) diff --git a/src/stdlib/native/bits.no b/src/stdlib/native/bits.no index d991b04de..5043b49cc 100644 --- a/src/stdlib/native/bits.no +++ b/src/stdlib/native/bits.no @@ -1,13 +1,22 @@ hint fn nth_bit(value: Field, const nth: Field) -> Field; +fn from_bits(bits: [Bool; LEN]) -> Field { + let mut lc1 = 0; + let mut e2 = 1; + let zero = 0; + + for index in 0..LEN { + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + return lc1; +} + fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { let mut bits = [false; LEN]; let mut lc1 = 0; let mut e2 = 1; - let one = 1; - let zero = 0; - // todo: ITE should allow literals let true_val = true; let false_val = false; @@ -15,27 +24,9 @@ fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { for index in 0..LEN { let bit_num = unsafe nth_bit(value, index); - // constrain the bit_num to be 0 or 1 - assert_eq(bit_num * (bit_num - 1), 0); - - // convert the bit_num to boolean bits[index] = if bit_num == 1 {true_val} else {false_val}; - - lc1 = lc1 + if bits[index] {e2} else {zero}; - e2 = e2 + e2; } - assert_eq(lc1, value); + + assert_eq(from_bits(bits), value); return bits; } - -fn from_bits(bits: [Bool; LEN]) -> Field { - let mut lc1 = 0; - let mut e2 = 1; - let zero = 0; - - for index in 0..LEN { - lc1 = lc1 + if bits[index] {e2} else {zero}; - e2 = e2 + e2; - } - return lc1; -} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm index 4608ebd97..32f492c19 100644 --- a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm @@ -1,26 +1,20 @@ @ noname.0.7.0 @ public inputs: 2 -v_5 == (v_4) * (v_4 + -1) -0 == (v_5) * (1) -1 == (v_6) * (1) -v_8 == (v_7) * (-1 * v_4 + v_6) --1 * v_9 + 1 == (v_8) * (1) -v_10 == (v_9) * (-1 * v_4 + v_6) -0 == (v_10) * (1) -v_12 == (v_11) * (v_11 + -1) -0 == (v_12) * (1) -1 == (v_13) * (1) -v_15 == (v_14) * (-1 * v_11 + v_13) --1 * v_16 + 1 == (v_15) * (1) -v_17 == (v_16) * (-1 * v_11 + v_13) -0 == (v_17) * (1) -v_19 == (v_18) * (v_18 + -1) -0 == (v_19) * (1) -1 == (v_20) * (1) -v_22 == (v_21) * (-1 * v_18 + v_20) --1 * v_23 + 1 == (v_22) * (1) -v_24 == (v_23) * (-1 * v_18 + v_20) -0 == (v_24) * (1) -v_2 + -1 * v_3 + 3 == (v_9 + 2 * v_16 + 4 * v_23) * (1) --1 * v_23 + 1 == (v_1) * (1) +1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_4 + v_5) +-1 * v_8 + 1 == (v_7) * (1) +v_9 == (v_8) * (-1 * v_4 + v_5) +0 == (v_9) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_10 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_10 + v_11) +0 == (v_15) * (1) +1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_16 + v_17) +-1 * v_20 + 1 == (v_19) * (1) +v_21 == (v_20) * (-1 * v_16 + v_17) +0 == (v_21) * (1) +v_2 + -1 * v_3 + 3 == (v_8 + 2 * v_14 + 4 * v_20) * (1) +-1 * v_20 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm index beac8c576..cc615ad37 100644 --- a/src/tests/stdlib/comparator/less_than/less_than.asm +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -1,26 +1,20 @@ @ noname.0.7.0 @ public inputs: 2 -v_5 == (v_4) * (v_4 + -1) -0 == (v_5) * (1) -1 == (v_6) * (1) -v_8 == (v_7) * (-1 * v_4 + v_6) --1 * v_9 + 1 == (v_8) * (1) -v_10 == (v_9) * (-1 * v_4 + v_6) -0 == (v_10) * (1) -v_12 == (v_11) * (v_11 + -1) -0 == (v_12) * (1) -1 == (v_13) * (1) -v_15 == (v_14) * (-1 * v_11 + v_13) --1 * v_16 + 1 == (v_15) * (1) -v_17 == (v_16) * (-1 * v_11 + v_13) -0 == (v_17) * (1) -v_19 == (v_18) * (v_18 + -1) -0 == (v_19) * (1) -1 == (v_20) * (1) -v_22 == (v_21) * (-1 * v_18 + v_20) --1 * v_23 + 1 == (v_22) * (1) -v_24 == (v_23) * (-1 * v_18 + v_20) -0 == (v_24) * (1) -v_2 + -1 * v_3 + 4 == (v_9 + 2 * v_16 + 4 * v_23) * (1) --1 * v_23 + 1 == (v_1) * (1) +1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_4 + v_5) +-1 * v_8 + 1 == (v_7) * (1) +v_9 == (v_8) * (-1 * v_4 + v_5) +0 == (v_9) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_10 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_10 + v_11) +0 == (v_15) * (1) +1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_16 + v_17) +-1 * v_20 + 1 == (v_19) * (1) +v_21 == (v_20) * (-1 * v_16 + v_17) +0 == (v_21) * (1) +v_2 + -1 * v_3 + 4 == (v_8 + 2 * v_14 + 4 * v_20) * (1) +-1 * v_20 + 1 == (v_1) * (1) From 1f8826a6ded7d806b979bd29c74ad84f6e8c9b9b Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:44:12 +0800 Subject: [PATCH 23/36] add const var STDLIB_DIRECTORY --- src/stdlib/mod.rs | 3 +++ src/tests/examples.rs | 6 +++--- src/tests/stdlib/mod.rs | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index cde82ead8..efa5a6e4e 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -20,6 +20,9 @@ pub mod bits; pub mod builtins; pub mod crypto; +/// The directory under [NONAME_DIRECTORY] containing the native stdlib. +pub const STDLIB_DIRECTORY: &str = "src/stdlib/native/"; + pub enum AllStdModules { Builtins, Crypto, diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 1aa9a31f3..dde1093ef 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -10,7 +10,7 @@ use crate::{ }, compiler::{compile, typecheck_next_file, Sources}, inputs::{parse_inputs, ExtField}, - stdlib::init_stdlib_dep, + stdlib::{init_stdlib_dep, STDLIB_DIRECTORY}, type_checker::TypeChecker, }; @@ -38,7 +38,7 @@ fn test_file( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -102,7 +102,7 @@ fn test_file( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 5442e9016..18f8668e3 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -9,7 +9,7 @@ use crate::{ error::Result, inputs::parse_inputs, mast, - stdlib::init_stdlib_dep, + stdlib::{init_stdlib_dep, STDLIB_DIRECTORY}, type_checker::TypeChecker, witness::CompiledCircuit, }; @@ -56,7 +56,7 @@ fn test_stdlib_code( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( From 8834fffb71ace2df1097522c23b08ed2eb7ab829 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:52:49 +0800 Subject: [PATCH 24/36] move native stdlib to their own folder --- src/stdlib/mod.rs | 3 ++- src/stdlib/native/{bits.no => bits/lib.no} | 0 src/stdlib/native/{comparator.no => comparator/lib.no} | 0 src/stdlib/native/{int.no => int/lib.no} | 0 4 files changed, 2 insertions(+), 1 deletion(-) rename src/stdlib/native/{bits.no => bits/lib.no} (100%) rename src/stdlib/native/{comparator.no => comparator/lib.no} (100%) rename src/stdlib/native/{int.no => int/lib.no} (100%) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index efa5a6e4e..65497e20a 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -100,7 +100,8 @@ pub fn init_stdlib_dep( for lib in libs { let module = UserRepo::new(&format!("std/{}", lib)); let prefix_stdlib = Path::new(path_prefix); - let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); + println!("Loading stdlib: {}", prefix_stdlib.join(format!("{lib}/lib.no")).display()); + let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}/lib.no"))).unwrap(); node_id = typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); } diff --git a/src/stdlib/native/bits.no b/src/stdlib/native/bits/lib.no similarity index 100% rename from src/stdlib/native/bits.no rename to src/stdlib/native/bits/lib.no diff --git a/src/stdlib/native/comparator.no b/src/stdlib/native/comparator/lib.no similarity index 100% rename from src/stdlib/native/comparator.no rename to src/stdlib/native/comparator/lib.no diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int/lib.no similarity index 100% rename from src/stdlib/native/int.no rename to src/stdlib/native/int/lib.no From a822a1b9fd1d83d5465c0e0da20e5e46e7373c99 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:52:54 +0800 Subject: [PATCH 25/36] fmt --- src/stdlib/mod.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 65497e20a..043f6a147 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -100,7 +100,10 @@ pub fn init_stdlib_dep( for lib in libs { let module = UserRepo::new(&format!("std/{}", lib)); let prefix_stdlib = Path::new(path_prefix); - println!("Loading stdlib: {}", prefix_stdlib.join(format!("{lib}/lib.no")).display()); + println!( + "Loading stdlib: {}", + prefix_stdlib.join(format!("{lib}/lib.no")).display() + ); let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}/lib.no"))).unwrap(); node_id = typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); From 3f91646ab0d0934cd33e60112fa3ac33dd95191a Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Thu, 24 Oct 2024 14:53:48 +0800 Subject: [PATCH 26/36] Add u16/32/64 (#206) * add uint16/32/64 * allow [0][0] --- src/parser/expr.rs | 6 ++- src/stdlib/native/int.no | 108 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 src/stdlib/native/int.no diff --git a/src/parser/expr.rs b/src/parser/expr.rs index d070290bb..db68f2eba 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -545,9 +545,11 @@ impl Expr { // sanity check if !matches!( self.kind, - ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } + ExprKind::Variable { .. } + | ExprKind::FieldAccess { .. } + | ExprKind::ArrayAccess { .. } ) { - panic!("an array access can only follow a variable"); + panic!("an array access can only follow a variable or another array access"); } // array[idx] diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no new file mode 100644 index 000000000..4390c3db6 --- /dev/null +++ b/src/stdlib/native/int.no @@ -0,0 +1,108 @@ +use std::bits; +use std::comparator; + +// u8 +struct Uint8 { + inner: Field, + bit_len: Field, +} + +fn Uint8.new(val: Field) -> Uint8 { + let bit_len = 8; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint8 { + inner: val, + bit_len: bit_len + }; +} + +// u16 +struct Uint16 { + inner: Field, + bit_len: Field, +} + +fn Uint16.new(val: Field) -> Uint16 { + let bit_len = 16; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint16 { + inner: val, + bit_len: bit_len + }; +} + +// u32 +struct Uint32 { + inner: Field, + bit_len: Field, +} + +fn Uint32.new(val: Field) -> Uint32 { + let bit_len = 32; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint32 { + inner: val, + bit_len: bit_len + }; +} + +// u64 +struct Uint64 { + inner: Field, + bit_len: Field, +} + +fn Uint64.new(val: Field) -> Uint64 { + let bit_len = 64; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint64 { + inner: val, + bit_len: bit_len + }; +} + +// implement comparator + +fn Uint8.less_than(self, rhs: Uint8) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint16.less_than(self, rhs: Uint16) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint32.less_than(self, rhs: Uint32) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint64.less_than(self, rhs: Uint64) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} \ No newline at end of file From c2c78c9b9796beb03d74046600efcd3e32fb7c48 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 15:34:59 +0800 Subject: [PATCH 27/36] remove deadcode --- src/parser/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index dd838663c..511354478 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -220,7 +220,6 @@ mod tests { let code = r#"main(pub public_input: [Fel; 3], private_input: [Fel; 3]) -> [Fel; 3] { return public_input; }"#; let tokens = &mut Token::parse(0, code).unwrap(); let ctx = &mut ParserCtx::default(); - let is_hint = false; let parsed = FunctionDef::parse(ctx, tokens).unwrap(); println!("{:?}", parsed); } From 328ffbf8a1921f91269e355e18207b1a0a7b261d Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 17:10:42 +0800 Subject: [PATCH 28/36] point to zksecurity repo --- src/cli/packages.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli/packages.rs b/src/cli/packages.rs index 28effedcd..a22818bc0 100644 --- a/src/cli/packages.rs +++ b/src/cli/packages.rs @@ -276,7 +276,7 @@ pub fn download_from_github(dep: &UserRepo) -> Result<()> { pub fn download_stdlib() -> Result<()> { // Hardcoded repository details and target branch - let repo_owner = "katat"; + let repo_owner = "zksecurity"; let repo_name = "noname"; let target_branch = "release"; let repo_url = format!( From 42a520a8927a3f76aeb49f29bfda3f4bd11fa510 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 25 Oct 2024 11:18:30 +0800 Subject: [PATCH 29/36] remove deadcode --- src/stdlib/native/int.no | 108 --------------------------------------- 1 file changed, 108 deletions(-) delete mode 100644 src/stdlib/native/int.no diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no deleted file mode 100644 index 4390c3db6..000000000 --- a/src/stdlib/native/int.no +++ /dev/null @@ -1,108 +0,0 @@ -use std::bits; -use std::comparator; - -// u8 -struct Uint8 { - inner: Field, - bit_len: Field, -} - -fn Uint8.new(val: Field) -> Uint8 { - let bit_len = 8; - - // range check - let ignore_ = bits::to_bits(bit_len, val); - - return Uint8 { - inner: val, - bit_len: bit_len - }; -} - -// u16 -struct Uint16 { - inner: Field, - bit_len: Field, -} - -fn Uint16.new(val: Field) -> Uint16 { - let bit_len = 16; - - // range check - let ignore_ = bits::to_bits(bit_len, val); - - return Uint16 { - inner: val, - bit_len: bit_len - }; -} - -// u32 -struct Uint32 { - inner: Field, - bit_len: Field, -} - -fn Uint32.new(val: Field) -> Uint32 { - let bit_len = 32; - - // range check - let ignore_ = bits::to_bits(bit_len, val); - - return Uint32 { - inner: val, - bit_len: bit_len - }; -} - -// u64 -struct Uint64 { - inner: Field, - bit_len: Field, -} - -fn Uint64.new(val: Field) -> Uint64 { - let bit_len = 64; - - // range check - let ignore_ = bits::to_bits(bit_len, val); - - return Uint64 { - inner: val, - bit_len: bit_len - }; -} - -// implement comparator - -fn Uint8.less_than(self, rhs: Uint8) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint16.less_than(self, rhs: Uint16) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint32.less_than(self, rhs: Uint32) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint64.less_than(self, rhs: Uint64) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); -} - -fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); -} \ No newline at end of file From be8127d151e9419f09e283db18ceb5b369d3eb9a Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 25 Oct 2024 12:38:49 +0800 Subject: [PATCH 30/36] doc bits --- src/stdlib/native/bits/lib.no | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/stdlib/native/bits/lib.no b/src/stdlib/native/bits/lib.no index 5043b49cc..3375799df 100644 --- a/src/stdlib/native/bits/lib.no +++ b/src/stdlib/native/bits/lib.no @@ -1,5 +1,6 @@ hint fn nth_bit(value: Field, const nth: Field) -> Field; +// Convert a bit array to a Field fn from_bits(bits: [Bool; LEN]) -> Field { let mut lc1 = 0; let mut e2 = 1; @@ -12,6 +13,7 @@ fn from_bits(bits: [Bool; LEN]) -> Field { return lc1; } +// Convert a Field to a bit array fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { let mut bits = [false; LEN]; let mut lc1 = 0; From 665ef2edf2d8a32b4ee330deb8c72b0db75b3178 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 25 Oct 2024 12:58:20 +0800 Subject: [PATCH 31/36] use main branch as the release branch for downloading --- src/cli/packages.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli/packages.rs b/src/cli/packages.rs index a22818bc0..7c06c2af4 100644 --- a/src/cli/packages.rs +++ b/src/cli/packages.rs @@ -278,7 +278,7 @@ pub fn download_stdlib() -> Result<()> { // Hardcoded repository details and target branch let repo_owner = "zksecurity"; let repo_name = "noname"; - let target_branch = "release"; + let target_branch = "main"; let repo_url = format!( "https://github.com/{owner}/{repo}.git", owner = repo_owner, From ab75d858d971f2b8b637e0c6e318a0ec615d94bc Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 25 Oct 2024 12:58:31 +0800 Subject: [PATCH 32/36] clean up --- src/cli/packages.rs | 4 +++- src/stdlib/mod.rs | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/cli/packages.rs b/src/cli/packages.rs index 7c06c2af4..263c12401 100644 --- a/src/cli/packages.rs +++ b/src/cli/packages.rs @@ -7,6 +7,8 @@ use camino::Utf8PathBuf as PathBuf; use miette::{Context, IntoDiagnostic, Result}; use serde::{Deserialize, Serialize}; +use crate::stdlib::STDLIB_DIRECTORY; + use super::{ manifest::{read_manifest, Manifest}, NONAME_DIRECTORY, PACKAGE_DIRECTORY, RELEASE_DIRECTORY, @@ -248,7 +250,7 @@ pub(crate) fn path_to_stdlib() -> PathBuf { .expect("invalid UTF8 path"); let noname_dir = home_dir.join(NONAME_DIRECTORY); - noname_dir.join(RELEASE_DIRECTORY).join("src/stdlib/native") + noname_dir.join(RELEASE_DIRECTORY).join(STDLIB_DIRECTORY) } /// download package from github diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 043f6a147..11848eda7 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -100,10 +100,6 @@ pub fn init_stdlib_dep( for lib in libs { let module = UserRepo::new(&format!("std/{}", lib)); let prefix_stdlib = Path::new(path_prefix); - println!( - "Loading stdlib: {}", - prefix_stdlib.join(format!("{lib}/lib.no")).display() - ); let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}/lib.no"))).unwrap(); node_id = typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); From 34cf5c476ff86715b952fdb93590ea2974778605 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 25 Oct 2024 13:50:16 +0800 Subject: [PATCH 33/36] ci: init stdlib from the latest code in pr --- .github/workflows/snarkjs.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/snarkjs.sh b/.github/workflows/snarkjs.sh index 7031f163b..a3bc6ddf1 100755 --- a/.github/workflows/snarkjs.sh +++ b/.github/workflows/snarkjs.sh @@ -9,6 +9,10 @@ fi DIR_PATH=$1 CURVE=$2 +# Init stdlib in .noname/release/src/stdlib instead of downloading +echo "Overriding stdlib in .noname/release/src/stdlib..." +mkdir -p ~/.noname/release/src/stdlib/ && cp -r /app/noname/src/stdlib/* ~/.noname/release/src/stdlib/ + # Ensure the circuit directory exists and is initialized echo "Initializing a new Noname package..." noname new --path circuit_noname From 4cfb0c52d7a23c1e746f0cc181de7b1c5f8b1e3a Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 25 Oct 2024 15:50:03 +0800 Subject: [PATCH 34/36] add docs to stdlib --- src/stdlib/native/bits/lib.no | 53 ++++++++++++++++++++++++++--- src/stdlib/native/comparator/lib.no | 47 ++++++++++++++++--------- 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/src/stdlib/native/bits/lib.no b/src/stdlib/native/bits/lib.no index 3375799df..1d7097eb2 100644 --- a/src/stdlib/native/bits/lib.no +++ b/src/stdlib/native/bits/lib.no @@ -1,34 +1,77 @@ +/// a hint function (unconstrained) to extracts the `nth` bit from a given `value`. +/// Its current implementation points to `std::bits::nth_bit`. So it has an empty body in definition. +/// +/// # Parameters +/// - `value`: The `Field` value from which to extract the bit. +/// - `nth`: The position of the bit to extract (0-indexed). +/// +/// # Returns +/// - `Field`: The value of the `nth` bit (0 or 1). +/// hint fn nth_bit(value: Field, const nth: Field) -> Field; -// Convert a bit array to a Field +/// Converts an array of boolean values (`bits`) into a `Field` value. +/// +/// # Parameters +/// - `bits`: An array of `Bool` values representing bits, where each `true` represents `1` and `false` represents `0`. +/// +/// # Returns +/// - `Field`: A `Field` value that represents the integer obtained from the binary representation of `bits`. +/// +/// # Example +/// ``` +/// let bits = [true, false, true]; // Represents the binary value 101 +/// let result = from_bits(bits); +/// `result` should be 5 as 101 in binary equals 5 in decimal. +/// ``` fn from_bits(bits: [Bool; LEN]) -> Field { let mut lc1 = 0; let mut e2 = 1; let zero = 0; for index in 0..LEN { - lc1 = lc1 + if bits[index] {e2} else {zero}; + lc1 = lc1 + if bits[index] { e2 } else { zero }; e2 = e2 + e2; } return lc1; } -// Convert a Field to a bit array +/// Converts a `Field` value into an array of boolean values (`bits`) representing its binary form. +/// +/// # Parameters +/// - `LEN`: The length of the resulting bit array. Determines how many bits are considered in the conversion. +/// - `value`: The `Field` value to convert into binary representation. +/// +/// # Returns +/// - `[Bool; LEN]`: An array of boolean values where each `true` represents `1` and `false` represents `0`. +/// +/// # Example +/// ``` +/// let value = 5; // Binary representation: 101 +/// let bits = to_bits(3, value); +/// `bits` should be [true, false, true] corresponding to the binary 101. +/// ``` +/// +/// # Panics +/// - The function asserts that `from_bits(bits)` equals `value`, ensuring the conversion is correct. fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { let mut bits = [false; LEN]; let mut lc1 = 0; let mut e2 = 1; - // todo: ITE should allow literals + // TODO: ITE should allow literals. let true_val = true; let false_val = false; for index in 0..LEN { let bit_num = unsafe nth_bit(value, index); - bits[index] = if bit_num == 1 {true_val} else {false_val}; + // constraint the bit values to booleans + bits[index] = if bit_num == 1 { true_val } else { false_val }; } + // constraint the accumulative contributions of bits to be equal to the value assert_eq(from_bits(bits), value); + return bits; } diff --git a/src/stdlib/native/comparator/lib.no b/src/stdlib/native/comparator/lib.no index dd11c6c62..58ad95593 100644 --- a/src/stdlib/native/comparator/lib.no +++ b/src/stdlib/native/comparator/lib.no @@ -1,26 +1,33 @@ use std::bits; -// Instead of comparing bit by bit, we check the carry bit: -// lhs + (1 << LEN) - rhs -// proof: -// lhs + (1 << LEN) will add a carry bit, valued 1, to the bit array representing lhs, -// resulted in a bit array of length LEN + 1, named as sum_bits. -// if `lhs < rhs``, then `lhs - rhs < 0`, thus `(1 << LEN) + lhs - rhs < (1 << LEN)` -// then, the carry bit of sum_bits is 0. -// otherwise, the carry bit of sum_bits is 1. +/// Checks if `lhs` is less than `rhs` by evaluating the carry bit after addition and subtraction. +/// +/// # Parameters +/// - `LEN`: The assumped bit length of both `lhs` and `rhs`. +/// - `lhs`: The left-hand side `Field` value to be compared. +/// - `rhs`: The right-hand side `Field` value to be compared. +/// +/// # Returns +/// - `Bool`: `true` if `lhs` is less than `rhs`, otherwise `false`. +/// +/// # Proof +/// - Adding `pow2` to `lhs` ensures a carry bit is added to the result, creating a bit array of length `LEN + 1`. +/// - If `lhs < rhs`, then `lhs - rhs < 0`, making `(1 << LEN) + lhs - rhs` less than `1 << LEN`, resulting in a carry bit of `0`. +/// - Otherwise, the carry bit will be `1`. +/// fn less_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { let carry_bit_len = LEN + 1; - // 1 << LEN + // Calculate 2^LEN using bit shifts. let mut pow2 = 1; for ii in 0..LEN { pow2 = pow2 + pow2; } + // Calculate the adjusted sum to determine the carry bit. let sum = (pow2 + lhs) - rhs; let sum_bit = bits::to_bits(carry_bit_len, sum); - // todo: modify the ife to allow literals let b1 = false; let b2 = true; let res = if sum_bit[LEN] { b1 } else { b2 }; @@ -28,12 +35,20 @@ fn less_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { return res; } -// Less than or equal to. -// based on the proof of less_than(): -// adding 1 to the rhs, can upper bound by 1 for the lhs: -// lhs < rhs + 1 -// is equivalent to -// lhs <= rhs +/// Checks if `lhs` is less than or equal to `rhs` using the `less_than` function. +/// +/// # Parameters +/// - `LEN`: The assumped bit length of both `lhs` and `rhs`. +/// - `lhs`: The left-hand side `Field` value to be compared. +/// - `rhs`: The right-hand side `Field` value to be compared. +/// +/// # Returns +/// - `Bool`: `true` if `lhs` is less than or equal to `rhs`, otherwise `false`. +/// +/// # Proof +/// By adding 1 to rhs can increase upper bound by 1 for the lhs. +/// Thus, `lhs < lhs + 1` => `lhs <= rhs`. +/// ``` fn less_eq_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { return less_than(LEN, lhs, rhs + 1); } From eb1def114d5883f41b71f538ed9d362697880a40 Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Mon, 28 Oct 2024 14:32:06 +0800 Subject: [PATCH 35/36] Support multiplexer (#207) * add multiplexer stdlib --- src/stdlib/mod.rs | 2 +- src/stdlib/native/multiplexer/lib.no | 105 ++++++++++++++++++ src/tests/stdlib/mod.rs | 1 + src/tests/stdlib/multiplexer/mod.rs | 48 ++++++++ .../multiplexer/select_element/main.asm | 36 ++++++ .../stdlib/multiplexer/select_element/main.no | 6 + 6 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 src/stdlib/native/multiplexer/lib.no create mode 100644 src/tests/stdlib/multiplexer/mod.rs create mode 100644 src/tests/stdlib/multiplexer/select_element/main.asm create mode 100644 src/tests/stdlib/multiplexer/select_element/main.no diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 11848eda7..ae7aed820 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -93,7 +93,7 @@ pub fn init_stdlib_dep( path_prefix: &str, ) -> usize { // list the stdlib dependency in order - let libs = vec!["bits", "comparator", "int"]; + let libs = vec!["bits", "comparator", "multiplexer", "int"]; let mut node_id = node_id; diff --git a/src/stdlib/native/multiplexer/lib.no b/src/stdlib/native/multiplexer/lib.no new file mode 100644 index 000000000..54e4b5a11 --- /dev/null +++ b/src/stdlib/native/multiplexer/lib.no @@ -0,0 +1,105 @@ +use std::comparator; + +/// Multiplies two vectors of the same length and returns the accumulated sum (dot product). +/// +/// # Parameters +/// - `lhs`: A vector (array) of `Field` elements. +/// - `rhs`: A vector (array) of `Field` elements. +/// +/// # Returns +/// - `Field`: The accumulated sum resulting from the element-wise multiplication of `lhs` and `rhs`. +/// +/// # Panics +/// - The function assumes that `lhs` and `rhs` have the same length, `LEN`. +/// +/// # Example +/// ``` +/// let lhs = [1, 2, 3]; +/// let rhs = [4, 5, 6]; +/// let result = escalar_product(lhs, rhs); +/// result should be 1*4 + 2*5 + 3*6 = 32 +/// ``` +fn escalar_product(lhs: [Field; LEN], rhs: [Field; LEN]) -> Field { + let mut lc = 0; + for idx in 0..LEN { + lc = lc + (lhs[idx] * rhs[idx]); + } + return lc; +} + +/// Generates a selector array of a given length `LEN` with all zeros except for a one at the specified `target_idx`. +/// +/// # Parameters +/// - `LEN`: The length of the output array. +/// - `target_idx`: The index where the value should be 1. The rest of the array will be filled with zeros. +/// +/// # Returns +/// - `[Field; LEN]`: An array of length `LEN` where all elements are zero except for a single `1` at `target_idx`. +/// +/// # Panics +/// - This function asserts that there is exactly one `1` in the generated array, ensuring `target_idx` is within bounds. +/// +/// # Example +/// ``` +/// let selector = gen_selector_arr(5, 2); +/// `selector` should be [0, 0, 1, 0, 0] +/// ``` +fn gen_selector_arr(const LEN: Field, target_idx: Field) -> [Field; LEN] { + let mut selector = [0; LEN]; + let mut lc = 0; + let one = 1; + let zero = 0; + + for idx in 0..LEN { + selector[idx] = if idx == target_idx { one } else { zero }; + lc = lc + selector[idx]; + } + + // Ensures there is exactly one '1' in the range of LEN. + assert(lc == 1); + + return selector; +} + +/// Selects an element from a 2D array based on a `target_idx` and returns a vector of length `WIDLEN`. +/// +/// # Parameters +/// - `arr`: A 2D array of dimensions `[ARRLEN][WIDLEN]` containing `Field` elements. +/// - `target_idx`: The index that determines which row of `arr` to select. +/// +/// # Returns +/// - `[Field; WIDLEN]`: A vector representing the selected row from `arr`. +/// +/// # Algorithm +/// 1. Generate a selector array using `gen_selector_arr` that has a `1` at `target_idx` and `0`s elsewhere. +/// 2. For each column index `idx` of the 2D array: +/// - Extract the `idx`-th element from each row into a temporary array. +/// - Use `escalar_product` with the temporary array and the selector array to `select` the value corresponding to `target_idx`. +/// 3. Reset the temporary array for the next iteration. +/// 4. Return the vector containing the selected row. +/// +/// # Example +/// ``` +/// let arr = [[1, 2], [3, 4], [5, 6]]; +/// let result = select_element(arr, 1); +/// `result` should be [3, 4] as it selects the second row (index 1). +/// ``` +fn select_element(arr: [[Field; WIDLEN]; ARRLEN], target_idx: Field) -> [Field; WIDLEN] { + let mut result = [0; WIDLEN]; + + let selector_arr = gen_selector_arr(ARRLEN, target_idx); + let mut one_len_arr = [0; ARRLEN]; + + for idx in 0..WIDLEN { + for jdx in 0..ARRLEN { + one_len_arr[jdx] = arr[jdx][idx]; + } + // Only one element in `selector_arr` is `1`, so the result is the element in `one_len_arr` + // at the same index as the `1` in `selector_arr`. + result[idx] = escalar_product(one_len_arr, selector_arr); + + // Reset the temporary array for the next column. + one_len_arr = [0; ARRLEN]; + } + return result; +} diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 18f8668e3..6f81d2a9f 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -1,4 +1,5 @@ mod comparator; +mod multiplexer; use std::{path::Path, str::FromStr}; diff --git a/src/tests/stdlib/multiplexer/mod.rs b/src/tests/stdlib/multiplexer/mod.rs new file mode 100644 index 000000000..cd921dd1d --- /dev/null +++ b/src/tests/stdlib/multiplexer/mod.rs @@ -0,0 +1,48 @@ +use crate::error::{self}; + +use super::test_stdlib; +use error::Result; +use rstest::rstest; + +#[rstest] +#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "1"}"#, vec!["3", "4", "5"])] +fn test_in_range( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + test_stdlib( + "multiplexer/select_element/main.no", + Some("multiplexer/select_element/main.asm"), + public_inputs, + private_inputs, + expected_output, + )?; + + Ok(()) +} + +// require the select idx to be in range +#[rstest] +#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "3"}"#, vec![])] +fn test_out_range( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + use crate::error::ErrorKind; + + let err = test_stdlib( + "multiplexer/select_element/main.no", + Some("multiplexer/select_element/main.asm"), + public_inputs, + private_inputs, + expected_output, + ) + .err() + .expect("Expected error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} diff --git a/src/tests/stdlib/multiplexer/select_element/main.asm b/src/tests/stdlib/multiplexer/select_element/main.asm new file mode 100644 index 000000000..c2fca003d --- /dev/null +++ b/src/tests/stdlib/multiplexer/select_element/main.asm @@ -0,0 +1,36 @@ +@ noname.0.7.0 +@ public inputs: 12 + +0 == (v_14) * (1) +v_16 == (v_15) * (v_13 + -1 * v_14) +-1 * v_17 + 1 == (v_16) * (1) +v_18 == (v_17) * (v_13 + -1 * v_14) +0 == (v_18) * (1) +1 == (v_19) * (1) +v_21 == (v_20) * (v_13 + -1 * v_19) +-1 * v_22 + 1 == (v_21) * (1) +v_23 == (v_22) * (v_13 + -1 * v_19) +0 == (v_23) * (1) +2 == (v_24) * (1) +v_26 == (v_25) * (v_13 + -1 * v_24) +-1 * v_27 + 1 == (v_26) * (1) +v_28 == (v_27) * (v_13 + -1 * v_24) +0 == (v_28) * (1) +1 == (v_29) * (1) +v_31 == (v_30) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29) +-1 * v_32 + 1 == (v_31) * (1) +v_33 == (v_32) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29) +0 == (v_33) * (1) +1 == (v_32) * (1) +v_34 == (v_4) * (v_17) +v_35 == (v_7) * (v_22) +v_36 == (v_10) * (v_27) +v_37 == (v_5) * (v_17) +v_38 == (v_8) * (v_22) +v_39 == (v_11) * (v_27) +v_40 == (v_6) * (v_17) +v_41 == (v_9) * (v_22) +v_42 == (v_12) * (v_27) +v_34 + v_35 + v_36 == (v_1) * (1) +v_37 + v_38 + v_39 == (v_2) * (1) +v_40 + v_41 + v_42 == (v_3) * (1) diff --git a/src/tests/stdlib/multiplexer/select_element/main.no b/src/tests/stdlib/multiplexer/select_element/main.no new file mode 100644 index 000000000..b06ae625c --- /dev/null +++ b/src/tests/stdlib/multiplexer/select_element/main.no @@ -0,0 +1,6 @@ +use std::multiplexer; + +fn main(pub xx: [[Field; 3]; 3], sel: Field) -> [Field; 3] { + let chosen_elements = multiplexer::select_element(xx, sel); + return chosen_elements; +} \ No newline at end of file From 83e3491f168c8c1c503b817764c022f7b29d2eaa Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Mon, 28 Oct 2024 15:39:30 +0800 Subject: [PATCH 36/36] Add MIMC stdlib (#208) * add mimc stdlib --- src/error.rs | 2 +- src/mast/mod.rs | 42 ++--- src/parser/types.rs | 30 ++-- src/stdlib/mod.rs | 2 +- src/stdlib/native/mimc/lib.no | 153 +++++++++++++++++ src/tests/stdlib/mimc/mimc_main.no | 7 + src/tests/stdlib/mimc/mod.rs | 207 +++++++++++++++++++++++ src/tests/stdlib/mimc/multi_mimc_main.no | 7 + src/tests/stdlib/mod.rs | 1 + src/type_checker/checker.rs | 6 +- 10 files changed, 423 insertions(+), 34 deletions(-) create mode 100644 src/stdlib/native/mimc/lib.no create mode 100644 src/tests/stdlib/mimc/mimc_main.no create mode 100644 src/tests/stdlib/mimc/mod.rs create mode 100644 src/tests/stdlib/mimc/multi_mimc_main.no diff --git a/src/error.rs b/src/error.rs index c9a56208e..c06b9bf35 100644 --- a/src/error.rs +++ b/src/error.rs @@ -140,7 +140,7 @@ pub enum ErrorKind { GenericValueExpected(String), #[error("conflict generic values during binding for `{0}`: `{1}` and `{2}`")] - ConflictGenericValue(String, u32, u32), + ConflictGenericValue(String, String, String), #[error("unexpected generic parameter: `{0}`")] UnexpectedGenericParameter(String), diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 89ad4cc6b..f2521c156 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -36,27 +36,27 @@ pub struct ExprMonoInfo { #[derive(Debug, Clone)] pub enum PropagatedConstant { - Single(u32), + Single(BigUint), Array(Vec), Custom(HashMap), } impl PropagatedConstant { - pub fn as_single(&self) -> u32 { + pub fn as_single(&self) -> BigUint { match self { - PropagatedConstant::Single(v) => *v, + PropagatedConstant::Single(v) => v.clone(), _ => panic!("expected single value"), } } - pub fn as_array(&self) -> Vec { + pub fn as_array(&self) -> Vec { match self { PropagatedConstant::Array(v) => v.iter().map(|c| c.as_single()).collect(), _ => panic!("expected array value"), } } - pub fn as_custom(&self) -> HashMap { + pub fn as_custom(&self) -> HashMap { match self { PropagatedConstant::Custom(v) => { v.iter().map(|(k, c)| (k.clone(), c.as_single())).collect() @@ -67,8 +67,8 @@ impl PropagatedConstant { } /// impl From trait for single value -impl From for PropagatedConstant { - fn from(v: u32) -> Self { +impl From for PropagatedConstant { + fn from(v: BigUint) -> Self { PropagatedConstant::Single(v) } } @@ -224,7 +224,10 @@ impl FnSig { TyKind::Array(ty, size) => TyKind::Array(Box::new(self.resolve_type(ty, ctx)), *size), TyKind::GenericSizedArray(ty, sym) => { let val = sym.eval(&self.generics, &ctx.tast); - TyKind::Array(Box::new(self.resolve_type(ty, ctx)), val) + TyKind::Array( + Box::new(self.resolve_type(ty, ctx)), + val.to_u32().expect("array size exceeded u32"), + ) } _ => typ.clone(), } @@ -381,9 +384,9 @@ impl MastCtx { impl Symbolic { /// Evaluate symbolic size to an integer. - pub fn eval(&self, gens: &GenericParameters, tast: &TypeChecker) -> u32 { + pub fn eval(&self, gens: &GenericParameters, tast: &TypeChecker) -> BigUint { match self { - Symbolic::Concrete(v) => *v, + Symbolic::Concrete(v) => v.clone(), Symbolic::Constant(var) => { let qualified = FullyQualified::local(var.value.clone()); let cst = tast.const_info(&qualified).expect("constant not found"); @@ -805,11 +808,7 @@ fn monomorphize_expr( Some(v) => { let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); - ExprMonoInfo::new( - mexpr, - typ, - Some(PropagatedConstant::from(v.to_u32().unwrap())), - ) + ExprMonoInfo::new(mexpr, typ, Some(PropagatedConstant::from(v))) } // keep as is _ => { @@ -846,13 +845,12 @@ fn monomorphize_expr( } ExprKind::BigUInt(inner) => { - let cst: u32 = inner.try_into().expect("biguint too large"); let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(inner.clone())); ExprMonoInfo::new( mexpr, Some(TyKind::Field { constant: true }), - Some(PropagatedConstant::from(cst)), + Some(PropagatedConstant::from(inner.clone())), ) } @@ -892,13 +890,12 @@ fn monomorphize_expr( // if it's a variable, // check if it's a constant first let bigint: BigUint = cst.value[0].into(); - let cst: u32 = bigint.clone().try_into().expect("biguint too large"); - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(bigint)); + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(bigint.clone())); ExprMonoInfo::new( mexpr, Some(TyKind::Field { constant: true }), - Some(PropagatedConstant::from(cst)), + Some(PropagatedConstant::from(bigint)), ) } else { // otherwise it's a local variable @@ -1090,7 +1087,10 @@ fn monomorphize_expr( ); if let Some(cst) = size_mono.constant { - let arr_typ = TyKind::Array(Box::new(item_typ), cst.as_single()); + let arr_typ = TyKind::Array( + Box::new(item_typ), + cst.as_single().to_u32().expect("array size too large"), + ); ExprMonoInfo::new(mexpr, Some(arr_typ), None) } else { return Err(error(ErrorKind::InvalidArraySize, expr.span)); diff --git a/src/parser/types.rs b/src/parser/types.rs index 0691f1356..4ce8645c8 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -1,4 +1,5 @@ use educe::Educe; +use num_bigint::BigUint; use std::{ collections::{HashMap, HashSet}, fmt::Display, @@ -7,7 +8,7 @@ use std::{ }; use ark_ff::Field; -use num_traits::ToPrimitive; +use num_traits::FromPrimitive; use serde::{Deserialize, Serialize}; use crate::{ @@ -15,7 +16,6 @@ use crate::{ constants::Span, error::{Error, ErrorKind, Result}, lexer::{Keyword, Token, TokenKind, Tokens}, - mast::ExprMonoInfo, stdlib::builtins::BUILTIN_FN_NAMES, syntax::{is_generic_parameter, is_type}, }; @@ -181,7 +181,7 @@ pub enum ModulePath { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum Symbolic { /// A literal number - Concrete(u32), + Concrete(BigUint), /// Point to a constant variable Constant(Ident), /// Generic parameter @@ -230,7 +230,7 @@ impl Symbolic { /// Parse from an expression node recursively. pub fn parse(node: &Expr) -> Result { match &node.kind { - ExprKind::BigUInt(n) => Ok(Symbolic::Concrete(n.to_u32().unwrap())), + ExprKind::BigUInt(n) => Ok(Symbolic::Concrete(n.clone())), ExprKind::Variable { module: _, name } => { if is_generic_parameter(&name.value) { Ok(Symbolic::Generic(name.clone())) @@ -571,7 +571,11 @@ impl FnSig { // resolve the generic parameter match sym { Symbolic::Generic(ident) => { - self.generics.assign(&ident.value, *observed_size, span)?; + self.generics.assign( + &ident.value, + BigUint::from_u32(*observed_size).unwrap(), + span, + )?; } _ => unreachable!("no operation allowed on symbolic size in function argument"), } @@ -633,7 +637,7 @@ impl FnSig { let generics = generics .iter() - .map(|(name, value)| format!("{}={}", name, value.unwrap())) + .map(|(name, value)| format!("{}={}", name, value.as_ref().unwrap())) .collect::>() .join("#"); @@ -742,7 +746,7 @@ pub struct ResolvedSig { #[derive(Debug, Default, Clone, Serialize, Deserialize)] /// Generic parameters for a function signature pub struct GenericParameters { - pub parameters: HashMap>, + pub parameters: HashMap>, pub resolved_sig: Option, } @@ -758,11 +762,13 @@ impl GenericParameters { } /// Get the value of a generic parameter - pub fn get(&self, name: &str) -> u32 { + pub fn get(&self, name: &str) -> BigUint { self.parameters .get(name) .expect("generic parameter not found") + .as_ref() .expect("generic value not assigned") + .clone() } /// Returns whether the generic parameters are empty @@ -771,7 +777,7 @@ impl GenericParameters { } /// Bind a generic parameter to a value - pub fn assign(&mut self, name: &String, value: u32, span: Span) -> Result<()> { + pub fn assign(&mut self, name: &String, value: BigUint, span: Span) -> Result<()> { let existing = self.parameters.get(name); match existing { Some(Some(v)) => { @@ -781,7 +787,11 @@ impl GenericParameters { Err(Error::new( "mast", - ErrorKind::ConflictGenericValue(name.to_string(), *v, value), + ErrorKind::ConflictGenericValue( + name.to_string(), + v.to_str_radix(10), + value.to_str_radix(10), + ), span, )) } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index ae7aed820..e01489059 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -93,7 +93,7 @@ pub fn init_stdlib_dep( path_prefix: &str, ) -> usize { // list the stdlib dependency in order - let libs = vec!["bits", "comparator", "multiplexer", "int"]; + let libs = vec!["bits", "comparator", "multiplexer", "mimc", "int"]; let mut node_id = node_id; diff --git a/src/stdlib/native/mimc/lib.no b/src/stdlib/native/mimc/lib.no new file mode 100644 index 000000000..7681f1e4d --- /dev/null +++ b/src/stdlib/native/mimc/lib.no @@ -0,0 +1,153 @@ +/// MIMC hash function using exponentiation with 7. +/// This function allows a maximum of 91 rounds, with each round using exponentiation of 7. +/// +/// # Parameters +/// - `value`: The input value to be hashed. +/// - `key`: The secret key used in the hash computation. +/// +/// # Returns +/// - `Field`: The resulting hash after `ROUNDS` of MIMC operations. +/// +/// # Constraints +/// - `ROUNDS` must be within the range of constants defined in `csts`. +fn mimc7_cipher(value: Field, key: Field) -> Field { + let rounds = 91; + // Initial value: sum of the key and input value. + let init = key + value; + // Variable for the accumulative result of exponentiation with 7. + let mut exp7 = 0; + + // Predefined constants for each round of the hash function. + let csts = [ + 0, + 20888961410941983456478427210666206549300505294776164667214940546594746570981, + 15265126113435022738560151911929040668591755459209400716467504685752745317193, + 8334177627492981984476504167502758309043212251641796197711684499645635709656, + 1374324219480165500871639364801692115397519265181803854177629327624133579404, + 11442588683664344394633565859260176446561886575962616332903193988751292992472, + 2558901189096558760448896669327086721003508630712968559048179091037845349145, + 11189978595292752354820141775598510151189959177917284797737745690127318076389, + 3262966573163560839685415914157855077211340576201936620532175028036746741754, + 17029914891543225301403832095880481731551830725367286980611178737703889171730, + 4614037031668406927330683909387957156531244689520944789503628527855167665518, + 19647356996769918391113967168615123299113119185942498194367262335168397100658, + 5040699236106090655289931820723926657076483236860546282406111821875672148900, + 2632385916954580941368956176626336146806721642583847728103570779270161510514, + 17691411851977575435597871505860208507285462834710151833948561098560743654671, + 11482807709115676646560379017491661435505951727793345550942389701970904563183, + 8360838254132998143349158726141014535383109403565779450210746881879715734773, + 12663821244032248511491386323242575231591777785787269938928497649288048289525, + 3067001377342968891237590775929219083706800062321980129409398033259904188058, + 8536471869378957766675292398190944925664113548202769136103887479787957959589, + 19825444354178182240559170937204690272111734703605805530888940813160705385792, + 16703465144013840124940690347975638755097486902749048533167980887413919317592, + 13061236261277650370863439564453267964462486225679643020432589226741411380501, + 10864774797625152707517901967943775867717907803542223029967000416969007792571, + 10035653564014594269791753415727486340557376923045841607746250017541686319774, + 3446968588058668564420958894889124905706353937375068998436129414772610003289, + 4653317306466493184743870159523234588955994456998076243468148492375236846006, + 8486711143589723036499933521576871883500223198263343024003617825616410932026, + 250710584458582618659378487568129931785810765264752039738223488321597070280, + 2104159799604932521291371026105311735948154964200596636974609406977292675173, + 16313562605837709339799839901240652934758303521543693857533755376563489378839, + 6032365105133504724925793806318578936233045029919447519826248813478479197288, + 14025118133847866722315446277964222215118620050302054655768867040006542798474, + 7400123822125662712777833064081316757896757785777291653271747396958201309118, + 1744432620323851751204287974553233986555641872755053103823939564833813704825, + 8316378125659383262515151597439205374263247719876250938893842106722210729522, + 6739722627047123650704294650168547689199576889424317598327664349670094847386, + 21211457866117465531949733809706514799713333930924902519246949506964470524162, + 13718112532745211817410303291774369209520657938741992779396229864894885156527, + 5264534817993325015357427094323255342713527811596856940387954546330728068658, + 18884137497114307927425084003812022333609937761793387700010402412840002189451, + 5148596049900083984813839872929010525572543381981952060869301611018636120248, + 19799686398774806587970184652860783461860993790013219899147141137827718662674, + 19240878651604412704364448729659032944342952609050243268894572835672205984837, + 10546185249390392695582524554167530669949955276893453512788278945742408153192, + 5507959600969845538113649209272736011390582494851145043668969080335346810411, + 18177751737739153338153217698774510185696788019377850245260475034576050820091, + 19603444733183990109492724100282114612026332366576932662794133334264283907557, + 10548274686824425401349248282213580046351514091431715597441736281987273193140, + 1823201861560942974198127384034483127920205835821334101215923769688644479957, + 11867589662193422187545516240823411225342068709600734253659804646934346124945, + 18718569356736340558616379408444812528964066420519677106145092918482774343613, + 10530777752259630125564678480897857853807637120039176813174150229243735996839, + 20486583726592018813337145844457018474256372770211860618687961310422228379031, + 12690713110714036569415168795200156516217175005650145422920562694422306200486, + 17386427286863519095301372413760745749282643730629659997153085139065756667205, + 2216432659854733047132347621569505613620980842043977268828076165669557467682, + 6309765381643925252238633914530877025934201680691496500372265330505506717193, + 20806323192073945401862788605803131761175139076694468214027227878952047793390, + 4037040458505567977365391535756875199663510397600316887746139396052445718861, + 19948974083684238245321361840704327952464170097132407924861169241740046562673, + 845322671528508199439318170916419179535949348988022948153107378280175750024, + 16222384601744433420585982239113457177459602187868460608565289920306145389382, + 10232118865851112229330353999139005145127746617219324244541194256766741433339, + 6699067738555349409504843460654299019000594109597429103342076743347235369120, + 6220784880752427143725783746407285094967584864656399181815603544365010379208, + 6129250029437675212264306655559561251995722990149771051304736001195288083309, + 10773245783118750721454994239248013870822765715268323522295722350908043393604, + 4490242021765793917495398271905043433053432245571325177153467194570741607167, + 19596995117319480189066041930051006586888908165330319666010398892494684778526, + 837850695495734270707668553360118467905109360511302468085569220634750561083, + 11803922811376367215191737026157445294481406304781326649717082177394185903907, + 10201298324909697255105265958780781450978049256931478989759448189112393506592, + 13564695482314888817576351063608519127702411536552857463682060761575100923924, + 9262808208636973454201420823766139682381973240743541030659775288508921362724, + 173271062536305557219323722062711383294158572562695717740068656098441040230, + 18120430890549410286417591505529104700901943324772175772035648111937818237369, + 20484495168135072493552514219686101965206843697794133766912991150184337935627, + 19155651295705203459475805213866664350848604323501251939850063308319753686505, + 11971299749478202793661982361798418342615500543489781306376058267926437157297, + 18285310723116790056148596536349375622245669010373674803854111592441823052978, + 7069216248902547653615508023941692395371990416048967468982099270925308100727, + 6465151453746412132599596984628739550147379072443683076388208843341824127379, + 16143532858389170960690347742477978826830511669766530042104134302796355145785, + 19362583304414853660976404410208489566967618125972377176980367224623492419647, + 1702213613534733786921602839210290505213503664731919006932367875629005980493, + 10781825404476535814285389902565833897646945212027592373510689209734812292327, + 4212716923652881254737947578600828255798948993302968210248673545442808456151, + 7594017890037021425366623750593200398174488805473151513558919864633711506220, + 18979889247746272055963929241596362599320706910852082477600815822482192194401, + 13602139229813231349386885113156901793661719180900395818909719758150455500533, + ]; + + // Iterate through each round to compute the hash. + for round in 0..rounds { + // Calculate intermediate values based on the round. + let exp1_else = (key + exp7) + csts[round]; + let exp1 = if round == 0 { init } else { exp1_else }; + let exp2 = exp1 * exp1; + let exp4 = exp2 * exp2; + let exp6 = exp4 * exp2; + let exp7_then = exp6 * exp1; + let exp7_else = exp7_then + key; + // Update exp7 based on whether it's the last round. + exp7 = if round != (rounds - 1) { exp7_then } else { exp7_else }; + } + + // Return the final hash value. + return exp7; +} + +/// MIMC hash function for multiple values. +/// Uses the `mimc7_cipher` function iteratively to hash an array of values. +/// +/// # Parameters +/// - `values`: An array of `Field` values to be hashed. +/// - `key`: The secret key used in the hash computation. +/// +/// # Returns +/// - `Field`: The resulting hash after processing all input values. +/// +fn mimc7_hash(values: [Field; LEN], key: Field) -> Field { + // Initialize with the key. + let mut res = key; + // Iterate over each value in the input array. + for value in values { + // Update the result with the MIMC hash of the value and the current result. + res = res + (value + mimc7_cipher(value, res)); + } + // Return the final accumulated result. + return res; +} diff --git a/src/tests/stdlib/mimc/mimc_main.no b/src/tests/stdlib/mimc/mimc_main.no new file mode 100644 index 000000000..1e104d1f8 --- /dev/null +++ b/src/tests/stdlib/mimc/mimc_main.no @@ -0,0 +1,7 @@ +use std::mimc; + +fn main(pub key: Field, val: Field) -> Field { + let res = mimc::mimc7_cipher(val, key); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/mimc/mod.rs b/src/tests/stdlib/mimc/mod.rs new file mode 100644 index 000000000..cf11059be --- /dev/null +++ b/src/tests/stdlib/mimc/mod.rs @@ -0,0 +1,207 @@ +use crate::{ + backends::r1cs::R1csBn254Field, + error::{self}, +}; + +use super::test_stdlib; +use ark_ff::Field; +use error::Result; +use num_bigint::BigUint; +use num_traits::Zero; +use rstest::rstest; +use std::str::FromStr; + +/// Parses a decimal string into an Fq field element. +fn fq_from_str(s: &str) -> R1csBn254Field { + R1csBn254Field::from_str(s).unwrap() +} + +/// MiMC7 hash function implementation. +fn mimc7(x_in: R1csBn254Field, k: R1csBn254Field, n_rounds: usize) -> R1csBn254Field { + // Round constants c[91] + let c_strings = [ + "0", + "20888961410941983456478427210666206549300505294776164667214940546594746570981", + "15265126113435022738560151911929040668591755459209400716467504685752745317193", + "8334177627492981984476504167502758309043212251641796197711684499645635709656", + "1374324219480165500871639364801692115397519265181803854177629327624133579404", + "11442588683664344394633565859260176446561886575962616332903193988751292992472", + "2558901189096558760448896669327086721003508630712968559048179091037845349145", + "11189978595292752354820141775598510151189959177917284797737745690127318076389", + "3262966573163560839685415914157855077211340576201936620532175028036746741754", + "17029914891543225301403832095880481731551830725367286980611178737703889171730", + "4614037031668406927330683909387957156531244689520944789503628527855167665518", + "19647356996769918391113967168615123299113119185942498194367262335168397100658", + "5040699236106090655289931820723926657076483236860546282406111821875672148900", + "2632385916954580941368956176626336146806721642583847728103570779270161510514", + "17691411851977575435597871505860208507285462834710151833948561098560743654671", + "11482807709115676646560379017491661435505951727793345550942389701970904563183", + "8360838254132998143349158726141014535383109403565779450210746881879715734773", + "12663821244032248511491386323242575231591777785787269938928497649288048289525", + "3067001377342968891237590775929219083706800062321980129409398033259904188058", + "8536471869378957766675292398190944925664113548202769136103887479787957959589", + "19825444354178182240559170937204690272111734703605805530888940813160705385792", + "16703465144013840124940690347975638755097486902749048533167980887413919317592", + "13061236261277650370863439564453267964462486225679643020432589226741411380501", + "10864774797625152707517901967943775867717907803542223029967000416969007792571", + "10035653564014594269791753415727486340557376923045841607746250017541686319774", + "3446968588058668564420958894889124905706353937375068998436129414772610003289", + "4653317306466493184743870159523234588955994456998076243468148492375236846006", + "8486711143589723036499933521576871883500223198263343024003617825616410932026", + "250710584458582618659378487568129931785810765264752039738223488321597070280", + "2104159799604932521291371026105311735948154964200596636974609406977292675173", + "16313562605837709339799839901240652934758303521543693857533755376563489378839", + "6032365105133504724925793806318578936233045029919447519826248813478479197288", + "14025118133847866722315446277964222215118620050302054655768867040006542798474", + "7400123822125662712777833064081316757896757785777291653271747396958201309118", + "1744432620323851751204287974553233986555641872755053103823939564833813704825", + "8316378125659383262515151597439205374263247719876250938893842106722210729522", + "6739722627047123650704294650168547689199576889424317598327664349670094847386", + "21211457866117465531949733809706514799713333930924902519246949506964470524162", + "13718112532745211817410303291774369209520657938741992779396229864894885156527", + "5264534817993325015357427094323255342713527811596856940387954546330728068658", + "18884137497114307927425084003812022333609937761793387700010402412840002189451", + "5148596049900083984813839872929010525572543381981952060869301611018636120248", + "19799686398774806587970184652860783461860993790013219899147141137827718662674", + "19240878651604412704364448729659032944342952609050243268894572835672205984837", + "10546185249390392695582524554167530669949955276893453512788278945742408153192", + "5507959600969845538113649209272736011390582494851145043668969080335346810411", + "18177751737739153338153217698774510185696788019377850245260475034576050820091", + "19603444733183990109492724100282114612026332366576932662794133334264283907557", + "10548274686824425401349248282213580046351514091431715597441736281987273193140", + "1823201861560942974198127384034483127920205835821334101215923769688644479957", + "11867589662193422187545516240823411225342068709600734253659804646934346124945", + "18718569356736340558616379408444812528964066420519677106145092918482774343613", + "10530777752259630125564678480897857853807637120039176813174150229243735996839", + "20486583726592018813337145844457018474256372770211860618687961310422228379031", + "12690713110714036569415168795200156516217175005650145422920562694422306200486", + "17386427286863519095301372413760745749282643730629659997153085139065756667205", + "2216432659854733047132347621569505613620980842043977268828076165669557467682", + "6309765381643925252238633914530877025934201680691496500372265330505506717193", + "20806323192073945401862788605803131761175139076694468214027227878952047793390", + "4037040458505567977365391535756875199663510397600316887746139396052445718861", + "19948974083684238245321361840704327952464170097132407924861169241740046562673", + "845322671528508199439318170916419179535949348988022948153107378280175750024", + "16222384601744433420585982239113457177459602187868460608565289920306145389382", + "10232118865851112229330353999139005145127746617219324244541194256766741433339", + "6699067738555349409504843460654299019000594109597429103342076743347235369120", + "6220784880752427143725783746407285094967584864656399181815603544365010379208", + "6129250029437675212264306655559561251995722990149771051304736001195288083309", + "10773245783118750721454994239248013870822765715268323522295722350908043393604", + "4490242021765793917495398271905043433053432245571325177153467194570741607167", + "19596995117319480189066041930051006586888908165330319666010398892494684778526", + "837850695495734270707668553360118467905109360511302468085569220634750561083", + "11803922811376367215191737026157445294481406304781326649717082177394185903907", + "10201298324909697255105265958780781450978049256931478989759448189112393506592", + "13564695482314888817576351063608519127702411536552857463682060761575100923924", + "9262808208636973454201420823766139682381973240743541030659775288508921362724", + "173271062536305557219323722062711383294158572562695717740068656098441040230", + "18120430890549410286417591505529104700901943324772175772035648111937818237369", + "20484495168135072493552514219686101965206843697794133766912991150184337935627", + "19155651295705203459475805213866664350848604323501251939850063308319753686505", + "11971299749478202793661982361798418342615500543489781306376058267926437157297", + "18285310723116790056148596536349375622245669010373674803854111592441823052978", + "7069216248902547653615508023941692395371990416048967468982099270925308100727", + "6465151453746412132599596984628739550147379072443683076388208843341824127379", + "16143532858389170960690347742477978826830511669766530042104134302796355145785", + "19362583304414853660976404410208489566967618125972377176980367224623492419647", + "1702213613534733786921602839210290505213503664731919006932367875629005980493", + "10781825404476535814285389902565833897646945212027592373510689209734812292327", + "4212716923652881254737947578600828255798948993302968210248673545442808456151", + "7594017890037021425366623750593200398174488805473151513558919864633711506220", + "18979889247746272055963929241596362599320706910852082477600815822482192194401", + "13602139229813231349386885113156901793661719180900395818909719758150455500533", + ]; + + // Convert constants to field elements + let c: Vec = c_strings + .iter() + .map(|&s| R1csBn254Field::from_str(s).unwrap()) + .collect(); + + let mut t7 = Vec::with_capacity(n_rounds - 1); + let mut out = R1csBn254Field::zero(); + + for i in 0..n_rounds { + let t = if i == 0 { + k + x_in + } else { + k + t7[i - 1] + c[i] + }; + + let t2 = t.square(); // t^2 + let t4 = t2.square(); // t^4 + let t6 = t4 * t2; // t^6 + + if i < n_rounds - 1 { + let t7_i = t6 * t; // t^7 + t7.push(t7_i); + } else { + out = t6 * t + k; // Final output: t^7 + k + } + } + + out +} + +fn multi_mimc7(k: R1csBn254Field, values: Vec, n_rounds: usize) -> R1csBn254Field { + let mut res = k; + for x in values { + res = res + x + mimc7(x, res, n_rounds); + } + res +} + +#[rstest] +#[case(0, 1, 91)] +fn test_mimc(#[case] key: u32, #[case] val: u32, #[case] n_rounds: usize) -> Result<()> { + let public_inputs = format!(r#"{{"key": "{}"}}"#, key); + let private_inputs = format!(r#"{{"val": "{}"}}"#, val); + + let x = fq_from_str(val.to_string().as_str()); + let k = fq_from_str(key.to_string().as_str()); + + let expected_output: BigUint = mimc7(x, k, n_rounds).into(); + + test_stdlib( + "mimc/mimc_main.no", + None, + &public_inputs, + &private_inputs, + vec![&expected_output.to_string()], + )?; + + Ok(()) +} + +#[rstest] +#[case(0, vec![1, 2, 3])] +fn test_multi_mimc(#[case] key: u32, #[case] values: Vec) -> Result<()> { + let k = fq_from_str(key.to_string().as_str()); + let x = values + .iter() + .map(|v| fq_from_str(v.to_string().as_str())) + .collect(); + + let expected_output: BigUint = multi_mimc7(k, x, 91).into(); + + let public_inputs = format!(r#"{{"key": "{}"}}"#, key); + // convert to ["1", "2", ...] + let private_inputs = format!( + r#"{{"values": {:?}}}"#, + values + .iter() + .map(|v| v.to_string()) + .collect::>() + ); + + test_stdlib( + "mimc/multi_mimc_main.no", + None, + &public_inputs, + &private_inputs, + vec![&expected_output.to_string()], + )?; + + Ok(()) +} diff --git a/src/tests/stdlib/mimc/multi_mimc_main.no b/src/tests/stdlib/mimc/multi_mimc_main.no new file mode 100644 index 000000000..73f56d104 --- /dev/null +++ b/src/tests/stdlib/mimc/multi_mimc_main.no @@ -0,0 +1,7 @@ +use std::mimc; + +fn main(pub key: Field, values: [Field; 3]) -> Field { + let res = mimc::mimc7_hash(values, key); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 6f81d2a9f..3bac61a6f 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -1,4 +1,5 @@ mod comparator; +mod mimc; mod multiplexer; use std::{path::Path, str::FromStr}; diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index ffa802c9f..cb69360c9 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use num_traits::ToPrimitive; use serde::{Deserialize, Serialize}; use crate::{ @@ -619,7 +620,10 @@ impl TypeChecker { let sym = Symbolic::parse(size)?; let res = if let Symbolic::Concrete(size) = sym { // if sym is a concrete variant, then just return concrete array type - ExprTyInfo::new_anon(TyKind::Array(Box::new(item_node.typ), size)) + ExprTyInfo::new_anon(TyKind::Array( + Box::new(item_node.typ), + size.to_u32().expect("array size too large"), + )) } else { // use generic array as the size node might include generic parameters or constant vars ExprTyInfo::new_anon(TyKind::GenericSizedArray(