Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loops #48

Merged
merged 5 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,15 @@ if(x == 4){
}
```
Conditions **must** be booleans

#### Loops
Similar to c-style for loops, they are composed of 4 parts
- Variable declaration
- Loop condition
- Variable assignment
- Body
```
for(let i: i32 = 0; i < 5; i = i + 1){
println(i);
}
```
12 changes: 8 additions & 4 deletions grammar/Fusion.g4
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ statement
: function
| call SEMI
| if
| declaration
| assignment
| loop
| declaration SEMI
| assignment SEMI
| block
| return
;

declaration:
variable EQUAL expr SEMI
variable EQUAL expr
;

block: L_CURLY statement* R_CURLY;
Expand All @@ -22,12 +23,14 @@ function: FUNCTION ID L_PAREN variable? (COMMA variable)* R_PAREN COLON type blo

variable: qualifier ID COLON type;

assignment: ID EQUAL expr SEMI;
assignment: ID EQUAL expr;

call: ID L_PAREN expr? (COMMA expr)* R_PAREN;

if: IF L_PAREN expr R_PAREN block else?;

loop: FOR L_PAREN declaration SEMI expr SEMI assignment R_PAREN block;

else: ELSE (block | if);

return: RETURN expr SEMI;
Expand Down Expand Up @@ -63,6 +66,7 @@ CONST: 'const';
LET: 'let';
IF: 'if';
ELSE: 'else';
FOR: 'for';

// symbols
SEMI: ';';
Expand Down
15 changes: 15 additions & 0 deletions include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,19 @@ class Conditional : public Node {

void xml(int level) override;
};

class Loop : public Node {
public:
shared_ptr<Declaration> variable;
shared_ptr<Expression> condition;
shared_ptr<Assignment> assignment;
shared_ptr<Block> body;
Loop(shared_ptr<Declaration> variable,
shared_ptr<Expression> condition,
shared_ptr<Assignment> assignment,
shared_ptr<Block> body,
Token* token);

void xml(int level) override;
};
} // namespace ast
1 change: 1 addition & 0 deletions include/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ class Builder : public FusionBaseVisitor {
std::any visitAssignment(FusionParser::AssignmentContext* ctx) override;
std::any visitIf(FusionParser::IfContext* ctx) override;
std::any visitElse(FusionParser::ElseContext* ctx) override;
std::any visitLoop(FusionParser::LoopContext* ctx) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/def_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ class DefRef : public Pass {
void visit_function(shared_ptr<ast::Function>) override;
void visit_variable(shared_ptr<ast::Variable>) override;
void visit_call(shared_ptr<ast::Call>) override;
void visit_loop(shared_ptr<ast::Loop>) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Pass {
virtual void visit_binary_operator(shared_ptr<ast::BinaryOperator>);
virtual void visit_unary_operator(shared_ptr<ast::UnaryOperator>);
virtual void visit_conditional(shared_ptr<ast::Conditional>);
virtual void visit_loop(shared_ptr<ast::Loop>);

static void run_passes(shared_ptr<ast::Block> ast, shared_ptr<SymbolTable>);
};
1 change: 1 addition & 0 deletions include/ast/passes/type_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ class TypeCheck : public Pass {
void visit_binary_operator(shared_ptr<ast::BinaryOperator>) override;
void visit_unary_operator(shared_ptr<ast::UnaryOperator>) override;
void visit_conditional(shared_ptr<ast::Conditional>) override;
void visit_loop(shared_ptr<ast::Loop>) override;
};
1 change: 1 addition & 0 deletions include/backend/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ class Backend {
mlir::Value visit_binary_operator(shared_ptr<ast::BinaryOperator>);
mlir::Value visit_unary_operator(shared_ptr<ast::UnaryOperator>);
mlir::Value visit_conditional(shared_ptr<ast::Conditional>);
mlir::Value visit_loop(shared_ptr<ast::Loop>);
};
21 changes: 21 additions & 0 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,24 @@ void ast::Conditional::xml(int level) {

std::cout << std::string(level * 4, ' ') << "</conditional>\n";
}

ast::Loop::Loop(shared_ptr<Declaration> variable,
shared_ptr<Expression> condition,
shared_ptr<Assignment> assignment,
shared_ptr<Block> body,
Token* token)
: Node(token) {
this->variable = variable;
this->condition = condition;
this->assignment = assignment;
this->body = body;
}

void ast::Loop::xml(int level) {
std::cout << std::string(level * 4, ' ') << "<loop>\n";
variable->xml(level + 1);
condition->xml(level + 1);
assignment->xml(level + 1);
body->xml(level + 1);
std::cout << std::string(level * 4, ' ') << "</loop>\n";
}
17 changes: 17 additions & 0 deletions src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) {
return visit(ctx->if_());
}

if (ctx->loop() != nullptr) {
return visit(ctx->loop());
}

throw std::runtime_error("found an invalid statement");
}

Expand Down Expand Up @@ -410,3 +414,16 @@ std::any Builder::visitElse(FusionParser::ElseContext* ctx) {
auto node = make_shared<ast::Conditional>(condition, block, token);
return to_node(node);
}

std::any Builder::visitLoop(FusionParser::LoopContext* ctx) {
Token* token = ctx->FOR()->getSymbol();

auto variable = cast_node(ast::Declaration, visit(ctx->declaration()));
auto condition = cast_node(ast::Expression, visit(ctx->expr()));
auto assignment = cast_node(ast::Assignment, visit(ctx->assignment()));
auto body = cast_node(ast::Block, visit(ctx->block()));

auto node =
make_shared<ast::Loop>(variable, condition, assignment, body, token);
return to_node(node);
}
9 changes: 9 additions & 0 deletions src/ast/passes/def_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ void DefRef::visit_call(shared_ptr<ast::Call> node) {
node->set_function(vs->function);
}

void DefRef::visit_loop(shared_ptr<ast::Loop> node) {
symbol_table->push();
visit(node->variable);
visit(node->condition);
visit(node->assignment);
visit(node->body);
symbol_table->pop();
}

bool DefRef::is_builtin(std::string name) {
if (name == "print" || name == "println") {
return true;
Expand Down
8 changes: 8 additions & 0 deletions src/ast/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void Pass::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);
try_visit(node, ast::Conditional, this->visit_conditional);
try_visit(node, ast::Loop, this->visit_loop);

throw std::runtime_error("node not added to pass manager");
}
Expand Down Expand Up @@ -113,3 +114,10 @@ void Pass::visit_conditional(shared_ptr<ast::Conditional> node) {
visit(node->else_if.value());
}
}

void Pass::visit_loop(shared_ptr<ast::Loop> node) {
visit(node->variable);
visit(node->condition);
visit(node->assignment);
visit(node->body);
}
9 changes: 9 additions & 0 deletions src/ast/passes/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ void TypeCheck::visit_conditional(shared_ptr<ast::Conditional> node) {
}
}

void TypeCheck::visit_loop(shared_ptr<ast::Loop> node) {
visit(node->variable);
visit(node->condition);
check_bool(node->condition->get_type(), node->token->getLine());

visit(node->assignment);
visit(node->body);
}

void TypeCheck::check_numeric(TypePtr type, size_t line) {
if (!type->is_numeric()) {
throw TypeError(line, "type(" + type->get_name() + ") is not numeric");
Expand Down
23 changes: 23 additions & 0 deletions src/backend/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mlir::Value Backend::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);
try_visit(node, ast::Conditional, this->visit_conditional);
try_visit(node, ast::Loop, this->visit_loop);

throw std::runtime_error("node not added to backend visit function");
}
Expand Down Expand Up @@ -183,3 +184,25 @@ mlir::Value Backend::visit_conditional(shared_ptr<ast::Conditional> node) {

return nullptr;
}

mlir::Value Backend::visit_loop(shared_ptr<ast::Loop> node) {
mlir::Block* b_cond = ctx::current_function().addBlock();
mlir::Block* b_loop = ctx::current_function().addBlock();
mlir::Block* b_exit = ctx::current_function().addBlock();

visit(node->variable);

flow::jump(b_cond);
ctx::builder->setInsertionPointToStart(b_cond);

mlir::Value condition = visit(node->condition);
flow::branch(condition, b_loop, b_exit);

ctx::builder->setInsertionPointToStart(b_loop);
visit(node->body);
visit(node->assignment);
flow::jump(b_cond);

ctx::builder->setInsertionPointToStart(b_exit);
return nullptr;
}
6 changes: 6 additions & 0 deletions tests/input/errors/symbol/loop.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(): i32 {
for(let i: i32 = 0; i < 5; i = i + 1) {
}
println(i);
return 0;
}
6 changes: 6 additions & 0 deletions tests/input/errors/type/loop.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(): i32 {
for(let i: i32 = 0; i; i = i + 1) {
println(5);
}
return 0;
}
6 changes: 6 additions & 0 deletions tests/input/loop/basic.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(): i32 {
for(let i: i32 = 0; i < 5; i = i + 1) {
println(5);
}
return 0;
}
1 change: 1 addition & 0 deletions tests/output/errors/symbol/loop.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SymbolError on Line 4
1 change: 1 addition & 0 deletions tests/output/errors/type/loop.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TypeError on Line 2
5 changes: 5 additions & 0 deletions tests/output/loop/basic.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
5
5
5
5
5
Loading