diff --git a/src/compiler.rs b/src/compiler.rs index 6f75bf1..b0586ea 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -4,6 +4,7 @@ use interpreter::interpret; use parser::parse; use symtab::SymTab; use tokenizer::tokenize; +use type_checker::type_check; mod ast; mod interpreter; @@ -11,12 +12,13 @@ mod parser; mod symtab; mod token; mod tokenizer; -//mod type_checker; +mod type_checker; mod variable; pub fn compile(code: &str) { - let tokens = tokenizer::tokenize(code); - parser::parse(&tokens); + let tokens = tokenize(code); + let ast = parse(&tokens); + type_check(&ast, &mut SymTab::new_type_table()); } pub fn start_interpreter() { @@ -27,7 +29,7 @@ pub fn start_interpreter() { let tokens = tokenize(&code); let ast = parse(&tokens); - let val = interpret(&ast, &mut SymTab::new()); + let val = interpret(&ast, &mut SymTab::new_val_table()); println!("{}", val); } } diff --git a/src/compiler/interpreter.rs b/src/compiler/interpreter.rs index 3db3870..24dcd07 100644 --- a/src/compiler/interpreter.rs +++ b/src/compiler/interpreter.rs @@ -4,7 +4,10 @@ use crate::compiler::{ variable::Value, }; -pub fn interpret<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'source>) -> Value { +pub fn interpret<'source>( + ast: &Expression<'source>, + symbols: &mut SymTab<'source, Value>, +) -> Value { match ast { EmptyLiteral(_) => Value::None(), IntLiteral(_, val) => Value::Int(*val), diff --git a/src/compiler/symtab.rs b/src/compiler/symtab.rs index 7d05ee1..dae6f66 100644 --- a/src/compiler/symtab.rs +++ b/src/compiler/symtab.rs @@ -1,13 +1,13 @@ -use crate::compiler::variable::Value; +use crate::compiler::variable::{Type, Value}; use std::collections::HashMap; #[derive(Default)] -pub struct SymTab<'source> { - tables: Vec>, +pub struct SymTab<'source, T> { + tables: Vec>, } -impl<'source> SymTab<'source> { - pub fn get(&mut self, symbol: &str) -> &mut Value { +impl<'source, T> SymTab<'source, T> { + pub fn get(&mut self, symbol: &str) -> &mut T { for i in (0..self.tables.len()).rev() { if self.tables[i].contains_key(symbol) { return self.tables[i].get_mut(symbol).unwrap(); @@ -16,28 +16,6 @@ impl<'source> SymTab<'source> { panic!("No symbol {} found!", symbol); } - pub fn new() -> SymTab<'source> { - let globals = HashMap::from([ - ("+", Value::Func(Value::add)), - ("*", Value::Func(Value::mul)), - ("-", Value::Func(Value::sub)), - ("/", Value::Func(Value::div)), - ("%", Value::Func(Value::rem)), - ("==", Value::Func(Value::eq)), - ("!=", Value::Func(Value::neq)), - ("<", Value::Func(Value::lt)), - ("<=", Value::Func(Value::le)), - (">", Value::Func(Value::gt)), - (">=", Value::Func(Value::ge)), - ("not", Value::Func(Value::not)), - ("neg", Value::Func(Value::neg)), - ]); - - SymTab { - tables: vec![globals], - } - } - pub fn push_level(&mut self) { self.tables.push(HashMap::new()); } @@ -46,7 +24,7 @@ impl<'source> SymTab<'source> { self.tables.pop(); } - pub fn insert(&mut self, name: &'source str, val: Value) { + pub fn insert(&mut self, name: &'source str, val: T) { if self .tables .last_mut() @@ -58,3 +36,51 @@ impl<'source> SymTab<'source> { } } } + +impl<'source> SymTab<'source, Type> { + pub fn new_type_table() -> SymTab<'source, Type> { + use Type::*; + let globals = HashMap::from([ + ("+", Func(vec![Int, Int], Box::new(Int))), + ("*", Func(vec![Int, Int], Box::new(Int))), + ("-", Func(vec![Int, Int], Box::new(Int))), + ("/", Func(vec![Int, Int], Box::new(Int))), + ("%", Func(vec![Int, Int], Box::new(Int))), + ("<", Func(vec![Int, Int], Box::new(Bool))), + ("<=", Func(vec![Int, Int], Box::new(Bool))), + (">", Func(vec![Int, Int], Box::new(Bool))), + (">=", Func(vec![Int, Int], Box::new(Bool))), + ("not", Func(vec![Bool], Box::new(Bool))), + ("neg", Func(vec![Bool], Box::new(Bool))), + ]); + + SymTab { + tables: vec![globals], + } + } +} + +impl<'source> SymTab<'source, Value> { + pub fn new_val_table() -> SymTab<'source, Value> { + use Value::*; + let globals = HashMap::from([ + ("+", Func(Value::add)), + ("*", Func(Value::mul)), + ("-", Func(Value::sub)), + ("/", Func(Value::div)), + ("%", Func(Value::rem)), + ("==", Func(Value::eq)), + ("!=", Func(Value::neq)), + ("<", Func(Value::lt)), + ("<=", Func(Value::le)), + (">", Func(Value::gt)), + (">=", Func(Value::ge)), + ("not", Func(Value::not)), + ("neg", Func(Value::neg)), + ]); + + SymTab { + tables: vec![globals], + } + } +} diff --git a/src/compiler/type_checker.rs b/src/compiler/type_checker.rs new file mode 100644 index 0000000..d3ce620 --- /dev/null +++ b/src/compiler/type_checker.rs @@ -0,0 +1,208 @@ +use crate::compiler::{ + ast::Expression::{self, *}, + symtab::SymTab, + variable::Type, +}; + +pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'source, Type>) -> Type { + match ast { + EmptyLiteral(_) => Type::Unit, + IntLiteral(_, _) => Type::Int, + BoolLiteral(_, _) => Type::Bool, + Identifier(_, _) => todo!(), + UnaryOp(_, _, _) => todo!(), + BinaryOp(_, _, _, _) => todo!(), + VarDeclaration(_, _, _) => Type::Unit, + Conditional(_, _, _, _) => todo!(), + While(_, _, _) => todo!(), + FunCall(_, _, _) => todo!(), + Block(_, _) => todo!(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::compiler::{parser::parse, tokenizer::tokenize}; + use Type::*; + + fn get_type(code: &str) -> Type { + type_check(&parse(&tokenize(code)), &mut SymTab::new_type_table()) + } + + #[test] + fn test_individual() { + let result = get_type("1"); + assert_eq!(result, Int); + + let result = get_type("true"); + assert_eq!(result, Bool); + + let result = get_type("var a = true; a"); + assert_eq!(result, Bool); + + let result = get_type("var a = true;"); + assert_eq!(result, Unit); + + let result = get_type("+"); + assert_eq!(result, Func(vec![Int, Int], Box::new(Int))); + } + + #[test] + fn test_var_untyped() { + let result = get_type("var a = 1"); + assert_eq!(result, Unit); + + let result = get_type("var a = 1; a"); + assert_eq!(result, Int); + } + + #[test] + fn test_var_typed() { + let result = get_type("var a: Int = 1"); + assert_eq!(result, Unit); + + let result = get_type("var a = 1; a"); + assert_eq!(result, Int); + } + + #[test] + #[should_panic] + fn test_var_typed_mismatch() { + get_type("var a: Int = true"); + } + + #[test] + fn test_assign() { + let result = get_type("var a = 1; a = 2;"); + assert_eq!(result, Unit); + + let result = get_type("var a = 1; a = 2"); + assert_eq!(result, Int); + } + + #[test] + #[should_panic] + fn test_assign_mismatch() { + get_type("var a = 1; a = true"); + } + + #[test] + fn test_operators() { + let result = get_type("true or false"); + assert_eq!(result, Bool); + let result = get_type("true and false"); + assert_eq!(result, Bool); + let result = get_type("true == false"); + assert_eq!(result, Bool); + let result = get_type("1 == 2"); + assert_eq!(result, Bool); + let result = get_type("true != false"); + assert_eq!(result, Bool); + let result = get_type("1 != 2"); + assert_eq!(result, Bool); + let result = get_type("1 < 2"); + assert_eq!(result, Bool); + let result = get_type("1 <= 2"); + assert_eq!(result, Bool); + let result = get_type("1 > 2"); + assert_eq!(result, Bool); + let result = get_type("1 >= 2"); + assert_eq!(result, Bool); + let result = get_type("1 + 2"); + assert_eq!(result, Int); + let result = get_type("1 - 2"); + assert_eq!(result, Int); + let result = get_type("1 * 2"); + assert_eq!(result, Int); + let result = get_type("1 / 2"); + assert_eq!(result, Int); + let result = get_type("1 % 2"); + assert_eq!(result, Int); + let result = get_type("not false"); + assert_eq!(result, Bool); + let result = get_type("-1"); + assert_eq!(result, Int); + } + + #[test] + #[should_panic] + fn test_operators_mismatch() { + get_type("1 == true"); + } + + #[test] + #[should_panic] + fn test_operators_wrong_type() { + get_type("1 and 2"); + } + + #[test] + fn test_conditional() { + let result = get_type("if true then 1"); + assert_eq!(result, Unit); + + let result = get_type("if true then 1 else 2"); + assert_eq!(result, Int); + } + + #[test] + #[should_panic] + fn test_conditional_non_bool() { + get_type("if 1 then 2"); + } + + #[test] + #[should_panic] + fn test_conditional_type_mismatch() { + get_type("if true then 2 else false"); + } + + #[test] + fn test_while() { + let result = get_type("while true do 1"); + assert_eq!(result, Unit); + } + + #[test] + #[should_panic] + fn test_while_non_bool() { + get_type("while 1 do 2"); + } + + #[test] + fn test_block() { + let result = get_type("{1; 2}"); + assert_eq!(result, Int); + + let result = get_type("{1; 2;}"); + assert_eq!(result, Unit); + } + + #[test] + fn test_function() { + let mut tokens = tokenize("foo(1)"); + let mut ast = parse(&tokens); + let mut symtab = SymTab::new_type_table(); + symtab.insert("foo", Func(vec![Int], Box::new(Int))); + let result = type_check(&ast, &mut symtab); + assert_eq!(result, Int); + + tokens = tokenize("foo(1);"); + ast = parse(&tokens); + symtab = SymTab::new_type_table(); + symtab.insert("foo", Func(vec![Int], Box::new(Int))); + let result = type_check(&ast, &mut symtab); + assert_eq!(result, Unit); + } + + #[test] + #[should_panic] + fn test_function_wrong_arg() { + let tokens = tokenize("foo(true)"); + let ast = parse(&tokens); + let mut symtab = SymTab::new_type_table(); + symtab.insert("foo", Func(vec![Int], Box::new(Int))); + type_check(&ast, &mut symtab); + } +} diff --git a/src/compiler/variable.rs b/src/compiler/variable.rs index e4e8d97..bb4f925 100644 --- a/src/compiler/variable.rs +++ b/src/compiler/variable.rs @@ -1,5 +1,13 @@ use std::fmt; +#[derive(PartialEq, Debug)] +pub enum Type { + Int, + Bool, + Func(Vec, Box), + Unit, +} + #[derive(PartialEq, PartialOrd, Debug, Copy, Clone)] pub enum Value { Int(i64),