1
0
Fork 0

Encapsulate expressions in AstNode struct

This commit is contained in:
Vili Sinervä 2025-02-04 18:45:17 +02:00
parent 0d19f447f9
commit ca2eeb9e50
No known key found for this signature in database
GPG key ID: DF8FEAF54EFAC996
5 changed files with 153 additions and 157 deletions

View file

@ -1,4 +1,5 @@
use crate::compiler::token::CodeLocation; use crate::compiler::token::CodeLocation;
use crate::compiler::variable::Type;
use std::fmt; use std::fmt;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
@ -7,57 +8,55 @@ pub enum TypeExpression {
Bool(CodeLocation), Bool(CodeLocation),
} }
#[derive(Debug, PartialEq)]
pub struct AstNode<'source> {
pub loc: CodeLocation,
pub node_type: Type,
pub expr: Expression<'source>,
}
impl<'source> AstNode<'source> {
pub fn new(loc: CodeLocation, expr: Expression<'source>) -> AstNode<'source> {
AstNode {
loc,
expr,
node_type: Type::Unit,
}
}
}
impl<'source> fmt::Display for AstNode<'source> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} {} at {}",
self.expr.expr_type_str(),
self.expr.val_string(),
self.loc
)
}
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Expression<'source> { pub enum Expression<'source> {
EmptyLiteral(CodeLocation), EmptyLiteral(),
IntLiteral(CodeLocation, i64), IntLiteral(i64),
BoolLiteral(CodeLocation, bool), BoolLiteral(bool),
Identifier(CodeLocation, &'source str), Identifier(&'source str),
UnaryOp(CodeLocation, &'source str, Box<Expression<'source>>), UnaryOp(&'source str, Box<AstNode<'source>>),
BinaryOp( BinaryOp(Box<AstNode<'source>>, &'source str, Box<AstNode<'source>>),
CodeLocation, VarDeclaration(&'source str, Box<AstNode<'source>>, Option<TypeExpression>),
Box<Expression<'source>>,
&'source str,
Box<Expression<'source>>,
),
VarDeclaration(
CodeLocation,
&'source str,
Box<Expression<'source>>,
Option<TypeExpression>,
),
Conditional( Conditional(
CodeLocation, Box<AstNode<'source>>,
Box<Expression<'source>>, Box<AstNode<'source>>,
Box<Expression<'source>>, Option<Box<AstNode<'source>>>,
Option<Box<Expression<'source>>>,
), ),
While( While(Box<AstNode<'source>>, Box<AstNode<'source>>),
CodeLocation, FunCall(&'source str, Vec<AstNode<'source>>),
Box<Expression<'source>>, Block(Vec<AstNode<'source>>),
Box<Expression<'source>>,
),
FunCall(CodeLocation, &'source str, Vec<Expression<'source>>),
Block(CodeLocation, Vec<Expression<'source>>),
} }
impl<'source> Expression<'source> { impl<'source> Expression<'source> {
pub fn loc(&self) -> CodeLocation {
match self {
Expression::EmptyLiteral(loc) => *loc,
Expression::IntLiteral(loc, _) => *loc,
Expression::BoolLiteral(loc, _) => *loc,
Expression::Identifier(loc, _) => *loc,
Expression::UnaryOp(loc, _, _) => *loc,
Expression::VarDeclaration(loc, _, _, _) => *loc,
Expression::BinaryOp(loc, _, _, _) => *loc,
Expression::Conditional(loc, _, _, _) => *loc,
Expression::While(loc, _, _) => *loc,
Expression::FunCall(loc, _, _) => *loc,
Expression::Block(loc, _) => *loc,
}
}
fn expr_type_str(&self) -> &str { fn expr_type_str(&self) -> &str {
match self { match self {
Expression::EmptyLiteral(..) => "Empty literal", Expression::EmptyLiteral(..) => "Empty literal",
@ -68,7 +67,7 @@ impl<'source> Expression<'source> {
Expression::VarDeclaration(..) => "Variable declaration", Expression::VarDeclaration(..) => "Variable declaration",
Expression::BinaryOp(..) => "Binary operation", Expression::BinaryOp(..) => "Binary operation",
Expression::Conditional(..) => "Conditional", Expression::Conditional(..) => "Conditional",
Expression::While(_, _, _) => "While loop", Expression::While(..) => "While loop",
Expression::FunCall(..) => "Function call", Expression::FunCall(..) => "Function call",
Expression::Block(..) => "Block", Expression::Block(..) => "Block",
} }
@ -77,28 +76,16 @@ impl<'source> Expression<'source> {
fn val_string(&self) -> String { fn val_string(&self) -> String {
match self { match self {
Expression::EmptyLiteral(..) => "".to_string(), Expression::EmptyLiteral(..) => "".to_string(),
Expression::IntLiteral(_, val) => val.to_string(), Expression::IntLiteral(val) => val.to_string(),
Expression::BoolLiteral(_, val) => val.to_string(), Expression::BoolLiteral(val) => val.to_string(),
Expression::Identifier(_, name) => name.to_string(), Expression::Identifier(name) => name.to_string(),
Expression::UnaryOp(_, op, _) => op.to_string(), Expression::UnaryOp(op, _) => op.to_string(),
Expression::VarDeclaration(_, name, _, _) => name.to_string(), Expression::VarDeclaration(name, _, _) => name.to_string(),
Expression::BinaryOp(_, _, op, _) => op.to_string(), Expression::BinaryOp(_, op, _) => op.to_string(),
Expression::Conditional(_, condition, _, _) => format!("if {}", condition), Expression::Conditional(condition, _, _) => format!("if {:?}", condition),
Expression::While(_, condition, _) => format!("while {}", condition), Expression::While(condition, _) => format!("while {:?}", condition),
Expression::FunCall(_, name, args) => format!("{} with {} args", name, args.len()), Expression::FunCall(name, args) => format!("{} with {} args", name, args.len()),
Expression::Block(_, expressions) => format!("with {} expressions", expressions.len()), Expression::Block(expressions) => format!("with {} expressions", expressions.len()),
} }
} }
} }
impl<'source> fmt::Display for Expression<'source> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} {} at {}",
self.expr_type_str(),
self.val_string(),
self.loc()
)
}
}

