Skip to content

Commit

Permalink
Merge pull request #216 from zksecurity/fix/forbid-mut-var-forloop
Browse files Browse the repository at this point in the history
fix: forbid generic call with mutable vars
  • Loading branch information
katat authored Oct 30, 2024
2 parents 2a09ba0 + b6e3a56 commit b15f081
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 10 deletions.
42 changes: 38 additions & 4 deletions book/src/rfc/rfc-0-generic-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,49 @@ fn foo(const NN: Field) {...}
fn foo(arr: [Field; NN]) {...}
```

*Forbid generic function in for-loop*
*Restrictions over generic function in for-loop*
**Mutable Variables as Generic Arguments**: It's prohibited to use mutable variables as generic arguments in generic function calls inside loops. The language doesn't support loop unrolling, so using loop indices or mutable counters as generic parameters is invalid.

Invalid example:
```rust
fn fn_call(const LEN: Field) -> [Field; LEN] {...}

...
for ii in 0..NN {
fn_call(ii); // Error: 'ii' is mutable
}

...
let mut jj = 0;
for ii in 0..NN {
fn_call(jj); // Error: 'jj' is mutable
jj = jj + 1;
}
```

**Allowed Usage with Constants**: You can use constant values or immutable variables as generic arguments within loops.

Valid example:
```rust
let kk = 0;
for ii in 0..NN {
fn_call(kk); // Allowed: 'kk' is constant
}
```

**Exception for Arrays**: Mutable array variables can be used as generic arguments because their sizes are fixed at declaration, even if their contents change.

For example:
```rust
fn fn_call_arr(const arr: [Field; LEN]) -> [Field; LEN] {...}
...

let mut arr = [0; 3];
for ii in 0..NN {
// any function takes the for loop var as its argument should be forbidden
fn_call(ii);
fn_call_arr(arr); // Allowed: array size is fixed
}
```

To allow generic functions in for-loop, we will need to take care of unrolling the loop and instantiating the function with the concrete value of the loop variable. This is not in the scope of this RFC.

*Forbid operations on symbolic value of arguments*
```rust
Expand Down
80 changes: 80 additions & 0 deletions src/negative_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,58 @@ fn test_generic_const_for_loop() {
));
}

#[test]
fn test_generic_const_mut_for_loop() {
let code = r#"
// generic on const argument
fn gen(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}
fn loop() {
let mut size = 2;
for ii in 0..3 {
gen(size);
}
}
"#;

let res = tast_pass(code).0;

assert!(matches!(
res.unwrap_err().kind,
ErrorKind::VarAccessForbiddenInForLoop(..)
));
}

#[test]
fn test_generic_mut_struct_for_loop() {
let code = r#"
struct Thing {
xx: Field,
}
// generic on const argument
fn gen(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}
fn loop() {
let mut thing = Thing {xx: 3};
for ii in 0..3 {
gen(thing.xx);
}
}
"#;

let res = tast_pass(code).0;

assert!(matches!(
res.unwrap_err().kind,
ErrorKind::VarAccessForbiddenInForLoop(..)
));
}

#[test]
fn test_generic_const_nested_for_loop() {
let code = r#"
Expand Down Expand Up @@ -201,6 +253,34 @@ fn test_generic_method_cst_for_loop() {
));
}

#[test]
fn test_generic_struct_self_for_loop() {
let code = r#"
struct Thing {
xx: Field,
}
// generic on const argument
fn Thing.gen(self, const LEN: Field) -> [Field; LEN] {
return [self.xx; LEN];
}
fn loop() {
let thing = Thing { xx: 3 };
for ii in 0..3 {
thing.gen(ii);
}
}
"#;

let res = tast_pass(code).0;

assert!(matches!(
res.unwrap_err().kind,
ErrorKind::VarAccessForbiddenInForLoop(..)
));
}

#[test]
fn test_generic_method_array_for_loop() {
let code = r#"
Expand Down
10 changes: 7 additions & 3 deletions src/type_checker/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,13 @@ impl<B: Backend> TypeChecker<B> {

// check if generic is allowed
if method_type.sig.require_monomorphization() && typed_fn_env.is_in_forloop() {
for (observed_arg, expected_arg) in
args.iter().zip(method_type.sig.arguments.iter())
{
for (observed_arg, expected_arg) in args.iter().zip(
method_type
.sig
.arguments
.iter()
.filter(|arg| arg.name.value != "self"),
) {
// check if the arg involves generic vars
if !expected_arg.extract_generic_names().is_empty() {
let mut forbidden_env = typed_fn_env.clone();
Expand Down
20 changes: 17 additions & 3 deletions src/type_checker/fn_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,28 @@ impl TypedFnEnv {
self.current_scope >= prefix_scope
}

pub fn is_forbidden(&self, scope: usize) -> bool {
/// Since currently we don't support unrolling, the generic function calls are assumed to target a same instance.
/// Each loop iteration should instantiate generic function calls with the same parameters.
/// This assumption requires a few type checking rules to forbid the cases that needs unrolling.
/// Forbid rules:
/// - Access to variables within the for loop scope.
/// - Access to mutable variables, except if it is an array.
/// Because once the array is declared, the size is fixed even if the array is mutable,
/// so the generic value resolved from array size will be same for generic function argument.
pub fn is_forbidden(&self, scope: usize, ty_info: TypeInfo) -> bool {
let in_forbidden_scope = if let Some(forloop_scope) = self.forloop_scopes.first() {
scope >= *forloop_scope
} else {
false
};

self.forbid_forloop_scope && in_forbidden_scope
let forbidden_mutable = ty_info.mutable
&& !matches!(
ty_info.typ,
TyKind::GenericSizedArray(..) | TyKind::Array(..)
);

self.forbid_forloop_scope && (in_forbidden_scope || forbidden_mutable)
}

/// Stores type information about a local variable.
Expand Down Expand Up @@ -151,7 +165,7 @@ impl TypedFnEnv {
// TODO: return an error no?
pub fn get_type_info(&self, ident: &str) -> Result<Option<&TypeInfo>> {
if let Some((scope, type_info)) = self.vars.get(ident) {
if self.is_forbidden(*scope) {
if self.is_forbidden(*scope, type_info.clone()) {
return Err(Error::new(
"type-checker",
ErrorKind::VarAccessForbiddenInForLoop(ident.to_string()),
Expand Down

0 comments on commit b15f081

Please sign in to comment.