From 7dd60aa9c502aa3bdeb4c43e8120c42f8d0ea301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vili=20Sinerv=C3=A4?= Date: Wed, 5 Feb 2025 23:00:39 +0200 Subject: [PATCH] Fix bug in type checker --- src/compiler/type_checker.rs | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/compiler/type_checker.rs b/src/compiler/type_checker.rs index 0bf0381..cdef99c 100644 --- a/src/compiler/type_checker.rs +++ b/src/compiler/type_checker.rs @@ -14,13 +14,13 @@ pub fn type_check<'source>( } fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type { - match ast.expr.clone() { + match &mut ast.expr { EmptyLiteral() => Type::Unit, IntLiteral(_) => Type::Int, BoolLiteral(_) => Type::Bool, Identifier(name) => symbols.get(name).clone(), - UnaryOp(op, mut expr) => { - let expr_types = vec![type_check(&mut expr, symbols)]; + UnaryOp(op, ref mut expr) => { + let expr_types = vec![type_check(expr, symbols)]; let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(&format!("unary_{op}")) else { @@ -36,10 +36,10 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T (**sig_ret_type).clone() } - BinaryOp(mut left, op, mut right) => match op { + BinaryOp(ref mut left, op, ref mut right) => match *op { "==" | "!=" => { - let left_type = type_check(&mut left, symbols); - let right_type = type_check(&mut right, symbols); + let left_type = type_check(left, symbols); + let right_type = type_check(right, symbols); if left_type != right_type { panic!("Mismatched types being compared with {op}"); } @@ -50,16 +50,16 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T panic!("Non-variable on left side of assignment!"); } - let left_type = type_check(&mut left, symbols); - let right_type = type_check(&mut right, symbols); + let left_type = type_check(left, symbols); + let right_type = type_check(right, symbols); if left_type != right_type { panic!("Mismatched types in assignment!"); } left_type } _ => { - let left_type = type_check(&mut left, symbols); - let right_type = type_check(&mut right, symbols); + let left_type = type_check(left, symbols); + let right_type = type_check(right, symbols); let arg_types = vec![left_type, right_type]; let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else { @@ -76,8 +76,8 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T (**sig_ret_type).clone() } }, - VarDeclaration(name, mut expr, type_expr) => { - let type_var = type_check(&mut expr, symbols); + VarDeclaration(name, ref mut expr, ref mut type_expr) => { + let type_var = type_check(expr, symbols); if let Some(type_expr) = type_expr { let expected_type = match type_expr { @@ -96,14 +96,14 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T symbols.insert(name, type_var); Type::Unit } - Conditional(mut condition_expr, mut then_expr, else_expr) => { - if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) { + Conditional(ref mut condition_expr, ref mut then_expr, ref mut else_expr) => { + if !matches!(type_check(condition_expr, symbols), Type::Bool) { panic!("Non-bool as if-then-else condition!"); } - if let Some(mut else_expr) = else_expr { - let then_type = type_check(&mut then_expr, symbols); - let else_type = type_check(&mut else_expr, symbols); + if let Some(ref mut else_expr) = else_expr { + let then_type = type_check(then_expr, symbols); + let else_type = type_check(else_expr, symbols); if then_type == else_type { then_type } else { @@ -113,11 +113,11 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T Type::Unit } } - While(mut condition_expr, mut do_expr) => { - if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) { + While(ref mut condition_expr, ref mut do_expr) => { + if !matches!(type_check(condition_expr, symbols), Type::Bool) { panic!("Non-bool as while-do condition!"); } - type_check(&mut do_expr, symbols); + type_check(do_expr, symbols); Type::Unit } FunCall(name, args) => { @@ -139,12 +139,12 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T (**sig_ret_type).clone() } - Block(expressions) => { + Block(ref mut expressions) => { symbols.push_level(); let mut type_var = Type::Unit; - for mut expression in expressions { - type_var = type_check(&mut expression, symbols); + for expression in expressions { + type_var = type_check(expression, symbols); } symbols.remove_level();