Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feat/ir-enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
katat committed Jan 23, 2025
2 parents 884877e + cd26e9a commit a30a66b
Show file tree
Hide file tree
Showing 22 changed files with 207 additions and 44 deletions.
2 changes: 1 addition & 1 deletion examples/assignment.no
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct Thing {
xx: Field,
pub xx: Field,
}

fn try_to_mutate(thing: Thing) {
Expand Down
4 changes: 2 additions & 2 deletions examples/hint.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn init_arr(const LEN: Field) -> [Field; LEN] {
Expand Down
4 changes: 2 additions & 2 deletions examples/types.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn main(pub xx: Field, pub yy: Field) {
Expand Down
4 changes: 2 additions & 2 deletions examples/types_array.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn main(pub xx: Field, pub yy: Field) {
Expand Down
4 changes: 2 additions & 2 deletions examples/types_array_output.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct Thing {
xx: Field,
yy: Field,
pub xx: Field,
pub yy: Field,
}

fn main(pub xx: Field, pub yy: Field) -> [Thing; 2] {
Expand Down
2 changes: 1 addition & 1 deletion src/circuit_writer/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ impl<B: Backend> IRWriter<B> {
// find range of field
let mut start = 0;
let mut len = 0;
for (field, field_typ) in &struct_info.fields {
for (field, field_typ, _attribute) in &struct_info.fields {
if field == &rhs.value {
len = self.size_of(field_typ);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl<B: Backend> CircuitWriter<B> {
.clone();

let mut offset = 0;
for (_field_name, field_typ) in &struct_info.fields {
for (_field_name, field_typ, _attribute) in &struct_info.fields {
let len = self.size_of(field_typ);
let range = offset..(offset + len);
self.constrain_inputs_to_main(&input[range], field_typ, span)?;
Expand Down Expand Up @@ -501,7 +501,7 @@ impl<B: Backend> CircuitWriter<B> {
// find range of field
let mut start = 0;
let mut len = 0;
for (field, field_typ) in &struct_info.fields {
for (field, field_typ, _attribute) in &struct_info.fields {
if field == &rhs.value {
len = self.size_of(field_typ);
break;
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ pub enum ErrorKind {
#[error("division by zero")]
DivisionByZero,

#[error("cannot access private field `{1}` of struct `{0}` from outside its methods.")]
PrivateFieldAccess(String, String),

#[error("lhs `{0}` is less than rhs `{1}`")]
NegativeLhsLessThanRhs(String, String),

Expand Down
2 changes: 1 addition & 1 deletion src/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl<B: Backend> CompiledCircuit<B> {

// parse each field
let mut res = vec![];
for (field_name, field_ty) in fields {
for (field_name, field_ty, _attribute) in fields {
let value = map.remove(field_name).ok_or_else(|| {
ParsingError::MissingStructFieldIdent(field_name.to_string())
})?;
Expand Down
16 changes: 13 additions & 3 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ fn monomorphize_expr<B: Backend>(
let typ = struct_info
.fields
.iter()
.find(|(name, _)| name == &rhs.value)
.map(|(_, typ)| typ.clone());
.find(|(name, _, _)| name == &rhs.value)
.map(|(_, typ, _)| typ.clone());

let mexpr = expr.to_mast(
ctx,
Expand Down Expand Up @@ -1067,7 +1067,17 @@ fn monomorphize_expr<B: Backend>(
let else_mono = monomorphize_expr(ctx, else_, mono_fn_env)?;

// make sure that the type of then_ and else_ match
if then_mono.typ != else_mono.typ {
let is_match = match (&then_mono.typ, &else_mono.typ) {
// generics not allowed as they should have been monomorphized
(Some(then_typ), Some(else_typ)) => then_typ.match_expected(else_typ, true),
_ => Err(Error::new(
"If-Else Monomorphization",
ErrorKind::UnexpectedError("Could not resolve type for the `if-else` branch"),
expr.span,
))?,
};

if !is_match {
Err(Error::new(
"If-Else Monomorphization",
ErrorKind::UnexpectedError(
Expand Down
2 changes: 1 addition & 1 deletion src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl NameResCtx {
self.resolve(module, true)?;

// we resolve the fully-qualified types of the fields
for (_field_name, field_typ) in fields {
for (_field_name, field_typ, _attribute) in fields {
self.resolve_typ_kind(&mut field_typ.kind)?;
}

Expand Down
48 changes: 46 additions & 2 deletions src/negative_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,29 @@ fn test_generic_custom_type_mismatched() {
));
}

#[test]
fn test_generic_mutated_cst_var_in_loop() {
let code = r#"
fn gen(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}
fn main(pub xx: Field) {
let mut loopvar = 1;
for ii in 0..3 {
loopvar = loopvar + 1;
}
let arr = gen(loopvar);
}
"#;

let res = tast_pass(code).0;
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::ArgumentTypeMismatch(..)
));
}

#[test]
fn test_array_bounds() {
let code = r#"
Expand Down Expand Up @@ -700,7 +723,7 @@ fn test_nonhint_call_with_unsafe() {
fn test_no_cst_struct_field_prop() {
let code = r#"
struct Thing {
val: Field,
pub val: Field,
}
fn gen(const LEN: Field) -> [Field; LEN] {
Expand All @@ -725,7 +748,7 @@ fn test_no_cst_struct_field_prop() {
fn test_mut_cst_struct_field_prop() {
let code = r#"
struct Thing {
val: Field,
pub val: Field,
}
fn gen(const LEN: Field) -> [Field; LEN] {
Expand All @@ -747,3 +770,24 @@ fn test_mut_cst_struct_field_prop() {
ErrorKind::ArgumentTypeMismatch(..)
));
}

#[test]
fn test_private_field_access() {
let code = r#"
struct Room {
pub beds: Field, // public
size: Field // private
}
fn main(pub beds: Field) {
let room = Room {beds: beds, size: 10};
room.size = 5; // not allowed
}
"#;

let res = tast_pass(code).0;
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::PrivateFieldAccess(..)
));
}
2 changes: 2 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ mod tests {
let parsed = StructDef::parse(ctx, tokens);
assert!(parsed.is_err());
assert!(parsed.as_ref().err().is_some());

println!("{:?}", parsed);
match &parsed.as_ref().err().unwrap().kind {
ErrorKind::ExpectedTokenNotKeyword(keyword, _) => {
assert_eq!(keyword, "pub");
Expand Down
38 changes: 34 additions & 4 deletions src/parser/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use serde::{Deserialize, Serialize};
use crate::{
constants::Span,
error::{ErrorKind, Result},
lexer::{Token, TokenKind, Tokens},
lexer::{Keyword, Token, TokenKind, Tokens},
syntax::is_type,
};

use super::{
types::{Ident, ModulePath, Ty, TyKind},
types::{Attribute, AttributeKind, Ident, ModulePath, Ty, TyKind},
Error, ParserCtx,
};

Expand All @@ -17,7 +17,7 @@ pub struct StructDef {
//pub attribute: Attribute,
pub module: ModulePath, // name resolution
pub name: CustomType,
pub fields: Vec<(Ident, Ty)>,
pub fields: Vec<(Ident, Ty, Option<Attribute>)>,
pub span: Span,
}

Expand Down Expand Up @@ -55,6 +55,36 @@ impl StructDef {
tokens.bump(ctx);
break;
}

// check for pub keyword
// struct Foo { pub a: Field, b: Field }
// ^
let attribute = if matches!(
tokens.peek(),
Some(Token {
kind: TokenKind::Keyword(Keyword::Pub),
..
})
) {
let token = tokens.bump(ctx).unwrap();
// next token shouldn't be :
if tokens.peek().unwrap().kind == TokenKind::Colon {
return Err(ctx.error(
ErrorKind::ExpectedTokenNotKeyword(
"pub".to_string(),
TokenKind::Identifier("".to_string()),
),
token.span,
));
}
Some(Attribute {
kind: AttributeKind::Pub,
span: token.span,
})
} else {
None
};

// struct Foo { a: Field, b: Field }
// ^
let field_name = Ident::parse(ctx, tokens)?;
Expand All @@ -67,7 +97,7 @@ impl StructDef {
// ^^^^^
let field_ty = Ty::parse(ctx, tokens)?;
span = span.merge_with(field_ty.span);
fields.push((field_name, field_ty));
fields.push((field_name, field_ty, attribute));

// struct Foo { a: Field, b: Field }
// ^ ^
Expand Down
2 changes: 1 addition & 1 deletion src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ fn assert_eq_values<B: Backend>(

// compare each field recursively
let mut offset = 0;
for (_, field_type) in &struct_info.fields {
for (_, field_type, _) in &struct_info.fields {
let field_size = compiler.size_of(field_type);
let mut field_comparisons = assert_eq_values(
compiler,
Expand Down
20 changes: 19 additions & 1 deletion src/stdlib/native/int/lib.no
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,22 @@ fn Uint32.mod(self, rhs: Uint32) -> Uint32 {
fn Uint64.mod(self, rhs: Uint64) -> Uint64 {
let res = self.divmod(rhs);
return res[1];
}
}

// implement to field
fn Uint8.to_field(self) -> Field {
return self.inner;
}

fn Uint16.to_field(self) -> Field {
return self.inner;
}

fn Uint32.to_field(self) -> Field {
return self.inner;
}

fn Uint64.to_field(self) -> Field {
return self.inner;
}

2 changes: 1 addition & 1 deletion src/tests/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use mimoo::liblib;
// test a library's type that links to its own type
struct Inner {
inner: Field,
pub inner: Field,
}
struct Lib {
Expand Down
2 changes: 1 addition & 1 deletion src/tests/stdlib/uints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main(pub lhs: Field, rhs: Field) -> Field {
let res = lhs_u.{opr}(rhs_u);
return res.inner;
return res.to_field();
}
"#;

Expand Down
Loading

0 comments on commit a30a66b

Please sign in to comment.