diff --git a/numbat/src/typechecker/constraints.rs b/numbat/src/typechecker/constraints.rs index a799a5c0..753001a8 100644 --- a/numbat/src/typechecker/constraints.rs +++ b/numbat/src/typechecker/constraints.rs @@ -68,6 +68,14 @@ impl ConstraintSet { result } + pub(crate) fn add_equal_constraint(&mut self, lhs: &Type, rhs: &Type) -> TrivialResolution { + self.add(Constraint::Equal(lhs.clone(), rhs.clone())) + } + + pub(crate) fn add_dtype_constraint(&mut self, type_: &Type) -> TrivialResolution { + self.add(Constraint::IsDType(type_.clone())) + } + pub fn clear(&mut self) { self.constraints.clear(); } diff --git a/numbat/src/typechecker/environment.rs b/numbat/src/typechecker/environment.rs index 726bab34..51dc6ae0 100644 --- a/numbat/src/typechecker/environment.rs +++ b/numbat/src/typechecker/environment.rs @@ -183,6 +183,18 @@ impl Environment { } } } + + pub(crate) fn get_proper_function_reference<'a>( + &self, + expr: &crate::ast::Expression<'a>, + ) -> Option<(&'a str, &FunctionSignature)> { + match expr { + crate::ast::Expression::Identifier(_, name) => self + .get_function_info(name) + .map(|(signature, _)| (*name, signature)), + _ => None, + } + } } impl ApplySubstitution for Environment { diff --git a/numbat/src/typechecker/mod.rs b/numbat/src/typechecker/mod.rs index 744001ff..e7ea4597 100644 --- a/numbat/src/typechecker/mod.rs +++ b/numbat/src/typechecker/mod.rs @@ -53,6 +53,131 @@ fn dtype(e: &Expression) -> Result { } } +struct ProperFunctionCallArgs<'a, 'b> { + registry: &'b DimensionRegistry, + constraints: &'b mut ConstraintSet, + name_generator: &'b mut NameGenerator, + span: &'b Span, + full_span: &'b Span, + function_name: &'a str, + signature: &'b FunctionSignature, + arguments: Vec>, + argument_types: Vec, +} + +fn proper_function_call<'a>( + ProperFunctionCallArgs { + registry, + constraints, + name_generator, + span, + full_span, + function_name, + signature, + arguments, + argument_types, + }: ProperFunctionCallArgs<'a, '_>, +) -> Result> { + let FunctionSignature { + name: _, + definition_span, + type_parameters: _, + parameters, + return_type_annotation: _, + fn_type, + } = signature; + + let fn_type = match fn_type { + TypeScheme::Concrete(t) => { + // This branch is needed for recursive functions, where the type of the function + // is not yet known (and not yet quantified). + t.clone() + } + TypeScheme::Quantified(_, _) => { + let qt = fn_type.instantiate(name_generator); + + for Bound::IsDim(t) in qt.bounds.iter() { + constraints.add_dtype_constraint(t).ok(); + } + + qt.inner + } + }; + + let Type::Fn(parameter_types, return_type) = fn_type else { + unreachable!("Expected function type, got {:#?}", fn_type); + }; + + let arity_range = parameters.len()..=parameters.len(); + + if !arity_range.contains(&arguments.len()) { + return Err(Box::new(TypeCheckError::WrongArity { + callable_span: *span, + callable_name: function_name.to_owned(), + callable_definition_span: Some(*definition_span), + arity: arity_range, + num_args: arguments.len(), + })); + } + + for (idx, ((parameter_span, parameter_type), argument_type)) in parameters + .iter() + .map(|p| p.0) + .zip(parameter_types.iter()) + .zip(argument_types) + .enumerate() + { + if constraints + .add_equal_constraint(parameter_type, &argument_type) + .is_trivially_violated() + { + match (parameter_type, &argument_type) { + (Type::Dimension(parameter_dtype), Type::Dimension(argument_dtype)) => { + return Err(Box::new(TypeCheckError::IncompatibleDimensions( + IncompatibleDimensionsError { + span_operation: *span, + operation: format_compact!( + "argument {num} of function call to '{name}'", + num = idx + 1, + name = function_name + ), + span_expected: parameter_span, + expected_name: "parameter type", + expected_dimensions: registry.get_derived_entry_names_for( + ¶meter_dtype.to_base_representation(), + ), + expected_type: parameter_dtype.to_base_representation(), + span_actual: arguments[idx].full_span(), + actual_name: " argument type", + actual_name_for_fix: "function argument", + actual_dimensions: registry.get_derived_entry_names_for( + &argument_dtype.to_base_representation(), + ), + actual_type: argument_dtype.to_base_representation(), + }, + ))); + } + _ => { + return Err(Box::new(TypeCheckError::IncompatibleTypesInFunctionCall( + Some(parameter_span), + parameter_type.clone(), + arguments[idx].full_span(), + argument_type.clone(), + ))); + } + } + } + } + + Ok(typed_ast::Expression::FunctionCall( + *span, + *full_span, + function_name, + arguments, + TypeScheme::concrete(return_type.as_ref().clone()), + )) +} + #[derive(Clone, Default)] pub struct TypeChecker { structs: HashMap, @@ -84,12 +209,11 @@ impl TypeChecker { } fn add_equal_constraint(&mut self, lhs: &Type, rhs: &Type) -> TrivialResolution { - self.constraints - .add(Constraint::Equal(lhs.clone(), rhs.clone())) + self.constraints.add_equal_constraint(lhs, rhs) } fn add_dtype_constraint(&mut self, type_: &Type) -> TrivialResolution { - self.constraints.add(Constraint::IsDType(type_.clone())) + self.constraints.add_dtype_constraint(type_) } fn enforce_dtype(&mut self, type_: &Type, span: Span) -> Result<()> { @@ -174,128 +298,6 @@ impl TypeChecker { })?) } - fn get_proper_function_reference<'a>( - &self, - expr: &ast::Expression<'a>, - ) -> Option<(&'a str, &FunctionSignature)> { - match expr { - ast::Expression::Identifier(_, name) => self - .env - .get_function_info(name) - .map(|(signature, _)| (*name, signature)), - _ => None, - } - } - - fn proper_function_call<'a>( - &mut self, - span: &Span, - full_span: &Span, - function_name: &'a str, - signature: &FunctionSignature, - arguments: Vec>, - argument_types: Vec, - ) -> Result> { - let FunctionSignature { - name: _, - definition_span, - type_parameters: _, - parameters, - return_type_annotation: _, - fn_type, - } = signature; - - let fn_type = match fn_type { - TypeScheme::Concrete(t) => { - // This branch is needed for recursive functions, where the type of the function - // is not yet known (and not yet quantified). - t.clone() - } - TypeScheme::Quantified(_, _) => { - let qt = fn_type.instantiate(&mut self.name_generator); - - for Bound::IsDim(t) in qt.bounds.iter() { - self.add_dtype_constraint(t).ok(); - } - - qt.inner - } - }; - - let Type::Fn(parameter_types, return_type) = fn_type else { - unreachable!("Expected function type, got {:#?}", fn_type); - }; - - let arity_range = parameters.len()..=parameters.len(); - - if !arity_range.contains(&arguments.len()) { - return Err(Box::new(TypeCheckError::WrongArity { - callable_span: *span, - callable_name: function_name.to_owned(), - callable_definition_span: Some(*definition_span), - arity: arity_range, - num_args: arguments.len(), - })); - } - - for (idx, ((parameter_span, parameter_type), argument_type)) in parameters - .iter() - .map(|p| p.0) - .zip(parameter_types.iter()) - .zip(argument_types) - .enumerate() - { - if self - .add_equal_constraint(parameter_type, &argument_type) - .is_trivially_violated() - { - match (parameter_type, &argument_type) { - (Type::Dimension(parameter_dtype), Type::Dimension(argument_dtype)) => { - return Err(Box::new(TypeCheckError::IncompatibleDimensions( - IncompatibleDimensionsError { - span_operation: *span, - operation: format_compact!( - "argument {num} of function call to '{name}'", - num = idx + 1, - name = function_name - ), - span_expected: parameter_span, - expected_name: "parameter type", - expected_dimensions: self.registry.get_derived_entry_names_for( - ¶meter_dtype.to_base_representation(), - ), - expected_type: parameter_dtype.to_base_representation(), - span_actual: arguments[idx].full_span(), - actual_name: " argument type", - actual_name_for_fix: "function argument", - actual_dimensions: self.registry.get_derived_entry_names_for( - &argument_dtype.to_base_representation(), - ), - actual_type: argument_dtype.to_base_representation(), - }, - ))); - } - _ => { - return Err(Box::new(TypeCheckError::IncompatibleTypesInFunctionCall( - Some(parameter_span), - parameter_type.clone(), - arguments[idx].full_span(), - argument_type.clone(), - ))); - } - } - } - } - - Ok(typed_ast::Expression::FunctionCall( - *span, - *full_span, - function_name, - arguments, - TypeScheme::concrete(return_type.as_ref().clone()), - )) - } - fn elaborate_expression<'a>( &mut self, ast: &ast::Expression<'a>, @@ -778,18 +780,20 @@ impl TypeChecker { // to a (proper) function, or it can be an arbitrary complicated expression // that evaluates to a function "pointer". - if let Some((name, signature)) = self.get_proper_function_reference(callable) { - // TODO: there is probably a better way to get around borrowing issues here - let signature = signature.clone(); - - self.proper_function_call( + if let Some((function_name, signature)) = + self.env.get_proper_function_reference(callable) + { + proper_function_call(ProperFunctionCallArgs { + registry: &mut self.registry, + constraints: &mut self.constraints, + name_generator: &mut self.name_generator, span, full_span, - name, - &signature, - arguments_checked, + function_name, + signature, + arguments: arguments_checked, argument_types, - )? + })? } else { let callable_checked = self.elaborate_expression(callable)?; let callable_type = callable_checked.get_type();