From f059870e81aa619951c37cedd2434b60b42bb9da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vili=20Sinerv=C3=A4?= Date: Tue, 4 Feb 2025 19:06:19 +0200 Subject: [PATCH] Finalize type checker for now --- src/compiler.rs | 4 +- src/compiler/ast.rs | 6 +-- src/compiler/type_checker.rs | 82 ++++++++++++++++++++++-------------- 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index b0586ea..0231340 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -17,8 +17,8 @@ mod variable; pub fn compile(code: &str) { let tokens = tokenize(code); - let ast = parse(&tokens); - type_check(&ast, &mut SymTab::new_type_table()); + let mut ast = parse(&tokens); + type_check(&mut ast, &mut SymTab::new_type_table()); } pub fn start_interpreter() { diff --git a/src/compiler/ast.rs b/src/compiler/ast.rs index 8d5e4d4..cd50a55 100644 --- a/src/compiler/ast.rs +++ b/src/compiler/ast.rs @@ -2,13 +2,13 @@ use crate::compiler::token::CodeLocation; use crate::compiler::variable::Type; use std::fmt; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum TypeExpression { Int(CodeLocation), Bool(CodeLocation), } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub struct AstNode<'source> { pub loc: CodeLocation, pub node_type: Type, @@ -37,7 +37,7 @@ impl<'source> fmt::Display for AstNode<'source> { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Expression<'source> { EmptyLiteral(), IntLiteral(i64), diff --git a/src/compiler/type_checker.rs b/src/compiler/type_checker.rs index 54b0ab1..a389e95 100644 --- a/src/compiler/type_checker.rs +++ b/src/compiler/type_checker.rs @@ -4,15 +4,24 @@ use crate::compiler::{ variable::Type, }; -pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type { - match &ast.expr { +pub fn type_check<'source>( + ast: &mut AstNode<'source>, + symbols: &mut SymTab<'source, Type>, +) -> Type { + let node_type = get_type(ast, symbols); + ast.node_type = node_type.clone(); + node_type +} + +fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type { + match ast.expr.clone() { EmptyLiteral() => Type::Unit, IntLiteral(_) => Type::Int, BoolLiteral(_) => Type::Bool, Identifier(name) => symbols.get(name).clone(), - UnaryOp(op, expr) => match *op { + UnaryOp(op, mut expr) => match op { "-" => { - let expr_types = vec![type_check(expr, symbols)]; + let expr_types = vec![type_check(&mut expr, symbols)]; let Type::Func(sig_arg_types, sig_ret_type) = symbols.get("neg") else { panic!("Identifier {} does not correspond to an operator!", op); @@ -28,7 +37,7 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, (**sig_ret_type).clone() } _ => { - let expr_types = vec![type_check(expr, symbols)]; + let expr_types = vec![type_check(&mut expr, symbols)]; let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else { panic!("Identifier {} does not correspond to an operator!", op); @@ -44,10 +53,10 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, (**sig_ret_type).clone() } }, - BinaryOp(left, op, right) => match *op { + BinaryOp(mut left, op, mut right) => match op { "==" | "!=" => { - let left_type = type_check(left, symbols); - let right_type = type_check(right, symbols); + let left_type = type_check(&mut left, symbols); + let right_type = type_check(&mut right, symbols); if left_type != right_type { panic!("Mismatched types being compared with {op}"); } @@ -58,16 +67,16 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, panic!("Non-variable on left side of assignment!"); } - let left_type = type_check(left, symbols); - let right_type = type_check(right, symbols); + let left_type = type_check(&mut left, symbols); + let right_type = type_check(&mut right, symbols); if left_type != right_type { panic!("Mismatched types in assignment!"); } left_type } _ => { - let left_type = type_check(left, symbols); - let right_type = type_check(right, symbols); + let left_type = type_check(&mut left, symbols); + let right_type = type_check(&mut right, symbols); let arg_types = vec![left_type, right_type]; let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else { @@ -84,8 +93,8 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, (**sig_ret_type).clone() } }, - VarDeclaration(name, expr, type_expr) => { - let type_var = type_check(expr, symbols); + VarDeclaration(name, mut expr, type_expr) => { + let type_var = type_check(&mut expr, symbols); if let Some(type_expr) = type_expr { let expected_type = match type_expr { @@ -104,14 +113,14 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, symbols.insert(name, type_var); Type::Unit } - Conditional(condition_expr, then_expr, else_expr) => { - if !matches!(type_check(condition_expr, symbols), Type::Bool) { + Conditional(mut condition_expr, mut then_expr, else_expr) => { + if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) { panic!("Non-bool as if-then-else condition!"); } - if let Some(else_expr) = else_expr { - let then_type = type_check(then_expr, symbols); - let else_type = type_check(else_expr, symbols); + if let Some(mut else_expr) = else_expr { + let then_type = type_check(&mut then_expr, symbols); + let else_type = type_check(&mut else_expr, symbols); if then_type == else_type { then_type } else { @@ -121,17 +130,17 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, Type::Unit } } - While(condition_expr, do_expr) => { - if !matches!(type_check(condition_expr, symbols), Type::Bool) { + While(mut condition_expr, mut do_expr) => { + if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) { panic!("Non-bool as while-do condition!"); } - type_check(do_expr, symbols); + type_check(&mut do_expr, symbols); Type::Unit } FunCall(name, args) => { let mut arg_types = Vec::new(); - for arg in args { - arg_types.push(type_check(arg, symbols)); + for mut arg in args { + arg_types.push(type_check(&mut arg, symbols)); } let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(name) else { @@ -151,8 +160,8 @@ pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, symbols.push_level(); let mut type_var = Type::Unit; - for expression in expressions { - type_var = type_check(expression, symbols); + for mut expression in expressions { + type_var = type_check(&mut expression, symbols); } symbols.remove_level(); @@ -168,7 +177,7 @@ mod tests { use Type::*; fn get_type(code: &str) -> Type { - type_check(&parse(&tokenize(code)), &mut SymTab::new_type_table()) + type_check(&mut parse(&tokenize(code)), &mut SymTab::new_type_table()) } #[test] @@ -329,14 +338,14 @@ mod tests { let mut ast = parse(&tokens); let mut symtab = SymTab::new_type_table(); symtab.insert("foo", Func(vec![Int], Box::new(Int))); - let result = type_check(&ast, &mut symtab); + let result = type_check(&mut ast, &mut symtab); assert_eq!(result, Int); tokens = tokenize("foo(1);"); ast = parse(&tokens); symtab = SymTab::new_type_table(); symtab.insert("foo", Func(vec![Int], Box::new(Int))); - let result = type_check(&ast, &mut symtab); + let result = type_check(&mut ast, &mut symtab); assert_eq!(result, Unit); } @@ -344,9 +353,20 @@ mod tests { #[should_panic] fn test_function_wrong_arg() { let tokens = tokenize("foo(true)"); - let ast = parse(&tokens); + let mut ast = parse(&tokens); let mut symtab = SymTab::new_type_table(); symtab.insert("foo", Func(vec![Int], Box::new(Int))); - type_check(&ast, &mut symtab); + type_check(&mut ast, &mut symtab); + } + + #[test] + fn test_node_type() { + let tokens = tokenize("1"); + let mut ast = parse(&tokens); + let mut symtab = SymTab::new_type_table(); + + assert_eq!(ast.node_type, Unit); + type_check(&mut ast, &mut symtab); + assert_eq!(ast.node_type, Int); } }