Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: builtins #33

Merged
merged 2 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/imports.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use std::{collections::HashMap, fmt};

use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};

use crate::{
circuit_writer::{CircuitWriter, VarInfo},
constants::Span,
error::Result,
parser::types::{FnSig, FunctionDef},
stdlib::{parse_fn_sigs, BUILTIN_FNS_DEFS},
type_checker::{FnInfo, TypeChecker},
var::Var,
};
Expand Down Expand Up @@ -73,6 +71,3 @@ impl fmt::Debug for FnKind {
}
}

// static of built-in functions
pub static BUILTIN_FNS: Lazy<HashMap<String, FnInfo>> =
Lazy::new(|| parse_fn_sigs(&BUILTIN_FNS_DEFS));
5 changes: 2 additions & 3 deletions src/name_resolution/expr.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::{
cli::packages::UserRepo,
error::Result,
imports::BUILTIN_FNS,
parser::{types::ModulePath, CustomType, Expr, ExprKind},
stdlib::QUALIFIED_BUILTINS,
stdlib::{BUILTIN_FN_NAMES, QUALIFIED_BUILTINS},
};

use super::context::NameResCtx;
Expand All @@ -22,7 +21,7 @@ impl NameResCtx {
fn_name,
args,
} => {
if matches!(module, ModulePath::Local) && BUILTIN_FNS.get(&fn_name.value).is_some()
if matches!(module, ModulePath::Local) && BUILTIN_FN_NAMES.contains(&fn_name.value)
{
// if it's a builtin, use `std::builtin`
*module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS));
Expand Down
4 changes: 2 additions & 2 deletions src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::{
cli::packages::UserRepo,
constants::{Field, Span},
error::{ErrorKind, Result},
imports::BUILTIN_FNS,
lexer::{Keyword, Token, TokenKind, Tokens},
stdlib::BUILTIN_FN_NAMES,
syntax::is_type,
};

