From 18825a3a4a53572bd038c679e823ffc5d2e17ff9 Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Tue, 22 Oct 2024 19:57:45 +0800 Subject: [PATCH 1/8] Fix: Allow struct fields to propagate constants (#204) * fix: make struct field propagate constants --- src/mast/mod.rs | 252 ++++++++++++++++++++++++++---------- src/negative_tests.rs | 52 ++++++++ src/stdlib/native/int.no | 10 +- src/type_checker/checker.rs | 32 +++++ src/type_checker/mod.rs | 23 ++++ 5 files changed, 300 insertions(+), 69 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index f16a2cbac..89ad4cc6b 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -9,7 +9,8 @@ use crate::{ imports::FnKind, parser::{ types::{ - FnSig, ForLoopArgument, GenericParameters, Range, Stmt, StmtKind, Symbolic, Ty, TyKind, + FnSig, ForLoopArgument, GenericParameters, Ident, Range, Stmt, StmtKind, Symbolic, Ty, + TyKind, }, CustomType, Expr, ExprKind, FunctionDef, Op2, }, @@ -29,19 +30,51 @@ pub struct ExprMonoInfo { /// The generic types shouldn't be presented in this field. pub typ: Option, - // todo: see if we can do constant folding on the expression nodes. - // - it is possible to remove this field, as the constant value can be extracted from folded expression node - /// Numeric value of the expression - /// applicable to BigInt type - pub constant: Option, + /// Propagated constant value + pub constant: Option, } -impl ExprMonoInfo { - pub fn new(expr: Expr, typ: Option, value: Option) -> Self { - if value.is_some() && !matches!(typ, Some(TyKind::Field { constant: true })) { - panic!("value can only be set for BigInt type"); +#[derive(Debug, Clone)] +pub enum PropagatedConstant { + Single(u32), + Array(Vec), + Custom(HashMap), +} + +impl PropagatedConstant { + pub fn as_single(&self) -> u32 { + match self { + PropagatedConstant::Single(v) => *v, + _ => panic!("expected single value"), } + } + pub fn as_array(&self) -> Vec { + match self { + PropagatedConstant::Array(v) => v.iter().map(|c| c.as_single()).collect(), + _ => panic!("expected array value"), + } + } + + pub fn as_custom(&self) -> HashMap { + match self { + PropagatedConstant::Custom(v) => { + v.iter().map(|(k, c)| (k.clone(), c.as_single())).collect() + } + _ => panic!("expected custom value"), + } + } +} + +/// impl From trait for single value +impl From for PropagatedConstant { + fn from(v: u32) -> Self { + PropagatedConstant::Single(v) + } +} + +impl ExprMonoInfo { + pub fn new(expr: Expr, typ: Option, value: Option) -> Self { Self { expr, typ, @@ -68,18 +101,18 @@ pub struct MTypeInfo { pub typ: TyKind, /// Store constant value - pub value: Option, + pub constant: Option, /// The span of the variable declaration. pub span: Span, } impl MTypeInfo { - pub fn new(typ: &TyKind, span: Span, value: Option) -> Self { + pub fn new(typ: &TyKind, span: Span, value: Option) -> Self { Self { typ: typ.clone(), span, - value, + constant: value, } } } @@ -216,11 +249,11 @@ impl FnSig { } // const NN: Field _ => { - let cst = observed_arg.constant; + let cst = observed_arg.constant.clone(); if is_generic_parameter(sig_arg.name.value.as_str()) && cst.is_some() { self.generics.assign( &sig_arg.name.value, - cst.unwrap(), + cst.unwrap().as_single(), observed_arg.expr.span, )?; } @@ -269,6 +302,10 @@ where functions_instantiated: HashMap, // new method name as the key, old method name as the value methods_instantiated: HashMap<(FullyQualified, String), String>, + // cache for [PropagatedConstant] values from instantiated methods + cst_method_cache: HashMap<(FullyQualified, String), PropagatedConstant>, + // cache for [PropagatedConstant] values from instantiated functions + cst_fn_cache: HashMap, } impl MastCtx { @@ -278,6 +315,8 @@ impl MastCtx { generic_func_scope: Some(0), functions_instantiated: HashMap::new(), methods_instantiated: HashMap::new(), + cst_method_cache: HashMap::new(), + cst_fn_cache: HashMap::new(), } } @@ -509,7 +548,12 @@ fn monomorphize_expr( }, ); - let cst = None; + // propagate the constant value + let cst = lhs_mono.constant.and_then(|c| match c { + PropagatedConstant::Custom(map) => map.get(rhs).cloned(), + _ => None, + }); + ExprMonoInfo::new(mexpr, typ, cst) } @@ -558,10 +602,19 @@ fn monomorphize_expr( .as_ref() .and_then(|sig| sig.return_type.clone().map(|t| t.kind)); - ExprMonoInfo::new(mexpr, typ, None) + // retrieve the constant value from the cache + let cst = ctx.cst_fn_cache.get(&mono_qualified).cloned(); + + ExprMonoInfo::new(mexpr, typ, cst) } else { // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ, cst) = + instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + + // cache the constant value + if let Some(cst) = cst.clone() { + ctx.cst_fn_cache.insert(mono_qualified.clone(), cst); + } let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( @@ -577,7 +630,7 @@ fn monomorphize_expr( let new_qualified = FullyQualified::new(module, &fn_name_mono.value); ctx.add_monomorphized_fn(old_qualified, new_qualified, fn_info_mono); - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, cst) } } @@ -649,10 +702,23 @@ fn monomorphize_expr( }, ); let typ = resolved_sig.return_type.clone().map(|t| t.kind); - ExprMonoInfo::new(mexpr, typ, None) + + // retrieve the constant value from the cache + let cst = ctx + .cst_method_cache + .get(&(struct_qualified.clone(), method_name.value.clone())) + .cloned(); + + ExprMonoInfo::new(mexpr, typ, cst) } else { // monomorphize the function call - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ, cst) = + instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + // cache the constant value + if let Some(cst) = cst.clone() { + ctx.cst_method_cache + .insert((struct_qualified.clone(), method_name.value.clone()), cst); + } let fn_name_mono = &fn_info_mono.sig().name; let mexpr = expr.to_mast( @@ -668,7 +734,7 @@ fn monomorphize_expr( ctx.tast .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, cst) } } @@ -739,9 +805,14 @@ fn monomorphize_expr( Some(v) => { let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); - ExprMonoInfo::new(mexpr, typ, v.to_u32()) + ExprMonoInfo::new( + mexpr, + typ, + Some(PropagatedConstant::from(v.to_u32().unwrap())), + ) } - None => { + // keep as is + _ => { let mexpr = expr.to_mast( ctx, &ExprKind::BinaryOp { @@ -778,7 +849,11 @@ fn monomorphize_expr( let cst: u32 = inner.try_into().expect("biguint too large"); let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(inner.clone())); - ExprMonoInfo::new(mexpr, Some(TyKind::Field { constant: true }), Some(cst)) + ExprMonoInfo::new( + mexpr, + Some(TyKind::Field { constant: true }), + Some(PropagatedConstant::from(cst)), + ) } ExprKind::Bool(inner) => { @@ -794,10 +869,10 @@ fn monomorphize_expr( let res = if is_generic_parameter(&name.value) { let mtype = mono_fn_env.get_type_info(&name.value).unwrap(); - let mexpr = - expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(mtype.value.unwrap()))); + let cst = mtype.constant.clone().unwrap().as_single(); + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(cst))); - ExprMonoInfo::new(mexpr, Some(mtype.typ.clone()), mtype.value) + ExprMonoInfo::new(mexpr, Some(mtype.typ.clone()), mtype.constant.clone()) } else if is_type(&name.value) { let mtype = TyKind::Custom { module: module.clone(), @@ -820,7 +895,11 @@ fn monomorphize_expr( let cst: u32 = bigint.clone().try_into().expect("biguint too large"); let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(bigint)); - ExprMonoInfo::new(mexpr, Some(TyKind::Field { constant: true }), Some(cst)) + ExprMonoInfo::new( + mexpr, + Some(TyKind::Field { constant: true }), + Some(PropagatedConstant::from(cst)), + ) } else { // otherwise it's a local variable let mexpr = expr.to_mast( @@ -832,7 +911,7 @@ fn monomorphize_expr( ); let mtype = mono_fn_env.get_type_info(&name.value).unwrap().clone(); - ExprMonoInfo::new(mexpr, Some(mtype.typ), mtype.value) + ExprMonoInfo::new(mexpr, Some(mtype.typ), mtype.constant) }; res @@ -955,24 +1034,46 @@ fn monomorphize_expr( )); } - fields_mono.push((ident, observed_mono.expr.clone())); + fields_mono.push(( + ident, + observed_mono.expr.clone(), + observed_mono.constant.clone(), + )); } let mexpr = expr.to_mast( ctx, &ExprKind::CustomTypeDeclaration { custom: custom.clone(), - fields: fields_mono, + // extract a tuple of first two elements + fields: fields_mono + .iter() + .map(|(a, b, _)| (a.clone(), b.clone())) + .collect(), }, ); + let cst_fields = { + let csts = HashMap::from_iter( + fields_mono + .into_iter() + .filter(|(_, _, cst)| cst.is_some()) + .map(|(ident, _, cst)| (ident, cst.unwrap())), + ); + if csts.is_empty() { + None + } else { + Some(PropagatedConstant::Custom(csts)) + } + }; + ExprMonoInfo::new( mexpr, Some(TyKind::Custom { module: module.clone(), name: name.clone(), }), - None, + cst_fields, ) } ExprKind::RepeatedArrayInit { item, size } => { @@ -989,7 +1090,7 @@ fn monomorphize_expr( ); if let Some(cst) = size_mono.constant { - let arr_typ = TyKind::Array(Box::new(item_typ), cst); + let arr_typ = TyKind::Array(Box::new(item_typ), cst.as_single()); ExprMonoInfo::new(mexpr, Some(arr_typ), None) } else { return Err(error(ErrorKind::InvalidArraySize, expr.span)); @@ -1011,23 +1112,30 @@ pub fn monomorphize_block( mono_fn_env: &mut MonomorphizedFnEnv, stmts: &[Stmt], expected_return: Option<&Ty>, -) -> Result<(Vec, Option)> { +) -> Result<(Vec, Option)> { mono_fn_env.nest(); - let mut return_typ = None; + let mut ret_expr_mono = None; let mut stmts_mono = vec![]; for stmt in stmts { - if let Some((stmt, ret_typ)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? { + if let Some((stmt, expr_mono)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? { stmts_mono.push(stmt); - if ret_typ.is_some() { - return_typ = ret_typ; + // only return stmt can return `ExprMonoInfo` which contains propagated constants + if expr_mono.is_some() { + ret_expr_mono = expr_mono; } } } + let return_typ = if let Some(expr_mono) = ret_expr_mono.clone() { + expr_mono.typ + } else { + None + }; + // check the return if let (Some(expected), Some(observed)) = (expected_return, return_typ.clone()) { if !observed.match_expected(&expected.kind, true) { @@ -1040,7 +1148,7 @@ pub fn monomorphize_block( mono_fn_env.pop(); - Ok((stmts_mono, return_typ)) + Ok((stmts_mono, ret_expr_mono)) } /// Monomorphize a statement. @@ -1048,7 +1156,7 @@ pub fn monomorphize_stmt( ctx: &mut MastCtx, mono_fn_env: &mut MonomorphizedFnEnv, stmt: &Stmt, -) -> Result)>> { +) -> Result)>> { let res = match &stmt.kind { StmtKind::Assign { mutable, lhs, rhs } => { let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?; @@ -1093,7 +1201,9 @@ pub fn monomorphize_stmt( return Err(error(ErrorKind::InvalidRangeSize, stmt.span)); } - if start_mono.constant.unwrap() > end_mono.constant.unwrap() { + if start_mono.constant.unwrap().as_single() + > end_mono.constant.unwrap().as_single() + { return Err(error(ErrorKind::InvalidRangeSize, stmt.span)); } @@ -1147,11 +1257,11 @@ pub fn monomorphize_stmt( StmtKind::Return(res) => { let expr_mono = monomorphize_expr(ctx, res, mono_fn_env)?; let stmt_mono = Stmt { - kind: StmtKind::Return(Box::new(expr_mono.expr)), + kind: StmtKind::Return(Box::new(expr_mono.expr.clone())), span: stmt.span, }; - Some((stmt_mono, expr_mono.typ)) + Some((stmt_mono, Some(expr_mono))) } StmtKind::Comment(_) => None, }; @@ -1168,7 +1278,7 @@ pub fn instantiate_fn_call( fn_info: FnInfo, args: &[ExprMonoInfo], span: Span, -) -> Result<(FnInfo, Option)> { +) -> Result<(FnInfo, Option, Option)> { ctx.start_monomorphize_func(); let fn_sig = fn_info.sig(); @@ -1192,7 +1302,11 @@ pub fn instantiate_fn_call( let val = fn_sig.generics.get(gen); mono_fn_env.store_type( gen, - &MTypeInfo::new(&TyKind::Field { constant: true }, span, Some(val)), + &MTypeInfo::new( + &TyKind::Field { constant: true }, + span, + Some(PropagatedConstant::from(val)), + ), )?; } @@ -1208,7 +1322,7 @@ pub fn instantiate_fn_call( let typ = mono_info.typ.as_ref().expect("expected a value"); mono_fn_env.store_type( arg_name, - &MTypeInfo::new(typ, mono_info.expr.span, mono_info.constant), + &MTypeInfo::new(typ, mono_info.expr.span, mono_info.constant.clone()), )?; } @@ -1216,32 +1330,40 @@ pub fn instantiate_fn_call( let ret_typed = sig_typed.return_type.clone(); // construct the monomorphized function AST - let func_def = match fn_info.kind { - FnKind::BuiltIn(_, handle) => FnInfo { - kind: FnKind::BuiltIn(sig_typed, handle), - is_hint: fn_info.is_hint, - span: fn_info.span, - }, + let (func_def, mono_info) = match fn_info.kind { + FnKind::BuiltIn(_, handle) => ( + FnInfo { + kind: FnKind::BuiltIn(sig_typed, handle), + ..fn_info + }, + // todo: we will need to propagate the constant value from builtin function as well + None, + ), FnKind::Native(fn_def) => { - let (stmts_typed, _) = + let (stmts_typed, mono_info) = monomorphize_block(ctx, mono_fn_env, &fn_def.body, ret_typed.as_ref())?; - FnInfo { - kind: FnKind::Native(FunctionDef { - sig: sig_typed, - body: stmts_typed, - span: fn_def.span, - is_hint: fn_def.is_hint, - }), - is_hint: fn_info.is_hint, - span: fn_info.span, - } + ( + FnInfo { + kind: FnKind::Native(FunctionDef { + sig: sig_typed, + body: stmts_typed, + span: fn_def.span, + is_hint: fn_def.is_hint, + }), + is_hint: fn_info.is_hint, + span: fn_info.span, + }, + mono_info, + ) } }; ctx.finish_monomorphize_func(); - Ok((func_def, ret_typed.map(|t| t.kind))) + let cst = mono_info.and_then(|c| c.constant); + + Ok((func_def, ret_typed.map(|t| t.kind), cst)) } pub fn error(kind: ErrorKind, span: Span) -> Error { Error::new("mast", kind, span) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 1ed4dc063..0759a6333 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -498,3 +498,55 @@ fn test_nonhint_call_with_unsafe() { ErrorKind::UnexpectedUnsafeAttribute )); } + +#[test] +fn test_no_cst_struct_field_prop() { + let code = r#" + struct Thing { + val: Field, + } + + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let thing = Thing { val: xx }; + + let arr = gen(thing.val); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} + +#[test] +fn test_mut_cst_struct_field_prop() { + let code = r#" + struct Thing { + val: Field, + } + + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let mut thing = Thing { val: 3 }; + thing.val = xx; + + let arr = gen(thing.val); + assert_eq(arr[0], xx); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 6cbac6de7..84bd7a98e 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -7,19 +7,21 @@ struct Uint8 { } fn Uint8.new(val: Field) -> Uint8 { + let bit_len = 8; + // range check - let ignore_ = bits::to_bits(8, val); + let ignore_ = bits::to_bits(bit_len, val); return Uint8 { inner: val, - bit_len: 8 + bit_len: bit_len }; } fn Uint8.less_than(self, rhs: Uint8) -> Bool { - return comparator::less_than(8, self.inner, rhs.inner); + return comparator::less_than(self.bit_len, self.inner, rhs.inner); } fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { - return comparator::less_eq_than(8, self.inner, rhs.inner); + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); } \ No newline at end of file diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 74a8ace86..d5fb3f54e 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -224,6 +224,7 @@ impl TypeChecker { .compute_type(lhs, typed_fn_env)? .expect("type-checker bug: lhs access on an empty var"); + // todo: check and update the const field type for other cases // lhs can be a local variable or a path to an array let lhs_name = match &lhs.kind { // `name = ` @@ -290,6 +291,31 @@ impl TypeChecker { )); } + // update struct field type + if let ExprKind::FieldAccess { + lhs, + rhs: field_name, + } = &lhs.kind + { + // get variable behind lhs + let lhs_node = self + .compute_type(lhs, typed_fn_env)? + .expect("type-checker bug: lhs access on an empty var"); + + // obtain the qualified name of the struct + let (module, struct_name) = match lhs_node.typ { + TyKind::Custom { module, name } => (module, name), + _ => { + return Err( + self.error(ErrorKind::FieldAccessOnNonCustomStruct, lhs.span) + ) + } + }; + + let qualified = FullyQualified::new(&module, &struct_name); + self.update_struct_field(&qualified, &field_name.value, rhs_typ.typ); + } + None } @@ -538,6 +564,12 @@ impl TypeChecker { expr.span, )); } + + // If the observed type is a Field type, then init that struct field as the observed type. + // This is because the field type can be a constant or not, which needs to be propagated. + if matches!(observed_typ.typ, TyKind::Field { .. }) { + self.update_struct_field(&qualified, &defined.0, observed_typ.typ.clone()); + } } let res = ExprTyInfo::new_anon(TyKind::Custom { diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 2b826f8a2..7d9da5061 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -150,6 +150,29 @@ impl TypeChecker { .expect("couldn't find the struct for storing the method"); struct_info.methods.remove(method_name); } + + /// Update the type of a struct field. + /// When the assignment is done, we need to update the type of the field. + /// This is only for the case of updating field types to either a constant or a variable. + pub fn update_struct_field( + &mut self, + qualified: &FullyQualified, + field_name: &str, + typ: TyKind, + ) { + let struct_info = self + .structs + .get_mut(qualified) + .expect("couldn't find the struct for storing the method"); + + // update the field type + for field in struct_info.fields.iter_mut() { + if field.0 == field_name { + field.1 = typ; + return; + } + } + } } impl TypeChecker { From ca35175d8f8d525fa9993f88c3141bd126934db7 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 18 Oct 2024 14:51:02 +0800 Subject: [PATCH 2/8] add uint16/32/64 --- src/stdlib/native/int.no | 81 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 84bd7a98e..4390c3db6 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -1,6 +1,7 @@ use std::bits; use std::comparator; +// u8 struct Uint8 { inner: Field, bit_len: Field, @@ -18,10 +19,90 @@ fn Uint8.new(val: Field) -> Uint8 { }; } +// u16 +struct Uint16 { + inner: Field, + bit_len: Field, +} + +fn Uint16.new(val: Field) -> Uint16 { + let bit_len = 16; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint16 { + inner: val, + bit_len: bit_len + }; +} + +// u32 +struct Uint32 { + inner: Field, + bit_len: Field, +} + +fn Uint32.new(val: Field) -> Uint32 { + let bit_len = 32; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint32 { + inner: val, + bit_len: bit_len + }; +} + +// u64 +struct Uint64 { + inner: Field, + bit_len: Field, +} + +fn Uint64.new(val: Field) -> Uint64 { + let bit_len = 64; + + // range check + let ignore_ = bits::to_bits(bit_len, val); + + return Uint64 { + inner: val, + bit_len: bit_len + }; +} + +// implement comparator + fn Uint8.less_than(self, rhs: Uint8) -> Bool { return comparator::less_than(self.bit_len, self.inner, rhs.inner); } fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint16.less_than(self, rhs: Uint16) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint32.less_than(self, rhs: Uint32) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint64.less_than(self, rhs: Uint64) -> Bool { + return comparator::less_than(self.bit_len, self.inner, rhs.inner); +} + +fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { + return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); } \ No newline at end of file From 55f441a8b0b81214465a58e824a724b1430bceba Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 18 Oct 2024 14:51:09 +0800 Subject: [PATCH 3/8] update tests --- .../less_eq_than/less_eq_than_main.no | 6 +- .../stdlib/comparator/less_than/less_than.asm | 27 ++- .../comparator/less_than/less_than_main.no | 7 +- src/tests/stdlib/comparator/mod.rs | 182 +++++++++++++----- src/tests/stdlib/mod.rs | 38 +++- 5 files changed, 189 insertions(+), 71 deletions(-) diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no index f46cbfb38..f916a96a2 100644 --- a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no @@ -1,10 +1,8 @@ use std::comparator; -use std::int; fn main(pub lhs: Field, rhs: Field) -> Bool { - let lhs_u8 = int::Uint8.new(lhs); - let rhs_u8 = int::Uint8.new(rhs); + let bit_len = 2; + let res = comparator::less_eq_than(bit_len, lhs, rhs); - let res = lhs_u8.less_eq_than(rhs_u8); return res; } \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm index 8c33ac348..beac8c576 100644 --- a/src/tests/stdlib/comparator/less_than/less_than.asm +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -3,9 +3,24 @@ v_5 == (v_4) * (v_4 + -1) 0 == (v_5) * (1) -v_7 == (v_6) * (v_6 + -1) -0 == (v_7) * (1) -v_9 == (v_8) * (v_8 + -1) -0 == (v_9) * (1) -v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 4) * (1) --1 * v_8 + 1 == (v_1) * (1) +1 == (v_6) * (1) +v_8 == (v_7) * (-1 * v_4 + v_6) +-1 * v_9 + 1 == (v_8) * (1) +v_10 == (v_9) * (-1 * v_4 + v_6) +0 == (v_10) * (1) +v_12 == (v_11) * (v_11 + -1) +0 == (v_12) * (1) +1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_11 + v_13) +-1 * v_16 + 1 == (v_15) * (1) +v_17 == (v_16) * (-1 * v_11 + v_13) +0 == (v_17) * (1) +v_19 == (v_18) * (v_18 + -1) +0 == (v_19) * (1) +1 == (v_20) * (1) +v_22 == (v_21) * (-1 * v_18 + v_20) +-1 * v_23 + 1 == (v_22) * (1) +v_24 == (v_23) * (-1 * v_18 + v_20) +0 == (v_24) * (1) +v_2 + -1 * v_3 + 4 == (v_9 + 2 * v_16 + 4 * v_23) * (1) +-1 * v_23 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no index bb8cb1f0c..e39d9cf95 100644 --- a/src/tests/stdlib/comparator/less_than/less_than_main.no +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -1,11 +1,8 @@ use std::comparator; -use std::int; fn main(pub lhs: Field, rhs: Field) -> Bool { - let lhs_u8 = int::Uint8.new(lhs); - let rhs_u8 = int::Uint8.new(rhs); - - let res = lhs_u8.less_than(rhs_u8); + let bit_len = 2; + let res = comparator::less_than(bit_len, lhs, rhs); return res; } \ No newline at end of file diff --git a/src/tests/stdlib/comparator/mod.rs b/src/tests/stdlib/comparator/mod.rs index bc85f91fe..ed4ce618c 100644 --- a/src/tests/stdlib/comparator/mod.rs +++ b/src/tests/stdlib/comparator/mod.rs @@ -1,89 +1,177 @@ -use crate::error; +use crate::error::{self, ErrorKind}; -use super::test_stdlib; +use super::{test_stdlib, test_stdlib_code}; use error::Result; +use rstest::rstest; -#[test] -fn test_less_than_true() -> Result<()> { - let public_inputs = r#"{"lhs": "0"}"#; - let private_inputs = r#"{"rhs": "1"}"#; +// code template +static LESS_THAN_TPL: &str = r#" +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let lhs_u = int::{}.new(lhs); + let rhs_u = int::{}.new(rhs); + + let res = lhs_u.less_than(rhs_u); + + return res; +} +"#; +static LESS_THAN_EQ_TPL: &str = r#" +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let lhs_u = int::{}.new(lhs); + let rhs_u = int::{}.new(rhs); + + let res = lhs_u.less_eq_than(rhs_u); + + return res; +} +"#; + +#[rstest] +#[case(r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_less_than( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { test_stdlib( "comparator/less_than/less_than_main.no", - "comparator/less_than/less_than.asm", + Some("comparator/less_than/less_than.asm"), public_inputs, private_inputs, - vec!["1"], + expected_output, )?; Ok(()) } -// test false #[test] -fn test_less_than_false() -> Result<()> { - let public_inputs = r#"{"lhs": "1"}"#; +fn test_less_than_witness_failure() -> Result<()> { + let public_inputs = r#"{"lhs": "4"}"#; let private_inputs = r#"{"rhs": "0"}"#; - test_stdlib( + let err = test_stdlib( "comparator/less_than/less_than_main.no", - "comparator/less_than/less_than.asm", + None, public_inputs, private_inputs, - vec!["0"], - )?; + vec![], + ) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); Ok(()) } -#[test] -fn test_less_eq_than_true_1() -> Result<()> { - let public_inputs = r#"{"lhs": "0"}"#; - let private_inputs = r#"{"rhs": "1"}"#; - - test_stdlib( - "comparator/less_eq_than/less_eq_than_main.no", - "comparator/less_eq_than/less_eq_than.asm", - public_inputs, - private_inputs, - vec!["1"], - )?; +#[rstest] +#[case("Uint8", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint16", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +#[case("Uint32", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint64", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_uint_less_than( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + // Replace placeholders with the given integer type. + let code = LESS_THAN_TPL.replace("{}", int_type); + + // Call the test function with the given inputs and expected output. + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; Ok(()) } -#[test] -fn test_less_eq_than_true_2() -> Result<()> { - let public_inputs = r#"{"lhs": "1"}"#; - let private_inputs = r#"{"rhs": "1"}"#; +#[rstest] +#[case("Uint8", r#"{"lhs": "256"}"#, r#"{"rhs": "0"}"#)] // Uint8 overflow +#[case("Uint16", r#"{"lhs": "65536"}"#, r#"{"rhs": "0"}"#)] // Uint16 overflow +#[case("Uint32", r#"{"lhs": "4294967296"}"#, r#"{"rhs": "0"}"#)] // Uint32 overflow +#[case("Uint64", r#"{"lhs": "18446744073709551616"}"#, r#"{"rhs": "0"}"#)] // Uint64 overflow +fn test_uint_less_than_range_failure( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = LESS_THAN_TPL.replace("{}", int_type); + + // Test that the provided inputs result in an error due to overflow. + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec!["0"]) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} +// Test for less than or equal scenarios +#[rstest] +#[case(r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] // True case (lhs < rhs) +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "1"}"#, vec!["1"])] // True case (lhs == rhs) +#[case(r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] // False case +fn test_less_eq_than( + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { test_stdlib( "comparator/less_eq_than/less_eq_than_main.no", - "comparator/less_eq_than/less_eq_than.asm", + Some("comparator/less_eq_than/less_eq_than.asm"), public_inputs, private_inputs, - vec!["1"], + expected_output, )?; Ok(()) } -#[test] -fn test_less_eq_than_false() -> Result<()> { - let public_inputs = r#"{"lhs": "1"}"#; - let private_inputs = r#"{"rhs": "0"}"#; +// implement the rest for less than eq - test_stdlib( - "comparator/less_eq_than/less_eq_than_main.no", - "comparator/less_eq_than/less_eq_than.asm", - public_inputs, - private_inputs, - vec!["0"], - )?; +#[rstest] +#[case("Uint8", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint16", r#"{"lhs": "1"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint32", r#"{"lhs": "0"}"#, r#"{"rhs": "1"}"#, vec!["1"])] +#[case("Uint64", r#"{"lhs": "1"}"#, r#"{"rhs": "0"}"#, vec!["0"])] +fn test_uint_less_eq_than( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, + #[case] expected_output: Vec<&str>, +) -> Result<()> { + let code = LESS_THAN_EQ_TPL.replace("{}", int_type); + + test_stdlib_code(&code, None, public_inputs, private_inputs, expected_output)?; Ok(()) } -// test value overflow modulus -// it shouldn't need user to enter the bit length -// should have a way to restrict and type check the value to a certain bit length +#[rstest] +#[case("Uint8", r#"{"lhs": "256"}"#, r#"{"rhs": "0"}"#)] // Uint8 overflow +#[case("Uint16", r#"{"lhs": "65536"}"#, r#"{"rhs": "0"}"#)] // Uint16 overflow +#[case("Uint32", r#"{"lhs": "4294967296"}"#, r#"{"rhs": "0"}"#)] // Uint32 overflow +#[case("Uint64", r#"{"lhs": "18446744073709551616"}"#, r#"{"rhs": "0"}"#)] // Uint64 overflow +fn test_uint_less_eq_than_range_failure( + #[case] int_type: &str, + #[case] public_inputs: &str, + #[case] private_inputs: &str, +) -> Result<()> { + let code = LESS_THAN_EQ_TPL.replace("{}", int_type); + + let err = test_stdlib_code(&code, None, public_inputs, private_inputs, vec!["0"]) + .err() + .expect("expected witness error"); + + assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); + + Ok(()) +} diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 107d44fe9..5442e9016 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -16,18 +16,38 @@ use crate::{ fn test_stdlib( path: &str, - asm_path: &str, + asm_path: Option<&str>, public_inputs: &str, private_inputs: &str, expected_public_output: Vec<&str>, ) -> Result>> { - let r1cs = R1CS::new(); let root = env!("CARGO_MANIFEST_DIR"); let prefix_path = Path::new(root).join("src/tests/stdlib"); // read noname file let code = std::fs::read_to_string(prefix_path.clone().join(path)).unwrap(); + let compiled_circuit = test_stdlib_code( + &code, + asm_path, + public_inputs, + private_inputs, + expected_public_output, + )?; + + Ok(compiled_circuit) +} + +fn test_stdlib_code( + code: &str, + asm_path: Option<&str>, + public_inputs: &str, + private_inputs: &str, + expected_public_output: Vec<&str>, +) -> Result>> { + let r1cs = R1CS::new(); + let root = env!("CARGO_MANIFEST_DIR"); + // parse inputs let public_inputs = parse_inputs(public_inputs).unwrap(); let private_inputs = parse_inputs(private_inputs).unwrap(); @@ -43,8 +63,8 @@ fn test_stdlib( &mut tast, this_module, &mut sources, - path.to_string(), - code.clone(), + "test.no".to_string(), + code.to_string(), node_id, ) .unwrap(); @@ -53,9 +73,8 @@ fn test_stdlib( let compiled_circuit = CircuitWriter::generate_circuit(mast, r1cs)?; // this should check the constraints - let generated_witness = compiled_circuit - .generate_witness(public_inputs.clone(), private_inputs.clone()) - .unwrap(); + let generated_witness = + compiled_circuit.generate_witness(public_inputs.clone(), private_inputs.clone())?; let expected_public_output = expected_public_output .iter() @@ -76,9 +95,10 @@ fn test_stdlib( } // check the ASM - if compiled_circuit.circuit.backend.num_constraints() < 100 { + if asm_path.is_some() && compiled_circuit.circuit.backend.num_constraints() < 100 { let prefix_asm = Path::new(root).join("src/tests/stdlib/"); - let expected_asm = std::fs::read_to_string(prefix_asm.clone().join(asm_path)).unwrap(); + let expected_asm = + std::fs::read_to_string(prefix_asm.clone().join(asm_path.unwrap())).unwrap(); let obtained_asm = compiled_circuit.asm(&Sources::new(), false); if obtained_asm != expected_asm { From 2b47a8d7342a12e675a367cab1849ea8f3556043 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:16:55 +0800 Subject: [PATCH 4/8] fix: remove bit_len from uints --- src/stdlib/native/int.no | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no index 4390c3db6..4ad93c3b1 100644 --- a/src/stdlib/native/int.no +++ b/src/stdlib/native/int.no @@ -4,7 +4,6 @@ use std::comparator; // u8 struct Uint8 { inner: Field, - bit_len: Field, } fn Uint8.new(val: Field) -> Uint8 { @@ -14,15 +13,13 @@ fn Uint8.new(val: Field) -> Uint8 { let ignore_ = bits::to_bits(bit_len, val); return Uint8 { - inner: val, - bit_len: bit_len + inner: val }; } // u16 struct Uint16 { - inner: Field, - bit_len: Field, + inner: Field } fn Uint16.new(val: Field) -> Uint16 { @@ -32,15 +29,13 @@ fn Uint16.new(val: Field) -> Uint16 { let ignore_ = bits::to_bits(bit_len, val); return Uint16 { - inner: val, - bit_len: bit_len + inner: val }; } // u32 struct Uint32 { - inner: Field, - bit_len: Field, + inner: Field } fn Uint32.new(val: Field) -> Uint32 { @@ -50,15 +45,13 @@ fn Uint32.new(val: Field) -> Uint32 { let ignore_ = bits::to_bits(bit_len, val); return Uint32 { - inner: val, - bit_len: bit_len + inner: val }; } // u64 struct Uint64 { - inner: Field, - bit_len: Field, + inner: Field } fn Uint64.new(val: Field) -> Uint64 { @@ -68,41 +61,40 @@ fn Uint64.new(val: Field) -> Uint64 { let ignore_ = bits::to_bits(bit_len, val); return Uint64 { - inner: val, - bit_len: bit_len + inner: val }; } // implement comparator fn Uint8.less_than(self, rhs: Uint8) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(8, self.inner, rhs.inner); } fn Uint8.less_eq_than(self, rhs: Uint8) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(8, self.inner, rhs.inner); } fn Uint16.less_than(self, rhs: Uint16) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(16, self.inner, rhs.inner); } fn Uint16.less_eq_than(self, rhs: Uint16) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(16, self.inner, rhs.inner); } fn Uint32.less_than(self, rhs: Uint32) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(32, self.inner, rhs.inner); } fn Uint32.less_eq_than(self, rhs: Uint32) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(32, self.inner, rhs.inner); } fn Uint64.less_than(self, rhs: Uint64) -> Bool { - return comparator::less_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_than(64, self.inner, rhs.inner); } fn Uint64.less_eq_than(self, rhs: Uint64) -> Bool { - return comparator::less_eq_than(self.bit_len, self.inner, rhs.inner); + return comparator::less_eq_than(64, self.inner, rhs.inner); } \ No newline at end of file From 8ef099379ed8c88a260fb47719b81a232caa1c86 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:32:14 +0800 Subject: [PATCH 5/8] simplify to_bits --- .../asm/kimchi/generic_builtin_bits.asm | 128 ++++++++---------- .../fixture/asm/r1cs/generic_builtin_bits.asm | 46 +++---- src/stdlib/native/bits.no | 37 ++--- .../comparator/less_eq_than/less_eq_than.asm | 40 +++--- .../stdlib/comparator/less_than/less_than.asm | 40 +++--- 5 files changed, 123 insertions(+), 168 deletions(-) diff --git a/examples/fixture/asm/kimchi/generic_builtin_bits.asm b/examples/fixture/asm/kimchi/generic_builtin_bits.asm index a7843d1c4..87ceb2cce 100644 --- a/examples/fixture/asm/kimchi/generic_builtin_bits.asm +++ b/examples/fixture/asm/kimchi/generic_builtin_bits.asm @@ -1,9 +1,6 @@ @ noname.0.7.0 @ public inputs: 1 -DoubleGeneric<1> -DoubleGeneric<1,0,-1,0,-1> -DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,1> @@ -17,12 +14,6 @@ DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,-1> -DoubleGeneric<1,1> -DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,0,-1,0,-1> -DoubleGeneric<0,0,-1,1> -DoubleGeneric<1> DoubleGeneric<1,1> DoubleGeneric<1,1,-1> DoubleGeneric<0,0,-1,1> @@ -34,13 +25,6 @@ DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<2,0,-1> -DoubleGeneric<1,1> -DoubleGeneric<1,0,-1,0,1> -DoubleGeneric<1,1,-1> -DoubleGeneric<1,0,-1,0,-1> -DoubleGeneric<0,0,-1,1> -DoubleGeneric<1> DoubleGeneric<1,1> DoubleGeneric<1,1,-1> DoubleGeneric<0,0,-1,1> @@ -52,6 +36,13 @@ DoubleGeneric<1> DoubleGeneric<1,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,1,-1> DoubleGeneric<4,0,-1> DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> @@ -77,61 +68,52 @@ DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,0,0,0,-2> -(0,0) -> (55,1) -> (74,1) -> (75,0) -(1,0) -> (2,0) -> (5,0) -(1,2) -> (2,1) -(2,2) -> (3,0) -(4,0) -> (6,0) -> (23,0) -> (41,0) -(5,1) -> (6,1) -(6,2) -> (7,1) -> (11,1) -(7,2) -> (10,0) -(8,0) -> (11,0) -> (13,0) -> (14,0) -(8,1) -> (9,0) -(9,2) -> (10,1) -(11,2) -> (12,0) -(13,2) -> (16,0) -> (17,0) -> (56,0) -> (63,0) -> (64,0) -(14,1) -> (15,0) -(16,2) -> (36,0) -(17,1) -> (18,0) -(19,0) -> (20,0) -> (22,0) -(19,2) -> (20,1) -(20,2) -> (21,0) -(22,1) -> (23,1) -(23,2) -> (24,1) -> (28,1) -(24,2) -> (27,0) -(25,0) -> (28,0) -> (30,0) -> (31,0) -(25,1) -> (26,0) -(26,2) -> (27,1) -(28,2) -> (29,0) -(30,2) -> (33,0) -> (34,0) -> (59,0) -> (66,0) -> (67,0) -(31,1) -> (32,0) -(33,2) -> (36,1) -(34,1) -> (35,0) -(36,2) -> (54,0) -(37,0) -> (38,0) -> (40,0) -(37,2) -> (38,1) -(38,2) -> (39,0) -(40,1) -> (41,1) -(41,2) -> (42,1) -> (46,1) -(42,2) -> (45,0) -(43,0) -> (46,0) -> (48,0) -> (49,0) +(0,0) -> (46,1) -> (65,1) -> (66,0) +(1,0) -> (3,0) -> (14,0) -> (25,0) +(2,1) -> (3,1) +(3,2) -> (4,1) -> (8,1) +(4,2) -> (7,0) +(5,0) -> (8,0) -> (10,0) -> (11,0) +(5,1) -> (6,0) +(6,2) -> (7,1) +(8,2) -> (9,0) +(10,2) -> (35,0) -> (36,0) -> (47,0) -> (54,0) -> (55,0) +(11,1) -> (12,0) +(13,1) -> (14,1) +(14,2) -> (15,1) -> (19,1) +(15,2) -> (18,0) +(16,0) -> (19,0) -> (21,0) -> (22,0) +(16,1) -> (17,0) +(17,2) -> (18,1) +(19,2) -> (20,0) +(21,2) -> (38,0) -> (39,0) -> (50,0) -> (57,0) -> (58,0) +(22,1) -> (23,0) +(24,1) -> (25,1) +(25,2) -> (26,1) -> (30,1) +(26,2) -> (29,0) +(27,0) -> (30,0) -> (32,0) -> (33,0) +(27,1) -> (28,0) +(28,2) -> (29,1) +(30,2) -> (31,0) +(32,2) -> (42,0) -> (43,0) -> (51,0) -> (61,0) -> (62,0) +(33,1) -> (34,0) +(35,2) -> (41,0) +(36,1) -> (37,0) +(38,2) -> (41,1) +(39,1) -> (40,0) +(41,2) -> (45,0) +(42,2) -> (45,1) (43,1) -> (44,0) -(44,2) -> (45,1) -(46,2) -> (47,0) -(48,2) -> (51,0) -> (52,0) -> (60,0) -> (70,0) -> (71,0) -(49,1) -> (50,0) -(51,2) -> (54,1) -(52,1) -> (53,0) -(54,2) -> (55,0) -(56,1) -> (57,0) -(57,2) -> (58,0) -(60,1) -> (61,0) -(61,2) -> (62,0) -(63,2) -> (69,0) -(64,1) -> (65,0) -(66,2) -> (69,1) -(67,1) -> (68,0) -(69,2) -> (73,0) -(70,2) -> (73,1) -(71,1) -> (72,0) -(73,2) -> (74,0) +(45,2) -> (46,0) +(47,1) -> (48,0) +(48,2) -> (49,0) +(51,1) -> (52,0) +(52,2) -> (53,0) +(54,2) -> (60,0) +(55,1) -> (56,0) +(57,2) -> (60,1) +(58,1) -> (59,0) +(60,2) -> (64,0) +(61,2) -> (64,1) +(62,1) -> (63,0) +(64,2) -> (65,0) diff --git a/examples/fixture/asm/r1cs/generic_builtin_bits.asm b/examples/fixture/asm/r1cs/generic_builtin_bits.asm index fb72f65a4..0988a4dbe 100644 --- a/examples/fixture/asm/r1cs/generic_builtin_bits.asm +++ b/examples/fixture/asm/r1cs/generic_builtin_bits.asm @@ -1,30 +1,24 @@ @ noname.0.7.0 @ public inputs: 1 -v_3 == (v_2) * (v_2 + -1) -0 == (v_3) * (1) -1 == (v_4) * (1) -v_6 == (v_5) * (-1 * v_2 + v_4) --1 * v_7 + 1 == (v_6) * (1) -v_8 == (v_7) * (-1 * v_2 + v_4) -0 == (v_8) * (1) -v_10 == (v_9) * (v_9 + -1) -0 == (v_10) * (1) -1 == (v_11) * (1) -v_13 == (v_12) * (-1 * v_9 + v_11) --1 * v_14 + 1 == (v_13) * (1) -v_15 == (v_14) * (-1 * v_9 + v_11) -0 == (v_15) * (1) -v_17 == (v_16) * (v_16 + -1) -0 == (v_17) * (1) -1 == (v_18) * (1) -v_20 == (v_19) * (-1 * v_16 + v_18) --1 * v_21 + 1 == (v_20) * (1) -v_22 == (v_21) * (-1 * v_16 + v_18) -0 == (v_22) * (1) -v_1 == (v_7 + 2 * v_14 + 4 * v_21) * (1) -1 == (-1 * v_7 + 1) * (1) -1 == (v_14) * (1) -1 == (-1 * v_21 + 1) * (1) -v_1 == (v_7 + 2 * v_14 + 4 * v_21) * (1) +1 == (v_3) * (1) +v_5 == (v_4) * (-1 * v_2 + v_3) +-1 * v_6 + 1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_2 + v_3) +0 == (v_7) * (1) +1 == (v_9) * (1) +v_11 == (v_10) * (-1 * v_8 + v_9) +-1 * v_12 + 1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_8 + v_9) +0 == (v_13) * (1) +1 == (v_15) * (1) +v_17 == (v_16) * (-1 * v_14 + v_15) +-1 * v_18 + 1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_14 + v_15) +0 == (v_19) * (1) +v_1 == (v_6 + 2 * v_12 + 4 * v_18) * (1) +1 == (-1 * v_6 + 1) * (1) +1 == (v_12) * (1) +1 == (-1 * v_18 + 1) * (1) +v_1 == (v_6 + 2 * v_12 + 4 * v_18) * (1) 2 == (v_1) * (1) diff --git a/src/stdlib/native/bits.no b/src/stdlib/native/bits.no index d991b04de..5043b49cc 100644 --- a/src/stdlib/native/bits.no +++ b/src/stdlib/native/bits.no @@ -1,13 +1,22 @@ hint fn nth_bit(value: Field, const nth: Field) -> Field; +fn from_bits(bits: [Bool; LEN]) -> Field { + let mut lc1 = 0; + let mut e2 = 1; + let zero = 0; + + for index in 0..LEN { + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + return lc1; +} + fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { let mut bits = [false; LEN]; let mut lc1 = 0; let mut e2 = 1; - let one = 1; - let zero = 0; - // todo: ITE should allow literals let true_val = true; let false_val = false; @@ -15,27 +24,9 @@ fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { for index in 0..LEN { let bit_num = unsafe nth_bit(value, index); - // constrain the bit_num to be 0 or 1 - assert_eq(bit_num * (bit_num - 1), 0); - - // convert the bit_num to boolean bits[index] = if bit_num == 1 {true_val} else {false_val}; - - lc1 = lc1 + if bits[index] {e2} else {zero}; - e2 = e2 + e2; } - assert_eq(lc1, value); + + assert_eq(from_bits(bits), value); return bits; } - -fn from_bits(bits: [Bool; LEN]) -> Field { - let mut lc1 = 0; - let mut e2 = 1; - let zero = 0; - - for index in 0..LEN { - lc1 = lc1 + if bits[index] {e2} else {zero}; - e2 = e2 + e2; - } - return lc1; -} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm index 4608ebd97..32f492c19 100644 --- a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm @@ -1,26 +1,20 @@ @ noname.0.7.0 @ public inputs: 2 -v_5 == (v_4) * (v_4 + -1) -0 == (v_5) * (1) -1 == (v_6) * (1) -v_8 == (v_7) * (-1 * v_4 + v_6) --1 * v_9 + 1 == (v_8) * (1) -v_10 == (v_9) * (-1 * v_4 + v_6) -0 == (v_10) * (1) -v_12 == (v_11) * (v_11 + -1) -0 == (v_12) * (1) -1 == (v_13) * (1) -v_15 == (v_14) * (-1 * v_11 + v_13) --1 * v_16 + 1 == (v_15) * (1) -v_17 == (v_16) * (-1 * v_11 + v_13) -0 == (v_17) * (1) -v_19 == (v_18) * (v_18 + -1) -0 == (v_19) * (1) -1 == (v_20) * (1) -v_22 == (v_21) * (-1 * v_18 + v_20) --1 * v_23 + 1 == (v_22) * (1) -v_24 == (v_23) * (-1 * v_18 + v_20) -0 == (v_24) * (1) -v_2 + -1 * v_3 + 3 == (v_9 + 2 * v_16 + 4 * v_23) * (1) --1 * v_23 + 1 == (v_1) * (1) +1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_4 + v_5) +-1 * v_8 + 1 == (v_7) * (1) +v_9 == (v_8) * (-1 * v_4 + v_5) +0 == (v_9) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_10 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_10 + v_11) +0 == (v_15) * (1) +1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_16 + v_17) +-1 * v_20 + 1 == (v_19) * (1) +v_21 == (v_20) * (-1 * v_16 + v_17) +0 == (v_21) * (1) +v_2 + -1 * v_3 + 3 == (v_8 + 2 * v_14 + 4 * v_20) * (1) +-1 * v_20 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm index beac8c576..cc615ad37 100644 --- a/src/tests/stdlib/comparator/less_than/less_than.asm +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -1,26 +1,20 @@ @ noname.0.7.0 @ public inputs: 2 -v_5 == (v_4) * (v_4 + -1) -0 == (v_5) * (1) -1 == (v_6) * (1) -v_8 == (v_7) * (-1 * v_4 + v_6) --1 * v_9 + 1 == (v_8) * (1) -v_10 == (v_9) * (-1 * v_4 + v_6) -0 == (v_10) * (1) -v_12 == (v_11) * (v_11 + -1) -0 == (v_12) * (1) -1 == (v_13) * (1) -v_15 == (v_14) * (-1 * v_11 + v_13) --1 * v_16 + 1 == (v_15) * (1) -v_17 == (v_16) * (-1 * v_11 + v_13) -0 == (v_17) * (1) -v_19 == (v_18) * (v_18 + -1) -0 == (v_19) * (1) -1 == (v_20) * (1) -v_22 == (v_21) * (-1 * v_18 + v_20) --1 * v_23 + 1 == (v_22) * (1) -v_24 == (v_23) * (-1 * v_18 + v_20) -0 == (v_24) * (1) -v_2 + -1 * v_3 + 4 == (v_9 + 2 * v_16 + 4 * v_23) * (1) --1 * v_23 + 1 == (v_1) * (1) +1 == (v_5) * (1) +v_7 == (v_6) * (-1 * v_4 + v_5) +-1 * v_8 + 1 == (v_7) * (1) +v_9 == (v_8) * (-1 * v_4 + v_5) +0 == (v_9) * (1) +1 == (v_11) * (1) +v_13 == (v_12) * (-1 * v_10 + v_11) +-1 * v_14 + 1 == (v_13) * (1) +v_15 == (v_14) * (-1 * v_10 + v_11) +0 == (v_15) * (1) +1 == (v_17) * (1) +v_19 == (v_18) * (-1 * v_16 + v_17) +-1 * v_20 + 1 == (v_19) * (1) +v_21 == (v_20) * (-1 * v_16 + v_17) +0 == (v_21) * (1) +v_2 + -1 * v_3 + 4 == (v_8 + 2 * v_14 + 4 * v_20) * (1) +-1 * v_20 + 1 == (v_1) * (1) From 1f8826a6ded7d806b979bd29c74ad84f6e8c9b9b Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:44:12 +0800 Subject: [PATCH 6/8] add const var STDLIB_DIRECTORY --- src/stdlib/mod.rs | 3 +++ src/tests/examples.rs | 6 +++--- src/tests/stdlib/mod.rs | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index cde82ead8..efa5a6e4e 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -20,6 +20,9 @@ pub mod bits; pub mod builtins; pub mod crypto; +/// The directory under [NONAME_DIRECTORY] containing the native stdlib. +pub const STDLIB_DIRECTORY: &str = "src/stdlib/native/"; + pub enum AllStdModules { Builtins, Crypto, diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 1aa9a31f3..dde1093ef 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -10,7 +10,7 @@ use crate::{ }, compiler::{compile, typecheck_next_file, Sources}, inputs::{parse_inputs, ExtField}, - stdlib::init_stdlib_dep, + stdlib::{init_stdlib_dep, STDLIB_DIRECTORY}, type_checker::TypeChecker, }; @@ -38,7 +38,7 @@ fn test_file( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -102,7 +102,7 @@ fn test_file( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 5442e9016..18f8668e3 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -9,7 +9,7 @@ use crate::{ error::Result, inputs::parse_inputs, mast, - stdlib::init_stdlib_dep, + stdlib::{init_stdlib_dep, STDLIB_DIRECTORY}, type_checker::TypeChecker, witness::CompiledCircuit, }; @@ -56,7 +56,7 @@ fn test_stdlib_code( let mut sources = Sources::new(); let mut tast = TypeChecker::new(); let mut node_id = 0; - node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, "src/stdlib/native/"); + node_id = init_stdlib_dep(&mut sources, &mut tast, node_id, STDLIB_DIRECTORY); let this_module = None; let _node_id = typecheck_next_file( From 8834fffb71ace2df1097522c23b08ed2eb7ab829 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:52:49 +0800 Subject: [PATCH 7/8] move native stdlib to their own folder --- src/stdlib/mod.rs | 3 ++- src/stdlib/native/{bits.no => bits/lib.no} | 0 src/stdlib/native/{comparator.no => comparator/lib.no} | 0 src/stdlib/native/{int.no => int/lib.no} | 0 4 files changed, 2 insertions(+), 1 deletion(-) rename src/stdlib/native/{bits.no => bits/lib.no} (100%) rename src/stdlib/native/{comparator.no => comparator/lib.no} (100%) rename src/stdlib/native/{int.no => int/lib.no} (100%) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index efa5a6e4e..65497e20a 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -100,7 +100,8 @@ pub fn init_stdlib_dep( for lib in libs { let module = UserRepo::new(&format!("std/{}", lib)); let prefix_stdlib = Path::new(path_prefix); - let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); + println!("Loading stdlib: {}", prefix_stdlib.join(format!("{lib}/lib.no")).display()); + let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}/lib.no"))).unwrap(); node_id = typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); } diff --git a/src/stdlib/native/bits.no b/src/stdlib/native/bits/lib.no similarity index 100% rename from src/stdlib/native/bits.no rename to src/stdlib/native/bits/lib.no diff --git a/src/stdlib/native/comparator.no b/src/stdlib/native/comparator/lib.no similarity index 100% rename from src/stdlib/native/comparator.no rename to src/stdlib/native/comparator/lib.no diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int/lib.no similarity index 100% rename from src/stdlib/native/int.no rename to src/stdlib/native/int/lib.no From a822a1b9fd1d83d5465c0e0da20e5e46e7373c99 Mon Sep 17 00:00:00 2001 From: kata Date: Thu, 24 Oct 2024 12:52:54 +0800 Subject: [PATCH 8/8] fmt --- src/stdlib/mod.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 65497e20a..043f6a147 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -100,7 +100,10 @@ pub fn init_stdlib_dep( for lib in libs { let module = UserRepo::new(&format!("std/{}", lib)); let prefix_stdlib = Path::new(path_prefix); - println!("Loading stdlib: {}", prefix_stdlib.join(format!("{lib}/lib.no")).display()); + println!( + "Loading stdlib: {}", + prefix_stdlib.join(format!("{lib}/lib.no")).display() + ); let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}/lib.no"))).unwrap(); node_id = typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap();