diff --git a/numbat/src/interpreter/mod.rs b/numbat/src/interpreter/mod.rs index b5e93dda..b0554056 100644 --- a/numbat/src/interpreter/mod.rs +++ b/numbat/src/interpreter/mod.rs @@ -209,8 +209,9 @@ mod tests { .transform(statements) .expect("No name resolution errors for inputs in this test suite"); let mut typechecker = crate::typechecker::TypeChecker::default(); + let statements_typechecked = typechecker - .check(&statements_transformed) + .check(&mut crate::Environment::default(), &statements_transformed) .expect("No type check errors for inputs in this test suite"); BytecodeInterpreter::new().interpret_statements( &mut InterpreterSettings::default(), diff --git a/numbat/src/lib.rs b/numbat/src/lib.rs index aa728087..345b59c1 100644 --- a/numbat/src/lib.rs +++ b/numbat/src/lib.rs @@ -66,6 +66,7 @@ use resolver::CodeSource; use resolver::Resolver; use resolver::ResolverError; use thiserror::Error; +pub use typechecker::Environment; use typechecker::{TypeCheckError, TypeChecker}; pub use diagnostic::Diagnostic; @@ -102,6 +103,7 @@ type Result = std::result::Result>; pub struct Context { prefix_transformer: Transformer, typechecker: TypeChecker, + env: Environment, interpreter: BytecodeInterpreter, resolver: Resolver, load_currency_module_on_demand: bool, @@ -113,6 +115,7 @@ impl Context { Context { prefix_transformer: Transformer::new(), typechecker: TypeChecker::default(), + env: Environment::default(), interpreter: BytecodeInterpreter::new(), resolver: Resolver::new(module_importer), load_currency_module_on_demand: false, @@ -179,7 +182,7 @@ impl Context { .iter() .filter(|name| !name.starts_with('_')) .map(move |name| { - let (signature, meta) = self.typechecker.lookup_function(name).unwrap(); + let (signature, meta) = self.env.lookup_function(name).unwrap(); ( name.clone(), meta.name.clone(), @@ -468,7 +471,7 @@ impl Context { return help; } - if let Some((fn_signature, fn_metadata)) = self.typechecker.lookup_function(keyword) { + if let Some((fn_signature, fn_metadata)) = self.env.lookup_function(keyword) { let metadata = fn_metadata.clone(); let mut help = m::text("Function: "); @@ -592,10 +595,11 @@ impl Context { let transformed_statements = result?; let typechecker_old = self.typechecker.clone(); + let env_old = self.env.clone(); let result = self .typechecker - .check(&transformed_statements) + .check(&mut self.env, &transformed_statements) .map_err(|err| NumbatError::TypeCheckError(*err)); if result.is_err() { @@ -611,6 +615,7 @@ impl Context { // self.prefix_transformer = prefix_transformer_old.clone(); self.typechecker = typechecker_old.clone(); + self.env = env_old.clone(); if self.load_currency_module_on_demand { if let Err(NumbatError::TypeCheckError(TypeCheckError::UnknownIdentifier( @@ -831,6 +836,7 @@ impl Context { // self.prefix_transformer = prefix_transformer_old; self.typechecker = typechecker_old; + self.env = env_old; self.interpreter = interpreter_old; } diff --git a/numbat/src/typechecker/environment.rs b/numbat/src/typechecker/environment.rs index 0f885982..a4852400 100644 --- a/numbat/src/typechecker/environment.rs +++ b/numbat/src/typechecker/environment.rs @@ -1,4 +1,4 @@ -use crate::ast::{TypeAnnotation, TypeParameterBound}; +use crate::ast::{self, TypeAnnotation, TypeParameterBound}; use crate::dimension::DimensionRegistry; use crate::pretty_print::PrettyPrint; use crate::span::Span; @@ -166,6 +166,22 @@ impl Environment { } } + pub(crate) fn get_proper_function_reference<'a>( + &self, + expr: &ast::Expression<'a>, + ) -> Option<(&'a str, &FunctionSignature)> { + match expr { + ast::Expression::Identifier(_, name) => self + .get_function_info(name) + .map(|(signature, _)| (*name, signature)), + _ => None, + } + } + + pub fn lookup_function(&self, name: &str) -> Option<(&FunctionSignature, &FunctionMetadata)> { + self.get_function_info(name) + } + pub(crate) fn generalize_types(&mut self, dtype_variables: &[TypeVariable]) { for (_, kind) in self.identifiers.iter_mut() { match kind { diff --git a/numbat/src/typechecker/mod.rs b/numbat/src/typechecker/mod.rs index faa13588..89c48a73 100644 --- a/numbat/src/typechecker/mod.rs +++ b/numbat/src/typechecker/mod.rs @@ -31,7 +31,8 @@ use crate::{decorator, ffi, suggestion}; use const_evaluation::evaluate_const_expr; use constraints::{Constraint, ConstraintSet, ConstraintSolverError, TrivialResolution}; -use environment::{Environment, FunctionMetadata, FunctionSignature}; +pub use environment::Environment; +use environment::{FunctionMetadata, FunctionSignature}; use itertools::Itertools; use name_generator::NameGenerator; use num_traits::Zero; @@ -59,7 +60,6 @@ pub struct TypeChecker { type_namespace: Namespace, value_namespace: Namespace, - env: Environment, name_generator: NameGenerator, constraints: ConstraintSet, @@ -159,11 +159,10 @@ impl TypeChecker { } } - fn identifier_type(&self, span: Span, name: &str) -> Result { - Ok(self.env.get_identifier_type(name).ok_or_else(|| { + fn identifier_type(&self, env: &Environment, span: Span, name: &str) -> Result { + Ok(env.get_identifier_type(name).ok_or_else(|| { let suggestion = suggestion::did_you_mean( - self.env - .iter_identifiers() + env.iter_identifiers() .map(|k| k.as_str()) .chain(["true", "false"]) // These are parsed as keywords, but can act like identifiers .chain(ffi::procedures().values().map(|p| p.name)), @@ -173,19 +172,6 @@ impl TypeChecker { })?) } - fn get_proper_function_reference<'a>( - &self, - expr: &ast::Expression<'a>, - ) -> Option<(&'a str, &FunctionSignature)> { - match expr { - ast::Expression::Identifier(_, name) => self - .env - .get_function_info(name) - .map(|(signature, _)| (*name, signature)), - _ => None, - } - } - fn proper_function_call<'a>( &mut self, span: &Span, @@ -297,6 +283,7 @@ impl TypeChecker { fn elaborate_expression<'a>( &mut self, + env: &Environment, ast: &ast::Expression<'a>, ) -> Result> { Ok(match ast { @@ -315,7 +302,7 @@ impl TypeChecker { typed_ast::Expression::Scalar(*span, *n, TypeScheme::concrete(Type::scalar())) } ast::Expression::Identifier(span, name) => { - let type_scheme = self.identifier_type(*span, name)?.clone(); + let type_scheme = self.identifier_type(env, *span, name)?.clone(); let ty = match type_scheme { TypeScheme::Concrete(ty) => ty, @@ -332,7 +319,7 @@ impl TypeChecker { typed_ast::Expression::Identifier(*span, name, TypeScheme::concrete(ty)) } ast::Expression::UnitIdentifier(span, prefix, name, full_name) => { - let type_scheme = self.identifier_type(*span, name)?.clone(); + let type_scheme = self.identifier_type(env, *span, name)?.clone(); let qt = type_scheme.instantiate(&mut self.name_generator); @@ -349,7 +336,7 @@ impl TypeChecker { ) } ast::Expression::UnaryOperator { op, expr, span_op } => { - let checked_expr = self.elaborate_expression(expr)?; + let checked_expr = self.elaborate_expression(env, expr)?; let type_ = checked_expr.get_type(); match op { @@ -390,8 +377,8 @@ impl TypeChecker { rhs, span_op, } => { - let lhs_checked = self.elaborate_expression(lhs)?; - let rhs_checked = self.elaborate_expression(rhs)?; + let lhs_checked = self.elaborate_expression(env, lhs)?; + let rhs_checked = self.elaborate_expression(env, rhs)?; let lhs_type = lhs_checked.get_type(); let rhs_type = rhs_checked.get_type(); @@ -766,7 +753,7 @@ impl TypeChecker { ast::Expression::FunctionCall(span, full_span, callable, args) => { let arguments_checked = args .iter() - .map(|a| self.elaborate_expression(a)) + .map(|a| self.elaborate_expression(env, a)) .collect::>>()?; let argument_types = arguments_checked .iter() @@ -777,20 +764,17 @@ impl TypeChecker { // to a (proper) function, or it can be an arbitrary complicated expression // that evaluates to a function "pointer". - if let Some((name, signature)) = self.get_proper_function_reference(callable) { - // TODO: there is probably a better way to get around borrowing issues here - let signature = signature.clone(); - + if let Some((name, signature)) = env.get_proper_function_reference(callable) { self.proper_function_call( span, full_span, name, - &signature, + signature, arguments_checked, argument_types, )? } else { - let callable_checked = self.elaborate_expression(callable)?; + let callable_checked = self.elaborate_expression(env, callable)?; let callable_type = callable_checked.get_type(); let parameter_types = (0..arguments_checked.len()) @@ -885,13 +869,13 @@ impl TypeChecker { } => Ok(typed_ast::StringPart::Interpolation { span: *span, format_specifiers: format_specifiers.as_ref().copied(), - expr: Box::new(self.elaborate_expression(expr)?), + expr: Box::new(self.elaborate_expression(env, expr)?), }), }) .collect::>()?, ), ast::Expression::Condition(span, condition, then, else_) => { - let condition = self.elaborate_expression(condition)?; + let condition = self.elaborate_expression(env, condition)?; if self .add_equal_constraint(&condition.get_type(), &Type::Boolean) @@ -902,8 +886,8 @@ impl TypeChecker { ))); } - let then = self.elaborate_expression(then)?; - let else_ = self.elaborate_expression(else_)?; + let then = self.elaborate_expression(env, then)?; + let else_ = self.elaborate_expression(env, else_)?; let then_type = then.get_type(); let else_type = else_.get_type(); @@ -937,7 +921,7 @@ impl TypeChecker { let name = *name; let fields_checked = fields .iter() - .map(|(_, n, v)| Ok((*n, self.elaborate_expression(v)?))) + .map(|(_, n, v)| Ok((*n, self.elaborate_expression(env, v)?))) .collect::>>()?; let Some(struct_info) = self.structs.get(name).cloned() else { @@ -1012,7 +996,7 @@ impl TypeChecker { } ast::Expression::AccessField(full_span, ident_span, expr, field_name) => { let field_name = *field_name; - let expr_checked = self.elaborate_expression(expr)?; + let expr_checked = self.elaborate_expression(env, expr)?; let type_ = expr_checked.get_type(); @@ -1062,7 +1046,7 @@ impl TypeChecker { ast::Expression::List(span, elements) => { let elements_checked = elements .iter() - .map(|e| self.elaborate_expression(e)) + .map(|e| self.elaborate_expression(env, e)) .collect::>>()?; let element_types: Vec = @@ -1109,8 +1093,9 @@ impl TypeChecker { }) } - fn _elaborate_inner<'a>( + fn _elaborate_definition_impl<'a>( &mut self, + env: &Environment, definition: ElaborationDefinitionArgs<'a, '_>, ) -> Result<(typed_ast::Expression<'a>, typed_ast::Type)> { let ElaborationDefinitionArgs { @@ -1125,7 +1110,7 @@ impl TypeChecker { elaboration_kind, } = definition; - let expr_checked = self.elaborate_expression(expr)?; + let expr_checked = self.elaborate_expression(env, expr)?; let type_deduced = expr_checked.get_type(); if let Some(type_annotation) = type_annotation { @@ -1183,6 +1168,7 @@ impl TypeChecker { fn elaborate_define_variable<'a>( &mut self, + env: &mut Environment, define_variable: &ast::DefineVariable<'a>, ) -> Result> { let DefineVariable { @@ -1193,20 +1179,23 @@ impl TypeChecker { decorators, } = define_variable; - let (expr_checked, type_deduced) = self._elaborate_inner(ElaborationDefinitionArgs { - identifier_span: *identifier_span, - expr, - type_annotation_span: None, - type_annotation: type_annotation.as_ref(), - operation: "variable definition", - expected_name: "specified dimension", - actual_name: " actual dimension", - actual_name_for_fix: "right hand side expression", - elaboration_kind: "definition", - })?; + let (expr_checked, type_deduced) = self._elaborate_definition_impl( + env, + ElaborationDefinitionArgs { + identifier_span: *identifier_span, + expr, + type_annotation_span: None, + type_annotation: type_annotation.as_ref(), + operation: "variable definition", + expected_name: "specified dimension", + actual_name: " actual dimension", + actual_name_for_fix: "right hand side expression", + elaboration_kind: "definition", + }, + )?; for (name, _) in decorator::name_and_aliases(identifier, decorators) { - self.env.add( + env.add( name.to_owned(), type_deduced.clone(), *identifier_span, @@ -1234,13 +1223,14 @@ impl TypeChecker { fn elaborate_statement<'a>( &mut self, + env: &mut Environment, ast: &ast::Statement<'a>, ) -> Result> { Ok(match ast { ast::Statement::Expression(expr) => { - let checked_expr = self.elaborate_expression(expr)?; + let checked_expr = self.elaborate_expression(env, expr)?; for &identifier in LAST_RESULT_IDENTIFIERS { - self.env.add_predefined( + env.add_predefined( identifier.into(), TypeScheme::concrete(checked_expr.get_type()), ); @@ -1249,7 +1239,7 @@ impl TypeChecker { } ast::Statement::DefineVariable(define_variable) => { typed_ast::Statement::DefineVariable( - self.elaborate_define_variable(define_variable)?, + self.elaborate_define_variable(env, define_variable)?, ) } ast::Statement::DefineBaseUnit(span, unit_name, type_annotation, decorators) => { @@ -1279,7 +1269,7 @@ impl TypeChecker { .into() }; for (name, _) in decorator::name_and_aliases(unit_name, decorators) { - self.env.add( + env.add( name.to_string(), Type::Dimension(type_specified.clone()), *span, @@ -1302,8 +1292,9 @@ impl TypeChecker { type_annotation, decorators, } => { - let (expr_checked, type_deduced) = - self._elaborate_inner(ElaborationDefinitionArgs { + let (expr_checked, type_deduced) = self._elaborate_definition_impl( + env, + ElaborationDefinitionArgs { identifier_span: *identifier_span, expr, type_annotation_span: type_annotation_span.as_ref().copied(), @@ -1313,10 +1304,11 @@ impl TypeChecker { actual_name: " actual dimension", actual_name_for_fix: "right hand side expression", elaboration_kind: "unit definition", - })?; + }, + )?; for (name, _) in decorator::name_and_aliases(identifier, decorators) { - self.env.add( + env.add( name.to_string(), type_deduced.clone(), *identifier_span, @@ -1362,7 +1354,7 @@ impl TypeChecker { // Save the environment and namespaces to avoid polluting // their parents with the locals of this function - self.env.save(); + env.save(); self.type_namespace.save(); self.value_namespace.save(); @@ -1420,7 +1412,7 @@ impl TypeChecker { )); } - self.env.add_scheme( + env.add_scheme( parameter.to_string(), TypeScheme::make_quantified(parameter_type.clone()), *parameter_span, @@ -1458,7 +1450,7 @@ impl TypeChecker { let fn_type = TypeScheme::Concrete(Type::Fn(parameter_types, Box::new(return_type.clone()))); - self.env.add_function( + env.add_function( function_name.to_string(), FunctionSignature { name: function_name.to_string(), @@ -1484,12 +1476,13 @@ impl TypeChecker { let mut typed_local_variables = vec![]; for local_variable in local_variables { - typed_local_variables.push(self.elaborate_define_variable(local_variable)?); + typed_local_variables + .push(self.elaborate_define_variable(env, local_variable)?); } let body_checked = body .as_ref() - .map(|expr| self.elaborate_expression(expr)) + .map(|expr| self.elaborate_expression(env, expr)) .transpose()?; let return_type_inferred = if let Some(ref expr) = body_checked { @@ -1571,7 +1564,7 @@ impl TypeChecker { .ok(); // Copy identifier for the new function into local env: - let (signature, metadata) = self.env.get_function_info(function_name).unwrap(); + let (signature, metadata) = env.get_function_info(function_name).unwrap(); let signature = signature.clone(); let metadata = metadata.clone(); @@ -1579,8 +1572,8 @@ impl TypeChecker { // add the function name to the environment self.value_namespace.restore(); self.type_namespace.restore(); - self.env.restore(); - self.env.add_function( + env.restore(); + env.add_function( function_name.to_string(), signature.clone(), metadata.clone(), @@ -1663,7 +1656,7 @@ impl TypeChecker { let checked_args = args .iter() - .map(|e| self.elaborate_expression(e)) + .map(|e| self.elaborate_expression(env, e)) .collect::>>()?; typed_ast::Statement::ProcedureCall(kind.clone(), checked_args) @@ -1682,7 +1675,7 @@ impl TypeChecker { let checked_args = args .iter() - .map(|e| self.elaborate_expression(e)) + .map(|e| self.elaborate_expression(env, e)) .collect::>>()?; match kind { @@ -1787,6 +1780,7 @@ impl TypeChecker { fn check_statement<'a>( &mut self, + env: &mut Environment, statement: &ast::Statement<'a>, ) -> Result> { self.constraints.clear(); @@ -1795,7 +1789,7 @@ impl TypeChecker { // Elaborate the program/statement: turn the AST into a typed AST, possibly // with unification variables, i.e. type variables that will only later be // filled in after the constraints have been solved. - let mut elaborated_statement = self.elaborate_statement(statement)?; + let mut elaborated_statement = self.elaborate_statement(env, statement)?; // Solve constraints let (substitution, dtype_variables) = @@ -1818,7 +1812,7 @@ impl TypeChecker { TypeCheckError::SubstitutionError(elaborated_statement.pretty_print().to_string(), e) })?; - self.env.apply(&substitution).map_err(|e| { + env.apply(&substitution).map_err(|e| { TypeCheckError::SubstitutionError(elaborated_statement.pretty_print().to_string(), e) })?; @@ -1877,7 +1871,7 @@ impl TypeChecker { elaborated_statement.update_readable_types(&self.registry); - self.env.generalize_types(&dtype_variables); + env.generalize_types(&dtype_variables); // Check if there is a typed hole in the statement if let Some((span, type_of_hole)) = elaborated_statement.find_typed_hole()? { @@ -1887,8 +1881,7 @@ impl TypeChecker { .to_readable_type(&self.registry, true) .to_string(), elaborated_statement.pretty_print().to_string(), - self.env - .iter_relevant_matches() + env.iter_relevant_matches() .filter(|(_, t)| t == &type_of_hole) .take(10) .map(|(n, _)| n) @@ -1902,12 +1895,13 @@ impl TypeChecker { pub fn check<'a>( &mut self, + env: &mut Environment, statements: &[ast::Statement<'a>], ) -> Result>> { let mut checked_statements = vec![]; for statement in statements { - checked_statements.push(self.check_statement(statement)?); + checked_statements.push(self.check_statement(env, statement)?); } Ok(checked_statements) @@ -1916,8 +1910,4 @@ impl TypeChecker { pub(crate) fn registry(&self) -> &DimensionRegistry { &self.registry } - - pub fn lookup_function(&self, name: &str) -> Option<(&FunctionSignature, &FunctionMetadata)> { - self.env.get_function_info(name) - } } diff --git a/numbat/src/typechecker/tests/mod.rs b/numbat/src/typechecker/tests/mod.rs index ae71a29c..06119543 100644 --- a/numbat/src/typechecker/tests/mod.rs +++ b/numbat/src/typechecker/tests/mod.rs @@ -5,7 +5,7 @@ use crate::parser::parse; use crate::prefix_transformer::Transformer; use crate::typechecker::{Result, TypeCheckError}; use crate::typed_ast::{self, DType}; -use crate::Statement; +use crate::{Environment, Statement}; use super::type_scheme::TypeScheme; use super::TypeChecker; @@ -50,6 +50,8 @@ fn type_c() -> DType { } fn run_typecheck(input: &str) -> Result> { + let mut env = Environment::default(); + let statements = parse(TEST_PRELUDE, 0) .expect("No parse errors for inputs in this test suite") .into_iter() @@ -60,7 +62,7 @@ fn run_typecheck(input: &str) -> Result> { .map_err(|err| Box::new(err.into()))?; TypeChecker::default() - .check(&transformed_statements) + .check(&mut env, &transformed_statements) .map(|mut statements_checked| statements_checked.pop().unwrap()) } diff --git a/numbat/src/typed_ast.rs b/numbat/src/typed_ast.rs index f1221bb2..6ab79204 100644 --- a/numbat/src/typed_ast.rs +++ b/numbat/src/typed_ast.rs @@ -1370,6 +1370,7 @@ mod tests { use crate::ast::ReplaceSpans; use crate::markup::{Formatter, PlainTextFormatter}; use crate::prefix_transformer::Transformer; + use crate::Environment; fn parse(code: &str) -> Statement { let statements = crate::parser::parse( @@ -1436,8 +1437,10 @@ mod tests { let mut transformer = Transformer::new(); let transformed_statements = transformer.transform(statements).unwrap().replace_spans(); + let mut env = Environment::default(); + crate::typechecker::TypeChecker::default() - .check(&transformed_statements) + .check(&mut env, &transformed_statements) .unwrap() .last() .unwrap()