diff --git a/src/compiler/parser.rs b/src/compiler/parser.rs index 0b747cc..e3bfb71 100644 --- a/src/compiler/parser.rs +++ b/src/compiler/parser.rs @@ -59,6 +59,14 @@ fn next_expect_strings<'source>( } } +fn next_expect_string<'source>( + pos: &mut usize, + tokens: &[Token<'source>], + expected_string: &str, +) -> Token<'source> { + next_expect_strings(pos, tokens, &vec![expected_string]) +} + fn next_expect_type<'source>( pos: &mut usize, tokens: &[Token<'source>], @@ -83,8 +91,20 @@ fn parse_identifier<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expr Identifier(token.text) } +fn parse_parenthesized<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { + next_expect_string(pos, tokens, "("); + let expression = parse_expression(pos, tokens); + next_expect_string(pos, tokens, ")"); + expression +} + fn parse_factor<'source>(pos: &mut usize, tokens: &[Token<'source>]) -> Expression<'source> { - match peek(pos, tokens).token_type { + let token = &peek(pos, tokens); + + if token.text == "(" { + return parse_parenthesized(pos, tokens); + } + match token.token_type { TokenType::Integer => parse_int_literal(pos, tokens), TokenType::Identifier => parse_identifier(pos, tokens), _ => panic!("Unexpected {}", peek(pos, tokens)), @@ -245,4 +265,66 @@ mod tests { ) ); } + + #[test] + fn test_parenthesized() { + let result = parse(&vec![ + new_id("("), + new_int("1"), + new_id("+"), + new_int("2"), + new_id(")"), + new_id("*"), + new_int("3"), + ]); + assert_eq!( + result, + BinaryOp( + Box::new(BinaryOp( + Box::new(IntLiteral(1)), + "+", + Box::new(IntLiteral(2)) + )), + "*", + Box::new(IntLiteral(3)), + ) + ); + + let result = parse(&vec![ + new_id("("), + new_id("("), + new_int("1"), + new_id("-"), + new_int("2"), + new_id(")"), + new_id(")"), + new_id("/"), + new_int("3"), + ]); + assert_eq!( + result, + BinaryOp( + Box::new(BinaryOp( + Box::new(IntLiteral(1)), + "-", + Box::new(IntLiteral(2)) + )), + "/", + Box::new(IntLiteral(3)), + ) + ); + } + + #[test] + #[should_panic] + fn test_parenthesized_mismatched() { + parse(&vec![ + new_id("("), + new_int("1"), + new_id("+"), + new_int("2"), + new_id("*"), + new_int("3"), + ]); + } }