From 1a07a2376a28de61f96e490187046f25f46d4035 Mon Sep 17 00:00:00 2001 From: Jack Parsons Date: Sat, 6 Jul 2024 10:06:22 -0600 Subject: [PATCH] Control Flow (#51) * implementation * fix looping with control flow * update readme --- README.md | 4 ++++ grammar/Fusion.g4 | 4 ++++ include/ast/ast.h | 12 ++++++++++++ include/ast/passes/pass.h | 2 ++ include/backend/backend.h | 8 ++++++++ src/ast/ast.cpp | 8 ++++++++ src/ast/builder.cpp | 10 ++++++++++ src/ast/passes/pass.cpp | 4 ++++ src/backend/visitor.cpp | 38 ++++++++++++++++++++++++++++++++++++++ tests/input/loop/flow.in | 13 +++++++++++++ tests/output/loop/flow.out | 2 ++ 11 files changed, 105 insertions(+) create mode 100644 tests/input/loop/flow.in create mode 100644 tests/output/loop/flow.out diff --git a/README.md b/README.md index 8cc08c5..57e8afe 100644 --- a/README.md +++ b/README.md @@ -130,3 +130,7 @@ for(let i: i32 = 0; i < 5; i = i + 1){ println(i); } ``` + +Loops also allow for control flow: +- `continue`: goes to the next iteration of the loop +- `break`: exits the loop early diff --git a/grammar/Fusion.g4 b/grammar/Fusion.g4 index 88a86ee..e9a524a 100644 --- a/grammar/Fusion.g4 +++ b/grammar/Fusion.g4 @@ -11,6 +11,8 @@ statement | assignment SEMI | block | return + | CONTINUE SEMI + | BREAK SEMI ; declaration: @@ -67,6 +69,8 @@ LET: 'let'; IF: 'if'; ELSE: 'else'; FOR: 'for'; +BREAK: 'break'; +CONTINUE: 'continue'; // symbols SEMI: ';'; diff --git a/include/ast/ast.h b/include/ast/ast.h index 892546f..7a0ae84 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -262,4 +262,16 @@ class Loop : public Node { void xml(int level) override; }; + +class Continue : public Node { + public: + Continue(Token* token) : Node(token) {} + void xml(int level); +}; + +class Break : public Node { + public: + Break(Token* token) : Node(token) {} + void xml(int level); +}; } // namespace ast diff --git a/include/ast/passes/pass.h b/include/ast/passes/pass.h index 1937cd5..d33d825 100644 --- a/include/ast/passes/pass.h +++ b/include/ast/passes/pass.h @@ -31,6 +31,8 @@ class Pass { virtual void visit_unary_operator(shared_ptr); virtual void visit_conditional(shared_ptr); virtual void visit_loop(shared_ptr); + virtual void visit_continue(shared_ptr); + virtual void visit_break(shared_ptr); static void run_passes(shared_ptr ast, shared_ptr); }; diff --git a/include/backend/backend.h b/include/backend/backend.h index d890a06..71cf0f4 100644 --- a/include/backend/backend.h +++ b/include/backend/backend.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -14,6 +15,11 @@ using std::shared_ptr; class Backend { private: std::unordered_map variables; + + // used by continue/break + std::stack loop_conditions; + std::stack loop_exits; + mlir::Value visit(shared_ptr); public: @@ -38,4 +44,6 @@ class Backend { mlir::Value visit_unary_operator(shared_ptr); mlir::Value visit_conditional(shared_ptr); mlir::Value visit_loop(shared_ptr); + mlir::Value visit_continue(shared_ptr); + mlir::Value visit_break(shared_ptr); }; diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 225028f..a140a78 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -456,3 +456,11 @@ void ast::Loop::xml(int level) { body->xml(level + 1); std::cout << std::string(level * 4, ' ') << "\n"; } + +void ast::Continue::xml(int level) { + std::cout << std::string(level * 4, ' ') << "\n"; +} + +void ast::Break::xml(int level) { + std::cout << std::string(level * 4, ' ') << "\n"; +} diff --git a/src/ast/builder.cpp b/src/ast/builder.cpp index a1e157b..5cf3335 100644 --- a/src/ast/builder.cpp +++ b/src/ast/builder.cpp @@ -72,6 +72,16 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) { return visit(ctx->loop()); } + if (ctx->CONTINUE() != nullptr) { + auto node = make_shared(ctx->CONTINUE()->getSymbol()); + return to_node(node); + } + + if (ctx->BREAK() != nullptr) { + auto node = make_shared(ctx->BREAK()->getSymbol()); + return to_node(node); + } + throw std::runtime_error("found an invalid statement"); } diff --git a/src/ast/passes/pass.cpp b/src/ast/passes/pass.cpp index ea221c4..f9674d0 100644 --- a/src/ast/passes/pass.cpp +++ b/src/ast/passes/pass.cpp @@ -53,6 +53,8 @@ void Pass::visit(shared_ptr node) { try_visit(node, ast::UnaryOperator, this->visit_unary_operator); try_visit(node, ast::Conditional, this->visit_conditional); try_visit(node, ast::Loop, this->visit_loop); + try_visit(node, ast::Continue, this->visit_continue); + try_visit(node, ast::Break, this->visit_break); throw std::runtime_error("node not added to pass manager"); } @@ -68,6 +70,8 @@ void Pass::visit_character_literal(shared_ptr node) {} void Pass::visit_boolean_literal(shared_ptr node) {} void Pass::visit_variable(shared_ptr node) {} void Pass::visit_parameter(shared_ptr node) {} +void Pass::visit_continue(shared_ptr node) {} +void Pass::visit_break(shared_ptr node) {} void Pass::visit_declaration(shared_ptr node) { visit(node->var); diff --git a/src/backend/visitor.cpp b/src/backend/visitor.cpp index 4a5d991..8d477ef 100644 --- a/src/backend/visitor.cpp +++ b/src/backend/visitor.cpp @@ -35,6 +35,8 @@ mlir::Value Backend::visit(shared_ptr node) { try_visit(node, ast::UnaryOperator, this->visit_unary_operator); try_visit(node, ast::Conditional, this->visit_conditional); try_visit(node, ast::Loop, this->visit_loop); + try_visit(node, ast::Continue, this->visit_continue); + try_visit(node, ast::Break, this->visit_break); throw std::runtime_error("node not added to backend visit function"); } @@ -188,8 +190,12 @@ mlir::Value Backend::visit_conditional(shared_ptr node) { mlir::Value Backend::visit_loop(shared_ptr node) { mlir::Block* b_cond = ctx::current_function().addBlock(); mlir::Block* b_loop = ctx::current_function().addBlock(); + mlir::Block* b_assn = ctx::current_function().addBlock(); mlir::Block* b_exit = ctx::current_function().addBlock(); + loop_conditions.push(b_assn); + loop_exits.push(b_exit); + visit(node->variable); flow::jump(b_cond); @@ -200,9 +206,41 @@ mlir::Value Backend::visit_loop(shared_ptr node) { ctx::builder->setInsertionPointToStart(b_loop); visit(node->body); + flow::jump(b_assn); + + ctx::builder->setInsertionPointToStart(b_assn); visit(node->assignment); flow::jump(b_cond); ctx::builder->setInsertionPointToStart(b_exit); + loop_conditions.pop(); + loop_exits.pop(); + + return nullptr; +} + +mlir::Value Backend::visit_continue(shared_ptr node) { + if (loop_conditions.empty()) { + throw std::runtime_error("backend found continue outside of a loop"); + } + + mlir::Block* b_body = ctx::current_function().addBlock(); + flow::jump(loop_conditions.top()); + + ctx::builder->setInsertionPointToStart(b_body); + + return nullptr; +} + +mlir::Value Backend::visit_break(shared_ptr node) { + if (loop_exits.empty()) { + throw std::runtime_error("backend found break outside of a loop"); + } + + mlir::Block* b_body = ctx::current_function().addBlock(); + flow::jump(loop_exits.top()); + + ctx::builder->setInsertionPointToStart(b_body); + return nullptr; } diff --git a/tests/input/loop/flow.in b/tests/input/loop/flow.in new file mode 100644 index 0000000..5ec6142 --- /dev/null +++ b/tests/input/loop/flow.in @@ -0,0 +1,13 @@ +fn main(): i32 { + for(let i: i32 = 0; i < 5; i = i + 1) { + if(i == 0) { + continue; + } + + if(i > 2) { + break; + } + println(i); + } + return 0; +} diff --git a/tests/output/loop/flow.out b/tests/output/loop/flow.out new file mode 100644 index 0000000..1191247 --- /dev/null +++ b/tests/output/loop/flow.out @@ -0,0 +1,2 @@ +1 +2