View file

@ -1,19 +1,19 @@
use crate::compiler::{ use crate::compiler::{
ast::Expression::{self, *}, ast::{
AstNode,
Expression::{self, *},
},
symtab::SymTab, symtab::SymTab,
variable::Value, variable::Value,
}; };
pub fn interpret<'source>( pub fn interpret<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, Value>) -> Value {
ast: &Expression<'source>, match &ast.expr {
symbols: &mut SymTab<'source, Value>, EmptyLiteral() => Value::None(),
) -> Value { IntLiteral(val) => Value::Int(*val),
match ast { BoolLiteral(val) => Value::Bool(*val),
EmptyLiteral(_) => Value::None(), Identifier(name) => *symbols.get(name),
IntLiteral(_, val) => Value::Int(*val), UnaryOp(op, expr) => match *op {
BoolLiteral(_, val) => Value::Bool(*val),
Identifier(_, name) => *symbols.get(name),
UnaryOp(_, op, expr) => match *op {
"-" => { "-" => {
let Value::Func(op_fn) = symbols.get("neg") else { let Value::Func(op_fn) = symbols.get("neg") else {
panic!("Operator {} does not correspond to a function!", op); panic!("Operator {} does not correspond to a function!", op);
@ -27,7 +27,7 @@ pub fn interpret<'source>(
op_fn(&[interpret(expr, symbols)]) op_fn(&[interpret(expr, symbols)])
} }
}, },
BinaryOp(_, left, op, right) => match *op { BinaryOp(left, op, right) => match *op {
"and" => { "and" => {
let left_val = interpret(left, symbols); let left_val = interpret(left, symbols);
if let Value::Bool(left_val) = left_val { if let Value::Bool(left_val) = left_val {
@ -63,7 +63,7 @@ pub fn interpret<'source>(
} }
} }
"=" => { "=" => {
if let Expression::Identifier(_, name) = **left { if let Expression::Identifier(name) = left.expr {
let val = interpret(right, symbols); let val = interpret(right, symbols);
*symbols.get(name) = val; *symbols.get(name) = val;
val val
@ -78,12 +78,12 @@ pub fn interpret<'source>(
op_fn(&[interpret(left, symbols), interpret(right, symbols)]) op_fn(&[interpret(left, symbols), interpret(right, symbols)])
} }
}, },
VarDeclaration(_, name, expr, _) => { VarDeclaration(name, expr, _) => {
let val = interpret(expr, symbols); let val = interpret(expr, symbols);
symbols.insert(name, val); symbols.insert(name, val);
Value::None() Value::None()
} }
Conditional(_, condition_expr, then_expr, else_expr) => { Conditional(condition_expr, then_expr, else_expr) => {
let Value::Bool(condition) = interpret(condition_expr, symbols) else { let Value::Bool(condition) = interpret(condition_expr, symbols) else {
panic!("Non-bool as if-then-else condition!"); panic!("Non-bool as if-then-else condition!");
}; };
@ -101,7 +101,7 @@ pub fn interpret<'source>(
Value::None() Value::None()
} }
} }
While(_, condition, do_expr) => { While(condition, do_expr) => {
loop { loop {
let condition = interpret(condition, symbols); let condition = interpret(condition, symbols);
if let Value::Bool(cond) = condition { if let Value::Bool(cond) = condition {
@ -116,7 +116,7 @@ pub fn interpret<'source>(
} }
Value::None() Value::None()
} }
FunCall(_, name, args) => { FunCall(name, args) => {
let mut arg_values = Vec::new(); let mut arg_values = Vec::new();
for arg in args { for arg in args {
arg_values.push(interpret(arg, symbols)); arg_values.push(interpret(arg, symbols));
@ -128,7 +128,7 @@ pub fn interpret<'source>(
function(&arg_values) function(&arg_values)
} }
Block(_, expressions) => { Block(expressions) => {
symbols.push_level(); symbols.push_level();
let mut val = Value::None(); let mut val = Value::None();

View file

@ -4,14 +4,15 @@ mod tests;
use crate::compiler::{ use crate::compiler::{
ast::{ ast::{
Expression::{self, *}, AstNode,
Expression::*,
TypeExpression::{self}, TypeExpression::{self},
}, },
parser::parser_utilities::*, parser::parser_utilities::*,
token::{Token, TokenType}, token::{Token, TokenType},
}; };
pub fn parse<'source>(tokens: &[Token<'source>]) -> Expression<'source> { pub fn parse<'source>(tokens: &[Token<'source>]) -> AstNode<'source> {
let mut pos = 0; let mut pos = 0;
let first_expression = parse_block_level_expressions(&mut pos, tokens); let first_expression = parse_block_level_expressions(&mut pos, tokens);
@ -47,10 +48,10 @@ pub fn parse<'source>(tokens: &[Token<'source>]) -> Expression<'source> {
let last_token = peek(&mut (pos - 1), tokens); let last_token = peek(&mut (pos - 1), tokens);
if last_token.text == ";" { if last_token.text == ";" {
expressions.push(EmptyLiteral(last_token.loc)); expressions.push(AstNode::new(last_token.loc, EmptyLiteral()));
} }
Block(tokens[0].loc, expressions) AstNode::new(tokens[0].loc, Block(expressions))
} else { } else {
first_expression first_expression
} }
@ -61,7 +62,7 @@ pub fn parse<'source>(tokens: &[Token<'source>]) -> Expression<'source> {
fn parse_block_level_expressions<'source>( fn parse_block_level_expressions<'source>(
pos: &mut usize, pos: &mut usize,
tokens: &[Token<'source>], tokens: &[Token<'source>],
) -> Expression<'source> { ) -> AstNode<'source> {
// Special handling for variable declaration, since it is only allowed in very specifc places // Special handling for variable declaration, since it is only allowed in very specifc places
if peek(pos, tokens).text == "var" { if peek(pos, tokens).text == "var" {
parse_var_declaration(pos, tokens) parse_var_declaration(pos, tokens)
@ -74,7 +75,7 @@ fn parse_expression<'source>(
level: usize, level: usize,
pos: &mut usize, pos: &mut usize,
tokens: &[Token<'source>], tokens: &[Token<'source>],
) -> Expression<'source> { ) -> AstNode<'source> {
const OPS: [&[&str]; 8] = [ const OPS: [&[&str]; 8] = [
&["="], // 0 &["="], // 0
&["or"], // 1 &["or"], // 1
@ -93,11 +94,9 @@ fn parse_expression<'source>(
if OPS[level].contains(&peek(pos, tokens).text) { if OPS[level].contains(&peek(pos, tokens).text) {
let operator_token = consume_strings(pos, tokens, OPS[level]); let operator_token = consume_strings(pos, tokens, OPS[level]);
let right = parse_expression(level, pos, tokens); let right = parse_expression(level, pos, tokens);
BinaryOp( AstNode::new(
operator_token.loc, operator_token.loc,
Box::new(left), BinaryOp(Box::new(left), operator_token.text, Box::new(right)),
operator_token.text,
Box::new(right),
) )
} else { } else {
left left
@ -109,11 +108,9 @@ fn parse_expression<'source>(
let operator_token = consume_strings(pos, tokens, OPS[level]); let operator_token = consume_strings(pos, tokens, OPS[level]);
let right = parse_expression(level + 1, pos, tokens); let right = parse_expression(level + 1, pos, tokens);
left = BinaryOp( left = AstNode::new(
operator_token.loc, operator_token.loc,
Box::new(left), BinaryOp(Box::new(left), operator_token.text, Box::new(right)),
operator_token.text,
Box::new(right),
); );
} }
left left
@ -122,7 +119,10 @@ fn parse_expression<'source>(
if OPS[level].contains(&peek(pos, tokens).text) { if OPS[level].contains(&peek(pos, tokens).text) {
let operator_token = consume_strings(pos, tokens, OPS[level]); let operator_token = consume_strings(pos, tokens, OPS[level]);
let right = parse_expression(level, pos, tokens); let right = parse_expression(level, pos, tokens);
UnaryOp(operator_token.loc, operator_token.text, Box::new(right)) AstNode::new(
operator_token.loc,
UnaryOp(operator_token.text, Box::new(right)),
)
} else { } else {
parse_expression(level + 1, pos, tokens) parse_expression(level + 1, pos, tokens)
} }
@ -132,7 +132,7 @@ fn parse_expression<'source>(
} }
} }
fn parse_term<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_term<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
let token = peek(pos, tokens); let token = peek(pos, tokens);
match token.token_type { match token.token_type {
@ -159,10 +159,7 @@ fn parse_term<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression
} }
} }
fn parse_var_declaration<'source>( fn parse_var_declaration<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
pos: &mut usize,
tokens: &[Token<'source>],
) -> Expression<'source> {
consume_string(pos, tokens, "var"); consume_string(pos, tokens, "var");
let name_token = consume_type(pos, tokens, TokenType::Identifier); let name_token = consume_type(pos, tokens, TokenType::Identifier);
@ -180,10 +177,13 @@ fn parse_var_declaration<'source>(
consume_string(pos, tokens, "="); consume_string(pos, tokens, "=");
let value = parse_expression(0, pos, tokens); let value = parse_expression(0, pos, tokens);
VarDeclaration(name_token.loc, name_token.text, Box::new(value), type_expr) AstNode::new(
name_token.loc,
VarDeclaration(name_token.text, Box::new(value), type_expr),
)
} }
fn parse_conditional<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_conditional<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
let start = consume_string(pos, tokens, "if"); let start = consume_string(pos, tokens, "if");
let condition = Box::new(parse_expression(0, pos, tokens)); let condition = Box::new(parse_expression(0, pos, tokens));
consume_string(pos, tokens, "then"); consume_string(pos, tokens, "then");
@ -197,26 +197,26 @@ fn parse_conditional<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Exp
_ => None, _ => None,
}; };
Conditional(start.loc, condition, then_expr, else_expr) AstNode::new(start.loc, Conditional(condition, then_expr, else_expr))
} }
fn parse_while_loop<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_while_loop<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
let start = consume_string(pos, tokens, "while"); let start = consume_string(pos, tokens, "while");
let condition = Box::new(parse_expression(0, pos, tokens)); let condition = Box::new(parse_expression(0, pos, tokens));
consume_string(pos, tokens, "do"); consume_string(pos, tokens, "do");
let do_expr = Box::new(parse_expression(0, pos, tokens)); let do_expr = Box::new(parse_expression(0, pos, tokens));
While(start.loc, condition, do_expr) AstNode::new(start.loc, While(condition, do_expr))
} }
fn parse_parenthesized<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_parenthesized<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
consume_string(pos, tokens, "("); consume_string(pos, tokens, "(");
let expression = parse_expression(0, pos, tokens); let expression = parse_expression(0, pos, tokens);
consume_string(pos, tokens, ")"); consume_string(pos, tokens, ")");
expression expression
} }
fn parse_block<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_block<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
let start = consume_string(pos, tokens, "{"); let start = consume_string(pos, tokens, "{");
let mut expressions = Vec::new(); let mut expressions = Vec::new();
@ -240,16 +240,16 @@ fn parse_block<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expressio
// If the last expression of the block ended in a semicolon, empty return // If the last expression of the block ended in a semicolon, empty return
let next_token = peek(pos, tokens); let next_token = peek(pos, tokens);
if next_token.text == "}" { if next_token.text == "}" {
expressions.push(EmptyLiteral(next_token.loc)); expressions.push(AstNode::new(next_token.loc, EmptyLiteral()));
break; break;
} }
} }
consume_string(pos, tokens, "}"); consume_string(pos, tokens, "}");
Block(start.loc, expressions) AstNode::new(start.loc, Block(expressions))
} }
fn parse_function<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_function<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
let identifier = consume_type(pos, tokens, TokenType::Identifier); let identifier = consume_type(pos, tokens, TokenType::Identifier);
consume_string(pos, tokens, "("); consume_string(pos, tokens, "(");
@ -266,32 +266,35 @@ fn parse_function<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expres
} }
} }
consume_string(pos, tokens, ")"); consume_string(pos, tokens, ")");
FunCall(identifier.loc, identifier.text, arguments) AstNode::new(identifier.loc, FunCall(identifier.text, arguments))
} }
fn parse_int_literal<'source>(pos: &mut usize, tokens: &[Token]) -> Expression<'source> { fn parse_int_literal<'source>(pos: &mut usize, tokens: &[Token]) -> AstNode<'source> {
let token = consume_type(pos, tokens, TokenType::Integer); let token = consume_type(pos, tokens, TokenType::Integer);
IntLiteral( let expr = IntLiteral(
token.loc,
token token
.text .text
.parse::<i64>() .parse::<i64>()
.unwrap_or_else(|_| panic!("Fatal parser error! Invalid value in token {token}")), .unwrap_or_else(|_| panic!("Fatal parser error! Invalid value in token {token}")),
) );
AstNode::new(token.loc, expr)
} }
fn parse_bool_literal<'source>(pos: &mut usize, tokens: &[Token]) -> Expression<'source> { fn parse_bool_literal<'source>(pos: &mut usize, tokens: &[Token]) -> AstNode<'source> {
let token = consume_type(pos, tokens, TokenType::Identifier); let token = consume_type(pos, tokens, TokenType::Identifier);
match token.text { let expr = match token.text {
"true" => BoolLiteral(token.loc, true), "true" => BoolLiteral(true),
"false" => BoolLiteral(token.loc, false), "false" => BoolLiteral(false),
_ => panic!("Fatal parser error! Expected bool literal but found {token}"), _ => panic!("Fatal parser error! Expected bool literal but found {token}"),
} };
AstNode::new(token.loc, expr)
} }
fn parse_identifier<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { fn parse_identifier<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> AstNode<'source> {
let token = consume_type(pos, tokens, TokenType::Identifier); let token = consume_type(pos, tokens, TokenType::Identifier);
Identifier(token.loc, token.text) AstNode::new(token.loc, Identifier(token.text))
} }

