Skip to content

Commit

Permalink
Made proper_function_call static and moved `get_proper_function_ref…
Browse files Browse the repository at this point in the history
…erence` to `Environment` to prove to the compiler that mutable references don't overlap
  • Loading branch information
rben01 committed Oct 27, 2024
1 parent da18b78 commit 1a276f3
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 134 deletions.
8 changes: 8 additions & 0 deletions numbat/src/typechecker/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ impl ConstraintSet {
result
}

pub(crate) fn add_equal_constraint(&mut self, lhs: &Type, rhs: &Type) -> TrivialResolution {
self.add(Constraint::Equal(lhs.clone(), rhs.clone()))
}

pub(crate) fn add_dtype_constraint(&mut self, type_: &Type) -> TrivialResolution {
self.add(Constraint::IsDType(type_.clone()))
}

pub fn clear(&mut self) {
self.constraints.clear();
}
Expand Down
12 changes: 12 additions & 0 deletions numbat/src/typechecker/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ impl Environment {
}
}
}

pub(crate) fn get_proper_function_reference<'a>(
&self,
expr: &crate::ast::Expression<'a>,
) -> Option<(&'a str, &FunctionSignature)> {
match expr {
crate::ast::Expression::Identifier(_, name) => self
.get_function_info(name)
.map(|(signature, _)| (*name, signature)),
_ => None,
}
}
}

impl ApplySubstitution for Environment {
Expand Down
272 changes: 138 additions & 134 deletions numbat/src/typechecker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,131 @@ fn dtype(e: &Expression) -> Result<DType> {
}
}

struct ProperFunctionCallArgs<'a, 'b> {
registry: &'b DimensionRegistry,
constraints: &'b mut ConstraintSet,
name_generator: &'b mut NameGenerator,
span: &'b Span,
full_span: &'b Span,
function_name: &'a str,
signature: &'b FunctionSignature,
arguments: Vec<typed_ast::Expression<'a>>,
argument_types: Vec<Type>,
}

fn proper_function_call<'a>(
ProperFunctionCallArgs {
registry,
constraints,
name_generator,
span,
full_span,
function_name,
signature,
arguments,
argument_types,
}: ProperFunctionCallArgs<'a, '_>,
) -> Result<typed_ast::Expression<'a>> {
let FunctionSignature {
name: _,
definition_span,
type_parameters: _,
parameters,
return_type_annotation: _,
fn_type,
} = signature;

let fn_type = match fn_type {
TypeScheme::Concrete(t) => {
// This branch is needed for recursive functions, where the type of the function
// is not yet known (and not yet quantified).
t.clone()
}
TypeScheme::Quantified(_, _) => {
let qt = fn_type.instantiate(name_generator);

for Bound::IsDim(t) in qt.bounds.iter() {
constraints.add_dtype_constraint(t).ok();
}

qt.inner
}
};

let Type::Fn(parameter_types, return_type) = fn_type else {
unreachable!("Expected function type, got {:#?}", fn_type);
};

let arity_range = parameters.len()..=parameters.len();

if !arity_range.contains(&arguments.len()) {
return Err(Box::new(TypeCheckError::WrongArity {
callable_span: *span,
callable_name: function_name.to_owned(),
callable_definition_span: Some(*definition_span),
arity: arity_range,
num_args: arguments.len(),
}));
}

for (idx, ((parameter_span, parameter_type), argument_type)) in parameters
.iter()
.map(|p| p.0)
.zip(parameter_types.iter())
.zip(argument_types)
.enumerate()
{
if constraints
.add_equal_constraint(parameter_type, &argument_type)
.is_trivially_violated()
{
match (parameter_type, &argument_type) {
(Type::Dimension(parameter_dtype), Type::Dimension(argument_dtype)) => {
return Err(Box::new(TypeCheckError::IncompatibleDimensions(
IncompatibleDimensionsError {
span_operation: *span,
operation: format_compact!(
"argument {num} of function call to '{name}'",
num = idx + 1,
name = function_name
),
span_expected: parameter_span,
expected_name: "parameter type",
expected_dimensions: registry.get_derived_entry_names_for(
&parameter_dtype.to_base_representation(),
),
expected_type: parameter_dtype.to_base_representation(),
span_actual: arguments[idx].full_span(),
actual_name: " argument type",
actual_name_for_fix: "function argument",
actual_dimensions: registry.get_derived_entry_names_for(
&argument_dtype.to_base_representation(),
),
actual_type: argument_dtype.to_base_representation(),
},
)));
}
_ => {
return Err(Box::new(TypeCheckError::IncompatibleTypesInFunctionCall(
Some(parameter_span),
parameter_type.clone(),
arguments[idx].full_span(),
argument_type.clone(),
)));
}
}
}
}

Ok(typed_ast::Expression::FunctionCall(
*span,
*full_span,
function_name,
arguments,
TypeScheme::concrete(return_type.as_ref().clone()),
))
}

