diff --git a/src/compiler/ast.rs b/src/compiler/ast.rs index 38fafd6..ea3fe42 100644 --- a/src/compiler/ast.rs +++ b/src/compiler/ast.rs @@ -1,6 +1,12 @@ use crate::compiler::token::CodeLocation; use std::fmt; +#[derive(Debug, PartialEq)] +pub enum TypeExpression { + Int(CodeLocation), + Bool(CodeLocation), +} + #[derive(Debug, PartialEq)] pub enum Expression<'source> { EmptyLiteral(CodeLocation), @@ -14,7 +20,12 @@ pub enum Expression<'source> { &'source str, Box>, ), - VarDeclaration(CodeLocation, &'source str, Box>), + VarDeclaration( + CodeLocation, + &'source str, + Box>, + Option, + ), Conditional( CodeLocation, Box>, @@ -38,7 +49,7 @@ impl<'source> Expression<'source> { Expression::BoolLiteral(loc, _) => *loc, Expression::Identifier(loc, _) => *loc, Expression::UnaryOp(loc, _, _) => *loc, - Expression::VarDeclaration(loc, _, _) => *loc, + Expression::VarDeclaration(loc, _, _, _) => *loc, Expression::BinaryOp(loc, _, _, _) => *loc, Expression::Conditional(loc, _, _, _) => *loc, Expression::While(loc, _, _) => *loc, @@ -70,7 +81,7 @@ impl<'source> Expression<'source> { Expression::BoolLiteral(_, val) => val.to_string(), Expression::Identifier(_, name) => name.to_string(), Expression::UnaryOp(_, op, _) => op.to_string(), - Expression::VarDeclaration(_, name, _) => name.to_string(), + Expression::VarDeclaration(_, name, _, _) => name.to_string(), Expression::BinaryOp(_, _, op, _) => op.to_string(), Expression::Conditional(_, condition, _, _) => format!("if {}", condition), Expression::While(_, condition, _) => format!("while {}", condition), diff --git a/src/compiler/interpreter.rs b/src/compiler/interpreter.rs index 24dcd07..ce92f4d 100644 --- a/src/compiler/interpreter.rs +++ b/src/compiler/interpreter.rs @@ -78,7 +78,7 @@ pub fn interpret<'source>( op_fn(&[interpret(left, symbols), interpret(right, symbols)]) } }, - VarDeclaration(_, name, expr) => { + VarDeclaration(_, name, expr, _) => { let val = interpret(expr, symbols); symbols.insert(name, val); Value::None() diff --git a/src/compiler/parser/mod.rs b/src/compiler/parser/mod.rs index eb6abea..ad2a7fb 100644 --- a/src/compiler/parser/mod.rs +++ b/src/compiler/parser/mod.rs @@ -3,7 +3,10 @@ mod parser_utilities; mod tests; use crate::compiler::{ - ast::Expression::{self, *}, + ast::{ + Expression::{self, *}, + TypeExpression::{self}, + }, parser::parser_utilities::*, token::{Token, TokenType}, }; @@ -162,9 +165,22 @@ fn parse_var_declaration<'source>( ) -> Expression<'source> { consume_string(pos, tokens, "var"); let name_token = consume_type(pos, tokens, TokenType::Identifier); + + let mut type_expr = None; + if peek(pos, tokens).text == ":" { + consume_string(pos, tokens, ":"); + + let type_token = consume_type(pos, tokens, TokenType::Identifier); + type_expr = match type_token.text { + "Int" => Some(TypeExpression::Int(type_token.loc)), + "Bool" => Some(TypeExpression::Bool(type_token.loc)), + _ => panic! {"Unknown type indicator!"}, + } + } + consume_string(pos, tokens, "="); let value = parse_expression(0, pos, tokens); - VarDeclaration(name_token.loc, name_token.text, Box::new(value)) + VarDeclaration(name_token.loc, name_token.text, Box::new(value), type_expr) } fn parse_conditional<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { diff --git a/src/compiler/parser/tests.rs b/src/compiler/parser/tests.rs index d9a62e9..5dfa1f7 100644 --- a/src/compiler/parser/tests.rs +++ b/src/compiler/parser/tests.rs @@ -104,8 +104,8 @@ macro_rules! empty_ast { } macro_rules! var_ast { - ($x:expr, $y:expr) => { - VarDeclaration(CodeLocation::new(usize::MAX, usize::MAX), $x, $y) + ($x:expr, $y:expr, $z:expr) => { + VarDeclaration(CodeLocation::new(usize::MAX, usize::MAX), $x, $y, $z) }; } @@ -525,19 +525,48 @@ fn test_block_missing_semicolon() { #[test] fn test_var_basic() { let result = parse(&tokenize("var x = 1")); - assert_eq!(result, var_ast!("x", int_ast_b!(1))); + assert_eq!(result, var_ast!("x", int_ast_b!(1), None)); let result = parse(&tokenize("{ var x = 1; x = 2; }")); assert_eq!( result, block_ast!(vec![ - var_ast!("x", int_ast_b!(1)), + var_ast!("x", int_ast_b!(1), None), bin_ast!(id_ast_b!("x"), "=", int_ast_b!(2)), empty_ast!() ]) ); } +#[test] +fn test_var_typed() { + let result = parse(&tokenize("var x: Int = 1")); + assert_eq!( + result, + var_ast!( + "x", + int_ast_b!(1), + Some(TypeExpression::Int(CodeLocation::new( + usize::MAX, + usize::MAX + ))) + ) + ); + + let result = parse(&tokenize("var x: Bool = true")); + assert_eq!( + result, + var_ast!( + "x", + bool_ast_b!(true), + Some(TypeExpression::Bool(CodeLocation::new( + usize::MAX, + usize::MAX + ))) + ) + ); +} + #[test] #[should_panic] fn test_var_chain() { diff --git a/src/compiler/tokenizer.rs b/src/compiler/tokenizer.rs index 2b5e265..a4826ab 100644 --- a/src/compiler/tokenizer.rs +++ b/src/compiler/tokenizer.rs @@ -11,7 +11,7 @@ pub fn tokenize(code: &str) -> Vec { TokenType::Operator, Regex::new(r"^(==|!=|<=|>=|=|<|>|\+|-|\*|/|\%)").unwrap(), ), - (TokenType::Punctuation, Regex::new(r"^[\(\){},;]").unwrap()), + (TokenType::Punctuation, Regex::new(r"^[\(\){},;:]").unwrap()), (TokenType::Integer, Regex::new(r"^[0-9]+").unwrap()), ( TokenType::Identifier, @@ -177,7 +177,7 @@ mod tests { #[test] fn test_tokenize_punctuation_basic() { let loc = CodeLocation::new(usize::MAX, usize::MAX); - let result = tokenize("{var = (1 + 2, 3);}"); + let result = tokenize("{var = (1 + 2, 3);:}"); use TokenType::*; assert_eq!( @@ -194,6 +194,7 @@ mod tests { Token::new("3", Integer, loc), Token::new(")", Punctuation, loc), Token::new(";", Punctuation, loc), + Token::new(":", Punctuation, loc), Token::new("}", Punctuation, loc), ) ); diff --git a/src/compiler/type_checker.rs b/src/compiler/type_checker.rs index 6f78a95..fbd082f 100644 --- a/src/compiler/type_checker.rs +++ b/src/compiler/type_checker.rs @@ -1,5 +1,8 @@ use crate::compiler::{ - ast::Expression::{self, *}, + ast::{ + Expression::{self, *}, + TypeExpression, + }, symtab::SymTab, variable::Type, }; @@ -84,8 +87,23 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour (**sig_ret_type).clone() } }, - VarDeclaration(_, name, expr) => { + VarDeclaration(_, name, expr, type_expr) => { let type_var = type_check(expr, symbols); + + if let Some(type_expr) = type_expr { + let expected_type = match type_expr { + TypeExpression::Int(_) => Type::Int, + TypeExpression::Bool(_) => Type::Bool, + }; + + if type_var != expected_type { + panic!( + "Expected type {:?} does not match actual type {:?} in var declaration", + expected_type, type_var + ) + } + } + symbols.insert(name, type_var); Type::Unit }