Skip to content

Commit

Permalink
Loops (#48)
Browse files Browse the repository at this point in the history
* add grammar

* add implementation

* add typecheck and tests

* add proper scoping

* update readme
  • Loading branch information
jackparsonss committed Jul 5, 2024
1 parent e14d21f commit 8d4ab2e
Show file tree
Hide file tree
Showing 20 changed files with 152 additions and 4 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,15 @@ if(x == 4){
}
```
Conditions **must** be booleans

#### Loops
Similar to c-style for loops, they are composed of 4 parts
- Variable declaration
- Loop condition
- Variable assignment
- Body
```
for(let i: i32 = 0; i < 5; i = i + 1){
println(i);
}
```
12 changes: 8 additions & 4 deletions grammar/Fusion.g4
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ statement
: function
| call SEMI
| if
| declaration
| assignment
| loop
| declaration SEMI
| assignment SEMI
| block
| return
;

declaration:
variable EQUAL expr SEMI
variable EQUAL expr
;

block: L_CURLY statement* R_CURLY;
Expand All @@ -22,12 +23,14 @@ function: FUNCTION ID L_PAREN variable? (COMMA variable)* R_PAREN COLON type blo

variable: qualifier ID COLON type;

assignment: ID EQUAL expr SEMI;
assignment: ID EQUAL expr;

call: ID L_PAREN expr? (COMMA expr)* R_PAREN;

if: IF L_PAREN expr R_PAREN block else?;

loop: FOR L_PAREN declaration SEMI expr SEMI assignment R_PAREN block;

else: ELSE (block | if);

return: RETURN expr SEMI;
Expand Down Expand Up @@ -63,6 +66,7 @@ CONST: 'const';
LET: 'let';
IF: 'if';
ELSE: 'else';
FOR: 'for';

// symbols
SEMI: ';';
Expand Down
15 changes: 15 additions & 0 deletions include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,19 @@ class Conditional : public Node {

void xml(int level) override;
};

class Loop : public Node {
public:
shared_ptr<Declaration> variable;
shared_ptr<Expression> condition;
shared_ptr<Assignment> assignment;
shared_ptr<Block> body;
Loop(shared_ptr<Declaration> variable,
shared_ptr<Expression> condition,
shared_ptr<Assignment> assignment,
shared_ptr<Block> body,
Token* token);

void xml(int level) override;
};
} // namespace ast
1 change: 1 addition & 0 deletions include/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ class Builder : public FusionBaseVisitor {
std::any visitAssignment(FusionParser::AssignmentContext* ctx) override;
std::any visitIf(FusionParser::IfContext* ctx) override;
std::any visitElse(FusionParser::ElseContext* ctx) override;
std::any visitLoop(FusionParser::LoopContext* ctx) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/def_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ class DefRef : public Pass {
void visit_function(shared_ptr<ast::Function>) override;
void visit_variable(shared_ptr<ast::Variable>) override;
void visit_call(shared_ptr<ast::Call>) override;
void visit_loop(shared_ptr<ast::Loop>) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Pass {
virtual void visit_binary_operator(shared_ptr<ast::BinaryOperator>);
virtual void visit_unary_operator(shared_ptr<ast::UnaryOperator>);
virtual void visit_conditional(shared_ptr<ast::Conditional>);
virtual void visit_loop(shared_ptr<ast::Loop>);

static void run_passes(shared_ptr<ast::Block> ast, shared_ptr<SymbolTable>);
};
1 change: 1 addition & 0 deletions include/ast/passes/type_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ class TypeCheck : public Pass {
void visit_binary_operator(shared_ptr<ast::BinaryOperator>) override;
void visit_unary_operator(shared_ptr<ast::UnaryOperator>) override;
void visit_conditional(shared_ptr<ast::Conditional>) override;
void visit_loop(shared_ptr<ast::Loop>) override;
};
1 change: 1 addition & 0 deletions include/backend/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ class Backend {
mlir::Value visit_binary_operator(shared_ptr<ast::BinaryOperator>);
mlir::Value visit_unary_operator(shared_ptr<ast::UnaryOperator>);
mlir::Value visit_conditional(shared_ptr<ast::Conditional>);
mlir::Value visit_loop(shared_ptr<ast::Loop>);
};
21 changes: 21 additions & 0 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,24 @@ void ast::Conditional::xml(int level) {

std::cout << std::string(level * 4, ' ') << "</conditional>\n";
}

ast::Loop::Loop(shared_ptr<Declaration> variable,
shared_ptr<Expression> condition,
shared_ptr<Assignment> assignment,
shared_ptr<Block> body,
Token* token)
: Node(token) {
this->variable = variable;
this->condition = condition;
this->assignment = assignment;
this->body = body;
}

void ast::Loop::xml(int level) {
std::cout << std::string(level * 4, ' ') << "<loop>\n";
variable->xml(level + 1);
condition->xml(level + 1);
assignment->xml(level + 1);
body->xml(level + 1);
std::cout << std::string(level * 4, ' ') << "</loop>\n";
}
17 changes: 17 additions & 0 deletions src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) {
return visit(ctx->if_());
}

if (ctx->loop() != nullptr) {
return visit(ctx->loop());
}

throw std::runtime_error("found an invalid statement");
}

Expand Down Expand Up @@ -410,3 +414,16 @@ std::any Builder::visitElse(FusionParser::ElseContext* ctx) {
auto node = make_shared<ast::Conditional>(condition, block, token);
return to_node(node);
}

std::any Builder::visitLoop(FusionParser::LoopContext* ctx) {
Token* token = ctx->FOR()->getSymbol();

auto variable = cast_node(ast::Declaration, visit(ctx->declaration()));
auto condition = cast_node(ast::Expression, visit(ctx->expr()));
auto assignment = cast_node(ast::Assignment, visit(ctx->assignment()));
auto body = cast_node(ast::Block, visit(ctx->block()));

auto node =
make_shared<ast::Loop>(variable, condition, assignment, body, token);
return to_node(node);
}
9 changes: 9 additions & 0 deletions src/ast/passes/def_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ void DefRef::visit_call(shared_ptr<ast::Call> node) {
node->set_function(vs->function);
}

void DefRef::visit_loop(shared_ptr<ast::Loop> node) {
symbol_table->push();
visit(node->variable);
visit(node->condition);
visit(node->assignment);
visit(node->body);
symbol_table->pop();
}

bool DefRef::is_builtin(std::string name) {
if (name == "print" || name == "println") {
return true;
Expand Down
8 changes: 8 additions & 0 deletions src/ast/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void Pass::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);
try_visit(node, ast::Conditional, this->visit_conditional);
try_visit(node, ast::Loop, this->visit_loop);

throw std::runtime_error("node not added to pass manager");
}
Expand Down Expand Up @@ -113,3 +114,10 @@ void Pass::visit_conditional(shared_ptr<ast::Conditional> node) {
visit(node->else_if.value());
}
}

void Pass::visit_loop(shared_ptr<ast::Loop> node) {
visit(node->variable);
visit(node->condition);
visit(node->assignment);
visit(node->body);
}
9 changes: 9 additions & 0 deletions src/ast/passes/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ void TypeCheck::visit_conditional(shared_ptr<ast::Conditional> node) {
}
}

void TypeCheck::visit_loop(shared_ptr<ast::Loop> node) {
visit(node->variable);
visit(node->condition);
check_bool(node->condition->get_type(), node->token->getLine());

visit(node->assignment);
visit(node->body);
}

void TypeCheck::check_numeric(TypePtr type, size_t line) {
if (!type->is_numeric()) {
throw TypeError(line, "type(" + type->get_name() + ") is not numeric");
Expand Down
23 changes: 23 additions & 0 deletions src/backend/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mlir::Value Backend::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);
try_visit(node, ast::Conditional, this->visit_conditional);
try_visit(node, ast::Loop, this->visit_loop);

throw std::runtime_error("node not added to backend visit function");
}
Expand Down Expand Up @@ -183,3 +184,25 @@ mlir::Value Backend::visit_conditional(shared_ptr<ast::Conditional> node) {

return nullptr;
}

mlir::Value Backend::visit_loop(shared_ptr<ast::Loop> node) {
mlir::Block* b_cond = ctx::current_function().addBlock();
mlir::Block* b_loop = ctx::current_function().addBlock();
mlir::Block* b_exit = ctx::current_function().addBlock();

visit(node->variable);

flow::jump(b_cond);
ctx::builder->setInsertionPointToStart(b_cond);

mlir::Value condition = visit(node->condition);
flow::branch(condition, b_loop, b_exit);

ctx::builder->setInsertionPointToStart(b_loop);
visit(node->body);
visit(node->assignment);
flow::jump(b_cond);

ctx::builder->setInsertionPointToStart(b_exit);
return nullptr;
}
6 changes: 6 additions & 0 deletions tests/input/errors/symbol/loop.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(): i32 {
for(let i: i32 = 0; i < 5; i = i + 1) {
}
println(i);
return 0;
}
6 changes: 6 additions & 0 deletions tests/input/errors/type/loop.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(): i32 {
for(let i: i32 = 0; i; i = i + 1) {
println(5);
}
return 0;
}
6 changes: 6 additions & 0 deletions tests/input/loop/basic.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(): i32 {
for(let i: i32 = 0; i < 5; i = i + 1) {
println(5);
}
return 0;
}
1 change: 1 addition & 0 deletions tests/output/errors/symbol/loop.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SymbolError on Line 4
1 change: 1 addition & 0 deletions tests/output/errors/type/loop.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TypeError on Line 2
5 changes: 5 additions & 0 deletions tests/output/loop/basic.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
5
5
5
5
5

0 comments on commit 8d4ab2e

Please sign in to comment.