#[derive(Clone, Default)]
pub struct TypeChecker {
structs: HashMap<CompactString, StructInfo>,
Expand Down Expand Up @@ -84,12 +209,11 @@ impl TypeChecker {
}

fn add_equal_constraint(&mut self, lhs: &Type, rhs: &Type) -> TrivialResolution {
self.constraints
.add(Constraint::Equal(lhs.clone(), rhs.clone()))
self.constraints.add_equal_constraint(lhs, rhs)
}

fn add_dtype_constraint(&mut self, type_: &Type) -> TrivialResolution {
self.constraints.add(Constraint::IsDType(type_.clone()))
self.constraints.add_dtype_constraint(type_)
}

fn enforce_dtype(&mut self, type_: &Type, span: Span) -> Result<()> {
Expand Down Expand Up @@ -174,128 +298,6 @@ impl TypeChecker {
})?)
}

fn get_proper_function_reference<'a>(
&self,
expr: &ast::Expression<'a>,
) -> Option<(&'a str, &FunctionSignature)> {
match expr {
ast::Expression::Identifier(_, name) => self
.env
.get_function_info(name)
.map(|(signature, _)| (*name, signature)),
_ => None,
}
}

fn proper_function_call<'a>(
&mut self,
span: &Span,
full_span: &Span,
function_name: &'a str,
signature: &FunctionSignature,
arguments: Vec<typed_ast::Expression<'a>>,
argument_types: Vec<Type>,
) -> Result<typed_ast::Expression<'a>> {
let FunctionSignature {
name: _,
definition_span,
type_parameters: _,
parameters,
return_type_annotation: _,
fn_type,
} = signature;

let fn_type = match fn_type {
TypeScheme::Concrete(t) => {
// This branch is needed for recursive functions, where the type of the function
// is not yet known (and not yet quantified).
t.clone()
}
TypeScheme::Quantified(_, _) => {
let qt = fn_type.instantiate(&mut self.name_generator);

for Bound::IsDim(t) in qt.bounds.iter() {
self.add_dtype_constraint(t).ok();
}

qt.inner
}
};

let Type::Fn(parameter_types, return_type) = fn_type else {
unreachable!("Expected function type, got {:#?}", fn_type);
};

let arity_range = parameters.len()..=parameters.len();

if !arity_range.contains(&arguments.len()) {
return Err(Box::new(TypeCheckError::WrongArity {
callable_span: *span,
callable_name: function_name.to_owned(),
callable_definition_span: Some(*definition_span),
arity: arity_range,
num_args: arguments.len(),
}));
}

for (idx, ((parameter_span, parameter_type), argument_type)) in parameters
.iter()
.map(|p| p.0)
.zip(parameter_types.iter())
.zip(argument_types)
.enumerate()
{
if self
.add_equal_constraint(parameter_type, &argument_type)
.is_trivially_violated()
{
match (parameter_type, &argument_type) {
(Type::Dimension(parameter_dtype), Type::Dimension(argument_dtype)) => {
return Err(Box::new(TypeCheckError::IncompatibleDimensions(
IncompatibleDimensionsError {
span_operation: *span,
operation: format_compact!(
"argument {num} of function call to '{name}'",
num = idx + 1,
name = function_name
),
span_expected: parameter_span,
expected_name: "parameter type",
expected_dimensions: self.registry.get_derived_entry_names_for(
&parameter_dtype.to_base_representation(),
),
expected_type: parameter_dtype.to_base_representation(),
span_actual: arguments[idx].full_span(),
actual_name: " argument type",
actual_name_for_fix: "function argument",
actual_dimensions: self.registry.get_derived_entry_names_for(
&argument_dtype.to_base_representation(),
),
actual_type: argument_dtype.to_base_representation(),
},
)));
}
_ => {
return Err(Box::new(TypeCheckError::IncompatibleTypesInFunctionCall(
Some(parameter_span),
parameter_type.clone(),
arguments[idx].full_span(),
argument_type.clone(),
)));
}
}
}
}

Ok(typed_ast::Expression::FunctionCall(
*span,
*full_span,
function_name,
arguments,
TypeScheme::concrete(return_type.as_ref().clone()),
))
}

fn elaborate_expression<'a>(
&mut self,
ast: &ast::Expression<'a>,
Expand Down Expand Up @@ -778,18 +780,20 @@ impl TypeChecker {
// to a (proper) function, or it can be an arbitrary complicated expression
// that evaluates to a function "pointer".

if let Some((name, signature)) = self.get_proper_function_reference(callable) {
// TODO: there is probably a better way to get around borrowing issues here
let signature = signature.clone();

self.proper_function_call(
if let Some((function_name, signature)) =
self.env.get_proper_function_reference(callable)
{
proper_function_call(ProperFunctionCallArgs {
registry: &mut self.registry,
constraints: &mut self.constraints,
name_generator: &mut self.name_generator,
span,
full_span,
name,
&signature,
arguments_checked,
function_name,
signature,
arguments: arguments_checked,
argument_types,
)?
})?
} else {
let callable_checked = self.elaborate_expression(callable)?;
let callable_type = callable_checked.get_type();
Expand Down

0 comments on commit 1a276f3

Please sign in to comment.