1
0
Fork 0

Fix bug in type checker

This commit is contained in:
Vili Sinervä 2025-02-05 23:00:39 +02:00
parent 4d00bbb6ba
commit 7dd60aa9c5
No known key found for this signature in database
GPG key ID: DF8FEAF54EFAC996

View file

@ -14,13 +14,13 @@ pub fn type_check<'source>(
} }
fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type { fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type {
match ast.expr.clone() { match &mut ast.expr {
EmptyLiteral() => Type::Unit, EmptyLiteral() => Type::Unit,
IntLiteral(_) => Type::Int, IntLiteral(_) => Type::Int,
BoolLiteral(_) => Type::Bool, BoolLiteral(_) => Type::Bool,
Identifier(name) => symbols.get(name).clone(), Identifier(name) => symbols.get(name).clone(),
UnaryOp(op, mut expr) => { UnaryOp(op, ref mut expr) => {
let expr_types = vec![type_check(&mut expr, symbols)]; let expr_types = vec![type_check(expr, symbols)];
let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(&format!("unary_{op}")) let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(&format!("unary_{op}"))
else { else {
@ -36,10 +36,10 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
(**sig_ret_type).clone() (**sig_ret_type).clone()
} }
BinaryOp(mut left, op, mut right) => match op { BinaryOp(ref mut left, op, ref mut right) => match *op {
"==" | "!=" => { "==" | "!=" => {
let left_type = type_check(&mut left, symbols); let left_type = type_check(left, symbols);
let right_type = type_check(&mut right, symbols); let right_type = type_check(right, symbols);
if left_type != right_type { if left_type != right_type {
panic!("Mismatched types being compared with {op}"); panic!("Mismatched types being compared with {op}");
} }
@ -50,16 +50,16 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
panic!("Non-variable on left side of assignment!"); panic!("Non-variable on left side of assignment!");
} }
let left_type = type_check(&mut left, symbols); let left_type = type_check(left, symbols);
let right_type = type_check(&mut right, symbols); let right_type = type_check(right, symbols);
if left_type != right_type { if left_type != right_type {
panic!("Mismatched types in assignment!"); panic!("Mismatched types in assignment!");
} }
left_type left_type
} }
_ => { _ => {
let left_type = type_check(&mut left, symbols); let left_type = type_check(left, symbols);
let right_type = type_check(&mut right, symbols); let right_type = type_check(right, symbols);
let arg_types = vec![left_type, right_type]; let arg_types = vec![left_type, right_type];
let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else { let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else {
@ -76,8 +76,8 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
(**sig_ret_type).clone() (**sig_ret_type).clone()
} }
}, },
VarDeclaration(name, mut expr, type_expr) => { VarDeclaration(name, ref mut expr, ref mut type_expr) => {
let type_var = type_check(&mut expr, symbols); let type_var = type_check(expr, symbols);
if let Some(type_expr) = type_expr { if let Some(type_expr) = type_expr {
let expected_type = match type_expr { let expected_type = match type_expr {
@ -96,14 +96,14 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
symbols.insert(name, type_var); symbols.insert(name, type_var);
Type::Unit Type::Unit
} }
Conditional(mut condition_expr, mut then_expr, else_expr) => { Conditional(ref mut condition_expr, ref mut then_expr, ref mut else_expr) => {
if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) { if !matches!(type_check(condition_expr, symbols), Type::Bool) {
panic!("Non-bool as if-then-else condition!"); panic!("Non-bool as if-then-else condition!");
} }
if let Some(mut else_expr) = else_expr { if let Some(ref mut else_expr) = else_expr {
let then_type = type_check(&mut then_expr, symbols); let then_type = type_check(then_expr, symbols);
let else_type = type_check(&mut else_expr, symbols); let else_type = type_check(else_expr, symbols);
if then_type == else_type { if then_type == else_type {
then_type then_type
} else { } else {
@ -113,11 +113,11 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
Type::Unit Type::Unit
} }
} }
While(mut condition_expr, mut do_expr) => { While(ref mut condition_expr, ref mut do_expr) => {
if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) { if !matches!(type_check(condition_expr, symbols), Type::Bool) {
panic!("Non-bool as while-do condition!"); panic!("Non-bool as while-do condition!");
} }
type_check(&mut do_expr, symbols); type_check(do_expr, symbols);
Type::Unit Type::Unit
} }
FunCall(name, args) => { FunCall(name, args) => {
@ -139,12 +139,12 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
(**sig_ret_type).clone() (**sig_ret_type).clone()
} }
Block(expressions) => { Block(ref mut expressions) => {
symbols.push_level(); symbols.push_level();
let mut type_var = Type::Unit; let mut type_var = Type::Unit;
for mut expression in expressions { for expression in expressions {
type_var = type_check(&mut expression, symbols); type_var = type_check(expression, symbols);
} }
symbols.remove_level(); symbols.remove_level();