View file

@ -3,7 +3,7 @@ use crate::compiler::{token::CodeLocation, tokenizer::tokenize};
macro_rules! bool_ast { macro_rules! bool_ast {
($x:expr) => { ($x:expr) => {
BoolLiteral(CodeLocation::new(usize::MAX, usize::MAX), $x) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), BoolLiteral($x))
}; };
} }
@ -15,7 +15,7 @@ macro_rules! bool_ast_b {
macro_rules! int_ast { macro_rules! int_ast {
($x:expr) => { ($x:expr) => {
IntLiteral(CodeLocation::new(usize::MAX, usize::MAX), $x) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), IntLiteral($x))
}; };
} }
@ -27,7 +27,7 @@ macro_rules! int_ast_b {
macro_rules! id_ast { macro_rules! id_ast {
($x:expr) => { ($x:expr) => {
Identifier(CodeLocation::new(usize::MAX, usize::MAX), $x) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), Identifier($x))
}; };
} }
@ -39,7 +39,7 @@ macro_rules! id_ast_b {
macro_rules! un_ast { macro_rules! un_ast {
($x:expr, $y:expr) => { ($x:expr, $y:expr) => {
UnaryOp(CodeLocation::new(usize::MAX, usize::MAX), $x, $y) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), UnaryOp($x, $y))
}; };
} }
@ -51,7 +51,10 @@ macro_rules! un_ast_b {
macro_rules! bin_ast { macro_rules! bin_ast {
($x:expr, $y:expr, $z:expr) => { ($x:expr, $y:expr, $z:expr) => {
BinaryOp(CodeLocation::new(usize::MAX, usize::MAX), $x, $y, $z) AstNode::new(
CodeLocation::new(usize::MAX, usize::MAX),
BinaryOp($x, $y, $z),
)
}; };
} }
@ -63,7 +66,10 @@ macro_rules! bin_ast_b {
macro_rules! con_ast { macro_rules! con_ast {
($x:expr, $y:expr, $z:expr) => { ($x:expr, $y:expr, $z:expr) => {
Conditional(CodeLocation::new(usize::MAX, usize::MAX), $x, $y, $z) AstNode::new(
CodeLocation::new(usize::MAX, usize::MAX),
Conditional($x, $y, $z),
)
}; };
} }
@ -75,7 +81,7 @@ macro_rules! con_ast_b {
macro_rules! fun_ast { macro_rules! fun_ast {
($x:expr, $y:expr) => { ($x:expr, $y:expr) => {
FunCall(CodeLocation::new(usize::MAX, usize::MAX), $x, $y) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), FunCall($x, $y))
}; };
} }
@ -87,7 +93,7 @@ macro_rules! fun_ast_b {
macro_rules! block_ast { macro_rules! block_ast {
($x:expr) => { ($x:expr) => {
Block(CodeLocation::new(usize::MAX, usize::MAX), $x) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), Block($x))
}; };
} }
@ -99,19 +105,22 @@ macro_rules! block_ast_b {
macro_rules! empty_ast { macro_rules! empty_ast {
() => { () => {
EmptyLiteral(CodeLocation::new(usize::MAX, usize::MAX)) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), EmptyLiteral())
}; };
} }
macro_rules! var_ast { macro_rules! var_ast {
($x:expr, $y:expr, $z:expr) => { ($x:expr, $y:expr, $z:expr) => {
VarDeclaration(CodeLocation::new(usize::MAX, usize::MAX), $x, $y, $z) AstNode::new(
CodeLocation::new(usize::MAX, usize::MAX),
VarDeclaration($x, $y, $z),
)
}; };
} }
macro_rules! while_ast { macro_rules! while_ast {
($x:expr, $y:expr) => { ($x:expr, $y:expr) => {
While(CodeLocation::new(usize::MAX, usize::MAX), $x, $y) AstNode::new(CodeLocation::new(usize::MAX, usize::MAX), While($x, $y))
}; };
} }
@ -124,7 +133,7 @@ macro_rules! while_ast_b {
#[test] #[test]
#[should_panic] #[should_panic]
fn test_empty() { fn test_empty() {
parse(&vec![]); parse(&[]);
} }
#[test] #[test]

View file

@ -1,19 +1,16 @@
use crate::compiler::{ use crate::compiler::{
ast::{ ast::{AstNode, Expression::*, TypeExpression},
Expression::{self, *},
TypeExpression,
},
symtab::SymTab, symtab::SymTab,
variable::Type, variable::Type,
}; };
pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'source, Type>) -> Type { pub fn type_check<'source>(ast: &AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type {
match ast { match &ast.expr {
EmptyLiteral(_) => Type::Unit, EmptyLiteral() => Type::Unit,
IntLiteral(_, _) => Type::Int, IntLiteral(_) => Type::Int,
BoolLiteral(_, _) => Type::Bool, BoolLiteral(_) => Type::Bool,
Identifier(_, name) => symbols.get(name).clone(), Identifier(name) => symbols.get(name).clone(),
UnaryOp(_, op, expr) => match *op { UnaryOp(op, expr) => match *op {
"-" => { "-" => {
let expr_types = vec![type_check(expr, symbols)]; let expr_types = vec![type_check(expr, symbols)];
@ -47,7 +44,7 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour
(**sig_ret_type).clone() (**sig_ret_type).clone()
} }
}, },
BinaryOp(_, left, op, right) => match *op { BinaryOp(left, op, right) => match *op {
"==" | "!=" => { "==" | "!=" => {
let left_type = type_check(left, symbols); let left_type = type_check(left, symbols);
let right_type = type_check(right, symbols); let right_type = type_check(right, symbols);
@ -57,7 +54,7 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour
Type::Bool Type::Bool
} }
"=" => { "=" => {
if !matches!(**left, Identifier(_, _)) { if !matches!(left.expr, Identifier(_)) {
panic!("Non-variable on left side of assignment!"); panic!("Non-variable on left side of assignment!");
} }
@ -87,7 +84,7 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour
(**sig_ret_type).clone() (**sig_ret_type).clone()
} }
}, },
VarDeclaration(_, name, expr, type_expr) => { VarDeclaration(name, expr, type_expr) => {
let type_var = type_check(expr, symbols); let type_var = type_check(expr, symbols);
if let Some(type_expr) = type_expr { if let Some(type_expr) = type_expr {
@ -107,7 +104,7 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour
symbols.insert(name, type_var); symbols.insert(name, type_var);
Type::Unit Type::Unit
} }
Conditional(_, condition_expr, then_expr, else_expr) => { Conditional(condition_expr, then_expr, else_expr) => {
if !matches!(type_check(condition_expr, symbols), Type::Bool) { if !matches!(type_check(condition_expr, symbols), Type::Bool) {
panic!("Non-bool as if-then-else condition!"); panic!("Non-bool as if-then-else condition!");
} }
@ -124,14 +121,14 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour
Type::Unit Type::Unit
} }
} }
While(_, condition_expr, do_expr) => { While(condition_expr, do_expr) => {
if !matches!(type_check(condition_expr, symbols), Type::Bool) { if !matches!(type_check(condition_expr, symbols), Type::Bool) {
panic!("Non-bool as while-do condition!"); panic!("Non-bool as while-do condition!");
} }
type_check(do_expr, symbols); type_check(do_expr, symbols);
Type::Unit Type::Unit
} }
FunCall(_, name, args) => { FunCall(name, args) => {
let mut arg_types = Vec::new(); let mut arg_types = Vec::new();
for arg in args { for arg in args {
arg_types.push(type_check(arg, symbols)); arg_types.push(type_check(arg, symbols));
@ -150,7 +147,7 @@ pub fn type_check<'source>(ast: &Expression<'source>, symbols: &mut SymTab<'sour
(**sig_ret_type).clone() (**sig_ret_type).clone()
} }
Block(_, expressions) => { Block(expressions) => {
symbols.push_level(); symbols.push_level();
let mut type_var = Type::Unit; let mut type_var = Type::Unit;