Fix bug in type checker
This commit is contained in:
parent
4d00bbb6ba
commit
7dd60aa9c5
1 changed files with 23 additions and 23 deletions
|
@ -14,13 +14,13 @@ pub fn type_check<'source>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type {
|
fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, Type>) -> Type {
|
||||||
match ast.expr.clone() {
|
match &mut ast.expr {
|
||||||
EmptyLiteral() => Type::Unit,
|
EmptyLiteral() => Type::Unit,
|
||||||
IntLiteral(_) => Type::Int,
|
IntLiteral(_) => Type::Int,
|
||||||
BoolLiteral(_) => Type::Bool,
|
BoolLiteral(_) => Type::Bool,
|
||||||
Identifier(name) => symbols.get(name).clone(),
|
Identifier(name) => symbols.get(name).clone(),
|
||||||
UnaryOp(op, mut expr) => {
|
UnaryOp(op, ref mut expr) => {
|
||||||
let expr_types = vec![type_check(&mut expr, symbols)];
|
let expr_types = vec![type_check(expr, symbols)];
|
||||||
|
|
||||||
let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(&format!("unary_{op}"))
|
let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(&format!("unary_{op}"))
|
||||||
else {
|
else {
|
||||||
|
@ -36,10 +36,10 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
|
||||||
|
|
||||||
(**sig_ret_type).clone()
|
(**sig_ret_type).clone()
|
||||||
}
|
}
|
||||||
BinaryOp(mut left, op, mut right) => match op {
|
BinaryOp(ref mut left, op, ref mut right) => match *op {
|
||||||
"==" | "!=" => {
|
"==" | "!=" => {
|
||||||
let left_type = type_check(&mut left, symbols);
|
let left_type = type_check(left, symbols);
|
||||||
let right_type = type_check(&mut right, symbols);
|
let right_type = type_check(right, symbols);
|
||||||
if left_type != right_type {
|
if left_type != right_type {
|
||||||
panic!("Mismatched types being compared with {op}");
|
panic!("Mismatched types being compared with {op}");
|
||||||
}
|
}
|
||||||
|
@ -50,16 +50,16 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
|
||||||
panic!("Non-variable on left side of assignment!");
|
panic!("Non-variable on left side of assignment!");
|
||||||
}
|
}
|
||||||
|
|
||||||
let left_type = type_check(&mut left, symbols);
|
let left_type = type_check(left, symbols);
|
||||||
let right_type = type_check(&mut right, symbols);
|
let right_type = type_check(right, symbols);
|
||||||
if left_type != right_type {
|
if left_type != right_type {
|
||||||
panic!("Mismatched types in assignment!");
|
panic!("Mismatched types in assignment!");
|
||||||
}
|
}
|
||||||
left_type
|
left_type
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let left_type = type_check(&mut left, symbols);
|
let left_type = type_check(left, symbols);
|
||||||
let right_type = type_check(&mut right, symbols);
|
let right_type = type_check(right, symbols);
|
||||||
let arg_types = vec![left_type, right_type];
|
let arg_types = vec![left_type, right_type];
|
||||||
|
|
||||||
let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else {
|
let Type::Func(sig_arg_types, sig_ret_type) = symbols.get(op) else {
|
||||||
|
@ -76,8 +76,8 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
|
||||||
(**sig_ret_type).clone()
|
(**sig_ret_type).clone()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
VarDeclaration(name, mut expr, type_expr) => {
|
VarDeclaration(name, ref mut expr, ref mut type_expr) => {
|
||||||
let type_var = type_check(&mut expr, symbols);
|
let type_var = type_check(expr, symbols);
|
||||||
|
|
||||||
if let Some(type_expr) = type_expr {
|
if let Some(type_expr) = type_expr {
|
||||||
let expected_type = match type_expr {
|
let expected_type = match type_expr {
|
||||||
|
@ -96,14 +96,14 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
|
||||||
symbols.insert(name, type_var);
|
symbols.insert(name, type_var);
|
||||||
Type::Unit
|
Type::Unit
|
||||||
}
|
}
|
||||||
Conditional(mut condition_expr, mut then_expr, else_expr) => {
|
Conditional(ref mut condition_expr, ref mut then_expr, ref mut else_expr) => {
|
||||||
if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) {
|
if !matches!(type_check(condition_expr, symbols), Type::Bool) {
|
||||||
panic!("Non-bool as if-then-else condition!");
|
panic!("Non-bool as if-then-else condition!");
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(mut else_expr) = else_expr {
|
if let Some(ref mut else_expr) = else_expr {
|
||||||
let then_type = type_check(&mut then_expr, symbols);
|
let then_type = type_check(then_expr, symbols);
|
||||||
let else_type = type_check(&mut else_expr, symbols);
|
let else_type = type_check(else_expr, symbols);
|
||||||
if then_type == else_type {
|
if then_type == else_type {
|
||||||
then_type
|
then_type
|
||||||
} else {
|
} else {
|
||||||
|
@ -113,11 +113,11 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
|
||||||
Type::Unit
|
Type::Unit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
While(mut condition_expr, mut do_expr) => {
|
While(ref mut condition_expr, ref mut do_expr) => {
|
||||||
if !matches!(type_check(&mut condition_expr, symbols), Type::Bool) {
|
if !matches!(type_check(condition_expr, symbols), Type::Bool) {
|
||||||
panic!("Non-bool as while-do condition!");
|
panic!("Non-bool as while-do condition!");
|
||||||
}
|
}
|
||||||
type_check(&mut do_expr, symbols);
|
type_check(do_expr, symbols);
|
||||||
Type::Unit
|
Type::Unit
|
||||||
}
|
}
|
||||||
FunCall(name, args) => {
|
FunCall(name, args) => {
|
||||||
|
@ -139,12 +139,12 @@ fn get_type<'source>(ast: &mut AstNode<'source>, symbols: &mut SymTab<'source, T
|
||||||
|
|
||||||
(**sig_ret_type).clone()
|
(**sig_ret_type).clone()
|
||||||
}
|
}
|
||||||
Block(expressions) => {
|
Block(ref mut expressions) => {
|
||||||
symbols.push_level();
|
symbols.push_level();
|
||||||
|
|
||||||
let mut type_var = Type::Unit;
|
let mut type_var = Type::Unit;
|
||||||
for mut expression in expressions {
|
for expression in expressions {
|
||||||
type_var = type_check(&mut expression, symbols);
|
type_var = type_check(expression, symbols);
|
||||||
}
|
}
|
||||||
|
|
||||||
symbols.remove_level();
|
symbols.remove_level();
|
||||||
|
|
Reference in a new issue