Expand Down Expand Up @@ -785,7 +785,7 @@ impl FunctionDef {
let sig = FnSig::parse(ctx, tokens)?;

// make sure that it doesn't shadow a builtin
if BUILTIN_FNS.get(&sig.name.value).is_some() {
if BUILTIN_FN_NAMES.contains(&sig.name.value) {
return Err(ctx.error(
ErrorKind::ShadowingBuiltIn(sig.name.value.clone()),
sig.name.span,
Expand Down
32 changes: 30 additions & 2 deletions src/stdlib/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,46 @@ use kimchi::circuits::polynomials::poseidon::{POS_ROWS_PER_HASH, ROUNDS_PER_ROW}
use kimchi::mina_poseidon::constants::{PlonkSpongeConstantsKimchi, SpongeConstants};
use kimchi::mina_poseidon::permutation::full_round;

use crate::imports::FnKind;
use crate::lexer::Token;
use crate::parser::types::FnSig;
use crate::parser::ParserCtx;
use crate::type_checker::FnInfo;
use crate::{
circuit_writer::{CircuitWriter, GateKind, VarInfo},
constants::{self, Field, Span},
error::{ErrorKind, Result},
imports::FnHandle,
parser::types::TyKind,
var::{ConstOrCell, Value, Var},
};

const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]";

pub const CRYPTO_FNS: [(&str, FnHandle); 1] = [(POSEIDON_FN, poseidon)];
pub const CRYPTO_SIGS: &[&str] = &[POSEIDON_FN];

pub fn get_crypto_fn(name: &str) -> Option<FnInfo> {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, name).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

let fn_handle = match name {
POSEIDON_FN => poseidon,
_ => return None,
};

Some(FnInfo {
kind: FnKind::BuiltIn(sig, fn_handle),
span: Span::default(),
})
}

/// a function returns crypto functions
pub fn crypto_fns() -> Vec<FnInfo> {
CRYPTO_SIGS
.iter()
.map(|sig| get_crypto_fn(sig).unwrap())
.collect()
}

pub fn poseidon(compiler: &mut CircuitWriter, vars: &[VarInfo], span: Span) -> Result<Option<Var>> {
//
Expand Down
98 changes: 43 additions & 55 deletions src/stdlib/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, ops::Neg as _};
use std::{collections::{HashMap, HashSet}, ops::Neg as _};

use ark_ff::{One as _, Zero};
use once_cell::sync::Lazy;
Expand All @@ -7,7 +7,7 @@ use crate::{
circuit_writer::{CircuitWriter, VarInfo},
constants::{Field, Span},
error::{Error, ErrorKind, Result},
imports::{BuiltinModule, FnHandle, FnKind},
imports::{FnHandle, FnKind},
lexer::Token,
parser::{
types::{FnSig, TyKind},
Expand All @@ -17,60 +17,10 @@ use crate::{
var::{ConstOrCell, Var},
};

use self::crypto::CRYPTO_FNS;
use self::crypto::get_crypto_fn;

pub mod crypto;

pub static CRYPTO_MODULE: Lazy<BuiltinModule> = Lazy::new(|| {
let functions = parse_fn_sigs(&CRYPTO_FNS);
BuiltinModule { functions }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you delete the parse_fn_sigs function now?

});

pub fn get_std_fn(submodule: &str, fn_name: &str, span: Span) -> Result<FnInfo> {
match submodule {
"crypto" => CRYPTO_MODULE
.functions
.get(fn_name)
.cloned()
.ok_or_else(|| {
Error::new(
"type-checker",
ErrorKind::UnknownExternalFn(submodule.to_string(), fn_name.to_string()),
span,
)
}),
_ => Err(Error::new(
"type-checker",
ErrorKind::StdImport(submodule.to_string()),
span,
)),
}
}

/// Takes a list of function signatures (as strings) and their associated function pointer,
/// returns the same list but with the parsed functions (as [FunctionSig]).
pub fn parse_fn_sigs(fn_sigs: &[(&str, FnHandle)]) -> HashMap<String, FnInfo> {
let mut functions = HashMap::new();
let ctx = &mut ParserCtx::default();

for (sig, fn_ptr) in fn_sigs {
// filename_id 0 is for builtins
let mut tokens = Token::parse(0, sig).unwrap();

let sig = FnSig::parse(ctx, &mut tokens).unwrap();

functions.insert(
sig.name.value.clone(),
FnInfo {
kind: FnKind::BuiltIn(sig, *fn_ptr),
span: Span::default(),
},
);
}

functions
}

//
// Builtins or utils (imported by default)
// TODO: give a name that's useful for the user,
Expand All @@ -81,8 +31,46 @@ pub const QUALIFIED_BUILTINS: &str = "std/builtins";
const ASSERT_FN: &str = "assert(condition: Bool)";
const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)";

pub const BUILTIN_FNS_DEFS: [(&str, FnHandle); 2] =
[(ASSERT_EQ_FN, assert_eq), (ASSERT_FN, assert)];
/// List of builtin function signatures.
pub const BUILTIN_SIGS: &[&str] = &[ASSERT_FN, ASSERT_EQ_FN];

// Unique set of builtin function names, derived from function signatures.
pub static BUILTIN_FN_NAMES: Lazy<HashSet<String>> = Lazy::new(|| {
BUILTIN_SIGS
.iter()
.map(|s| {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, s).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();
sig.name.value
})
.collect()
});

pub fn get_builtin_fn(name: &str) -> Option<FnInfo> {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, name).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

let fn_handle = match name {
ASSERT_FN => assert,
ASSERT_EQ_FN => assert_eq,
_ => return None,
};

Some(FnInfo {
kind: FnKind::BuiltIn(sig, fn_handle),
span: Span::default(),
})
}

/// a function returns builtin functions
pub fn builtin_fns() -> Vec<FnInfo> {
BUILTIN_SIGS
.iter()
.map(|sig| get_builtin_fn(sig).unwrap())
.collect()
}

/// Asserts that two vars are equal.
fn assert_eq(compiler: &mut CircuitWriter, vars: &[VarInfo], span: Span) -> Result<Option<Var>> {
Expand Down
19 changes: 7 additions & 12 deletions src/type_checker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::{
cli::packages::UserRepo,
constants::{Field, Span},
error::{Error, ErrorKind, Result},
imports::{FnKind, BUILTIN_FNS},
imports::FnKind,
name_resolution::NAST,
parser::{
types::{FuncOrMethod, FunctionDef, ModulePath, RootKind, Ty, TyKind},
CustomType, Expr, StructDef,
},
stdlib::{CRYPTO_MODULE, QUALIFIED_BUILTINS},
stdlib::{builtin_fns, crypto::crypto_fns, QUALIFIED_BUILTINS},
};

pub use checker::{FnInfo, StructInfo};
Expand Down Expand Up @@ -91,12 +91,7 @@ impl TypeChecker {
}

pub(crate) fn fn_info(&self, qualified: &FullyQualified) -> Option<&FnInfo> {
if qualified.module == Some(UserRepo::new("std/builtins")) {
// if it's a built-in: get it from a global
BUILTIN_FNS.get(&qualified.name)
} else {
self.functions.get(qualified)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm puzzled as to why this code existed before, I guess it was useless right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is it was trying to differentiate the crypto builtin functions and native functions.
But since both the builtins and natives are stored in the same place now, this code seem deprecated.

self.functions.get(qualified)
}

pub(crate) fn const_info(&self, qualified: &FullyQualified) -> Option<&ConstInfo> {
Expand Down Expand Up @@ -141,8 +136,8 @@ impl TypeChecker {

// initialize it with the builtins
let builtin_module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS));
for (fn_name, fn_info) in BUILTIN_FNS.iter() {
let qualified = FullyQualified::new(&builtin_module, fn_name);
for fn_info in builtin_fns() {
let qualified = FullyQualified::new(&builtin_module, &fn_info.sig().name.value);
if type_checker
.functions
.insert(qualified, fn_info.clone())
Expand All @@ -154,8 +149,8 @@ impl TypeChecker {

// initialize it with the standard library
let crypto_module = ModulePath::Absolute(UserRepo::new("std/crypto"));
for (fn_name, fn_info) in CRYPTO_MODULE.functions.iter() {
let qualified = FullyQualified::new(&crypto_module, fn_name);
for fn_info in crypto_fns() {
let qualified = FullyQualified::new(&crypto_module, &fn_info.sig().name.value);
if type_checker
.functions
.insert(qualified, fn_info.clone())
Expand Down
Loading