diff --git a/src/compiler/parser/mod.rs b/src/compiler/parser/mod.rs index c030dd9..eb6abea 100644 --- a/src/compiler/parser/mod.rs +++ b/src/compiler/parser/mod.rs @@ -77,7 +77,7 @@ fn parse_expression<'source>( &["or"], // 1 &["and"], // 2 &["==", "!="], // 3 - &["<", "<=", "=>", ">"], // 4 + &["<", "<=", ">=", ">"], // 4 &["+", "-"], // 5 &["*", "/", "%"], // 6 &["not", "-"], // 7 diff --git a/src/compiler/symtab.rs b/src/compiler/symtab.rs index dae6f66..a1cae10 100644 --- a/src/compiler/symtab.rs +++ b/src/compiler/symtab.rs @@ -51,7 +51,9 @@ impl<'source> SymTab<'source, Type> { (">", Func(vec![Int, Int], Box::new(Bool))), (">=", Func(vec![Int, Int], Box::new(Bool))), ("not", Func(vec![Bool], Box::new(Bool))), - ("neg", Func(vec![Bool], Box::new(Bool))), + ("neg", Func(vec![Int], Box::new(Int))), + ("or", Func(vec![Bool, Bool], Box::new(Bool))), + ("and", Func(vec![Bool, Bool], Box::new(Bool))), ]); SymTab { diff --git a/src/compiler/type_checker.rs b/src/compiler/type_checker.rs index d3ce620..6f78a95 100644 --- a/src/compiler/type_checker.rs +++ b/src/compiler/type_checker.rs @@ -9,14 +9,140 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour EmptyLiteral(_) => Type::Unit, IntLiteral(_, _) => Type::Int, BoolLiteral(_, _) => Type::Bool, - Identifier(_, _) => todo!(), - UnaryOp(_, _, _) => todo!(), - BinaryOp(_, _, _, _) => todo!(), - VarDeclaration(_, _, _) => Type::Unit, - Conditional(_, _, _, _) => todo!(), - While(_, _, _) => todo!(), - FunCall(_, _, _) => todo!(), - Block(_, _) => todo!(), + Identifier(_, name) => symbols.get(name).clone(), + UnaryOp(_, op, expr) => match *op { + "-" => { + let expr_types = vec![type_check(expr, symbols)]; + + let Type::Func(sig_arg_types, sig_ret_type) = symbols.get("neg") else { + panic!("Identifier {} does not correspond to an operator!", op); + }; + + if expr_types != *sig_arg_types { + panic!( + "Operator {} argument types {:?} don't match expected {:?}", + op, expr_types, *sig_arg_types + ); + } + + (**sig_ret_type).clone() + } + _ => { + let expr_types = vec![type_check(expr, symbols)]; + + let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else { + panic!("Identifier {} does not correspond to an operator!", op); + }; + + if expr_types != *sig_arg_types { + panic!( + "Operator {} argument types {:?} don't match expected {:?}", + op, expr_types, *sig_arg_types + ); + } + + (**sig_ret_type).clone() + } + }, + BinaryOp(_, left, op, right) => match *op { + "==" | "!=" => { + let left_type = type_check(left, symbols); + let right_type = type_check(right, symbols); + if left_type != right_type { + panic!("Mismatched types being compared with {op}"); + } + Type::Bool + } + "=" => { + if !matches!(**left, Identifier(_, _)) { + panic!("Non-variable on left side of assignment!"); + } + + let left_type = type_check(left, symbols); + let right_type = type_check(right, symbols); + if left_type != right_type { + panic!("Mismatched types in assignment!"); + } + left_type + } + _ => { + let left_type = type_check(left, symbols); + let right_type = type_check(right, symbols); + let arg_types = vec![left_type, right_type]; + + let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else { + panic!("Identifier {} does not correspond to an operator!", op); + }; + + if arg_types != *sig_arg_types { + panic!( + "Operator {} argument types {:?} don't match expected {:?}", + op, arg_types, *sig_arg_types + ); + } + + (**sig_ret_type).clone() + } + }, + VarDeclaration(_, name, expr) => { + let type_var = type_check(expr, symbols); + symbols.insert(name, type_var); + Type::Unit + } + Conditional(_, condition_expr, then_expr, else_expr) => { + if !matches!(type_check(condition_expr, symbols), Type::Bool) { + panic!("Non-bool as if-then-else condition!"); + } + + if let Some(else_expr) = else_expr { + let then_type = type_check(then_expr, symbols); + let else_type = type_check(else_expr, symbols); + if then_type == else_type { + then_type + } else { + panic!("Mismatched return values in if-then-else!"); + } + } else { + Type::Unit + } + } + While(_, condition_expr, do_expr) => { + if !matches!(type_check(condition_expr, symbols), Type::Bool) { + panic!("Non-bool as while-do condition!"); + } + type_check(do_expr, symbols); + Type::Unit + } + FunCall(_, name, args) => { + let mut arg_types = Vec::new(); + for arg in args { + arg_types.push(type_check(arg, symbols)); + } + + let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(name) else { + panic!("Identifier {} does not correspond to a function!", name); + }; + + if arg_types != *sig_arg_types { + panic!( + "Function {} argument types {:?} don't match expected {:?}", + name, arg_types, *sig_arg_types + ); + } + + (**sig_ret_type).clone() + } + Block(_, expressions) => { + symbols.push_level(); + + let mut type_var = Type::Unit; + for expression in expressions { + type_var = type_check(expression, symbols); + } + + symbols.remove_level(); + type_var + } } } @@ -43,9 +169,6 @@ mod tests { let result = get_type("var a = true;"); assert_eq!(result, Unit); - - let result = get_type("+"); - assert_eq!(result, Func(vec![Int, Int], Box::new(Int))); } #[test] @@ -87,6 +210,12 @@ mod tests { get_type("var a = 1; a = true"); } + #[test] + #[should_panic] + fn test_assign_non_var() { + get_type("1 = 2"); + } + #[test] fn test_operators() { let result = get_type("true or false"); diff --git a/src/compiler/variable.rs b/src/compiler/variable.rs index bb4f925..2f565f0 100644 --- a/src/compiler/variable.rs +++ b/src/compiler/variable.rs @@ -1,6 +1,6 @@ use std::fmt; -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Debug, Clone)] pub enum Type { Int, Bool,