Skip to content

Commit

Permalink
Conditionals (#47)
Browse files Browse the repository at this point in the history
* add helpers

* add grammar

* add node and builder

* add to pass

* add to backend

* add tests

* add type checking

* clean up type check

* update readme
  • Loading branch information
jackparsonss committed Jul 3, 2024
1 parent 895811e commit e14d21f
Show file tree
Hide file tree
Showing 32 changed files with 270 additions and 3 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,17 @@ fn main(): i32 {
let a: bool = 5 != 5 && 5 == 5;
}
```

#### Conditionals
Just like `if`/`else` statements in most languages
```
let x: i32 = 5;
if(x == 4){
println(0);
} else if(x == 5) {
println(10);
} else {
println(x);
}
```
Conditions **must** be booleans
7 changes: 7 additions & 0 deletions grammar/Fusion.g4
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ file: statement* EOF;
statement
: function
| call SEMI
| if
| declaration
| assignment
| block
Expand All @@ -25,6 +26,10 @@ assignment: ID EQUAL expr SEMI;

call: ID L_PAREN expr? (COMMA expr)* R_PAREN;

if: IF L_PAREN expr R_PAREN block else?;

else: ELSE (block | if);

return: RETURN expr SEMI;

expr
Expand Down Expand Up @@ -56,6 +61,8 @@ RETURN: 'return';
FUNCTION: 'fn';
CONST: 'const';
LET: 'let';
IF: 'if';
ELSE: 'else';

// symbols
SEMI: ';';
Expand Down
17 changes: 17 additions & 0 deletions include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,21 @@ class UnaryOperator : public Expression {
UnaryOperator(UnaryOpType type, shared_ptr<Expression> rhs, Token* token);
void xml(int level) override;
};

class Conditional : public Node {
public:
shared_ptr<Expression> condition;
shared_ptr<Block> body;
std::optional<shared_ptr<Conditional>> else_if;

Conditional(shared_ptr<Expression> condition,
shared_ptr<Block> body,
std::optional<shared_ptr<Conditional>> else_if,
Token* token);
Conditional(shared_ptr<Expression> condition,
shared_ptr<Block> body,
Token* token);

void xml(int level) override;
};
} // namespace ast
2 changes: 2 additions & 0 deletions include/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ class Builder : public FusionBaseVisitor {
std::any visitAndOrCond(FusionParser::AndOrCondContext* ctx) override;
std::any visitUnary(FusionParser::UnaryContext* ctx) override;
std::any visitAssignment(FusionParser::AssignmentContext* ctx) override;
std::any visitIf(FusionParser::IfContext* ctx) override;
std::any visitElse(FusionParser::ElseContext* ctx) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Pass {
virtual void visit_return(shared_ptr<ast::Return>);
virtual void visit_binary_operator(shared_ptr<ast::BinaryOperator>);
virtual void visit_unary_operator(shared_ptr<ast::UnaryOperator>);
virtual void visit_conditional(shared_ptr<ast::Conditional>);

static void run_passes(shared_ptr<ast::Block> ast, shared_ptr<SymbolTable>);
};
1 change: 1 addition & 0 deletions include/ast/passes/type_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ class TypeCheck : public Pass {
void visit_return(shared_ptr<ast::Return>) override;
void visit_binary_operator(shared_ptr<ast::BinaryOperator>) override;
void visit_unary_operator(shared_ptr<ast::UnaryOperator>) override;
void visit_conditional(shared_ptr<ast::Conditional>) override;
};
1 change: 1 addition & 0 deletions include/backend/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ class Backend {
mlir::Value visit_return(shared_ptr<ast::Return>);
mlir::Value visit_binary_operator(shared_ptr<ast::BinaryOperator>);
mlir::Value visit_unary_operator(shared_ptr<ast::UnaryOperator>);
mlir::Value visit_conditional(shared_ptr<ast::Conditional>);
};
9 changes: 9 additions & 0 deletions include/backend/expressions/flow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include "mlir/IR/Block.h"
#include "mlir/IR/Value.h"

namespace flow {
void branch(mlir::Value condition, mlir::Block* b_true, mlir::Block* b_false);
void jump(mlir::Block* block);
} // namespace flow
20 changes: 20 additions & 0 deletions include/errors/errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ class CompileTimeException : public std::exception {
const char* what() const noexcept override { return msg.c_str(); }
};

class RunTimeException : public std::exception {
protected:
std::string msg;

public:
const char* what() const noexcept override { return msg.c_str(); }
};

#define DEF_COMPILE_TIME_EXCEPTION(NAME) \
class NAME : public CompileTimeException { \
public: \
Expand All @@ -21,6 +29,16 @@ class CompileTimeException : public std::exception {
} \
}

#define DEF_RUNTIME_TIME_EXCEPTION(NAME) \
class NAME : public CompileTimeException { \
public: \
NAME(const std::string& description) { \
std::stringstream buf; \
buf << #NAME << ": " << description << std::endl; \
msg = buf.str(); \
} \
}

DEF_COMPILE_TIME_EXCEPTION(MainError);

DEF_COMPILE_TIME_EXCEPTION(SymbolError);
Expand All @@ -32,3 +50,5 @@ DEF_COMPILE_TIME_EXCEPTION(SyntaxError);
DEF_COMPILE_TIME_EXCEPTION(TypeError);

DEF_COMPILE_TIME_EXCEPTION(AssignError);

DEF_RUNTIME_TIME_EXCEPTION(BackendError);
3 changes: 3 additions & 0 deletions include/shared/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <mlir/IR/Location.h>
#include <mlir/IR/Types.h>
#include <memory>
#include <stack>

#include "shared/type/type.h"

Expand All @@ -25,6 +26,8 @@ extern TypePtr none;
extern TypePtr bool_;

extern std::vector<TypePtr> primitives;
extern std::stack<mlir::LLVM::LLVMFuncOp> function_stack;

extern void initialize_context();
extern mlir::LLVM::LLVMFuncOp current_function();
} // namespace ctx
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set(
"${CMAKE_CURRENT_SOURCE_DIR}/backend/builtin/builtin.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/backend/builtin/print.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/backend/expressions/arithmetic.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/backend/expressions/flow.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/builder.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/symbol/scope.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/symbol/symbol.cpp"
Expand Down
31 changes: 31 additions & 0 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,34 @@ void ast::UnaryOperator::xml(int level) {

std::cout << std::string(level * 4, ' ') << "</binary operator>\n";
}

ast::Conditional::Conditional(shared_ptr<Expression> condition,
shared_ptr<Block> body,
std::optional<shared_ptr<Conditional>> else_if,
Token* token)
: Node(token) {
this->condition = condition;
this->body = body;
this->else_if = else_if;
}

ast::Conditional::Conditional(shared_ptr<Expression> condition,
shared_ptr<Block> body,
Token* token)
: ast::Conditional(condition, body, std::nullopt, token) {}

void ast::Conditional::xml(int level) {
std::cout << std::string(level * 4, ' ') << "<conditional>\n";
std::cout << std::string((level + 1) * 4, ' ') << "<if>\n";
condition->xml(level + 2);
body->xml(level + 2);
std::cout << std::string((level + 1) * 4, ' ') << "</if>\n";

if (else_if.has_value()) {
std::cout << std::string((level + 1) * 4, ' ') << "<else if>\n";
else_if.value()->xml(level + 2);
std::cout << std::string((level + 1) * 4, ' ') << "</else if>\n";
}

std::cout << std::string(level * 4, ' ') << "</conditional>\n";
}
31 changes: 31 additions & 0 deletions src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) {
return visit(ctx->return_());
}

if (ctx->if_() != nullptr) {
return visit(ctx->if_());
}

throw std::runtime_error("found an invalid statement");
}

Expand Down Expand Up @@ -379,3 +383,30 @@ std::any Builder::visitAssignment(FusionParser::AssignmentContext* ctx) {
auto assn = make_shared<ast::Assignment>(var, expr, token);
return to_node(assn);
}
std::any Builder::visitIf(FusionParser::IfContext* ctx) {
Token* token = ctx->IF()->getSymbol();

auto condition = cast_node(ast::Expression, visit(ctx->expr()));
auto block = cast_node(ast::Block, visit(ctx->block()));

std::optional<shared_ptr<ast::Conditional>> else_if = std::nullopt;
if (ctx->else_() != nullptr) {
else_if = cast_node(ast::Conditional, visit(ctx->else_()));
}

auto node = make_shared<ast::Conditional>(condition, block, else_if, token);
return to_node(node);
}

std::any Builder::visitElse(FusionParser::ElseContext* ctx) {
Token* token = ctx->ELSE()->getSymbol();

if (ctx->if_() != nullptr) {
return cast_node(ast::Node, visit(ctx->if_()));
}

auto block = cast_node(ast::Block, visit(ctx->block()));
auto condition = make_shared<ast::BooleanLiteral>(true, token);
auto node = make_shared<ast::Conditional>(condition, block, token);
return to_node(node);
}
10 changes: 10 additions & 0 deletions src/ast/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void Pass::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::Return, this->visit_return);
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);
try_visit(node, ast::Conditional, this->visit_conditional);

throw std::runtime_error("node not added to pass manager");
}
Expand Down Expand Up @@ -103,3 +104,12 @@ void Pass::visit_binary_operator(shared_ptr<ast::BinaryOperator> node) {
void Pass::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {
visit(node->rhs);
}

void Pass::visit_conditional(shared_ptr<ast::Conditional> node) {
visit(node->condition);
visit(node->body);

if (node->else_if.has_value()) {
visit(node->else_if.value());
}
}
11 changes: 11 additions & 0 deletions src/ast/passes/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ void TypeCheck::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {
}
}

void TypeCheck::visit_conditional(shared_ptr<ast::Conditional> node) {
visit(node->condition);
check_bool(node->condition->get_type(), node->token->getLine());

visit(node->body);

if (node->else_if.has_value()) {
visit(node->else_if.value());
}
}

void TypeCheck::check_numeric(TypePtr type, size_t line) {
if (!type->is_numeric()) {
throw TypeError(line, "type(" + type->get_name() + ") is not numeric");
Expand Down
4 changes: 2 additions & 2 deletions src/backend/backend.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "backend/backend.h"
#include "backend/builtin/builtin.h"
#include "errors/errors.h"
#include "shared/context.h"

#include "llvm/IR/LegacyPassManager.h"
Expand Down Expand Up @@ -28,8 +29,7 @@ shared_ptr<ast::Block> Backend::traverse(shared_ptr<ast::Block> ast) {
visit(ast);

if (mlir::failed(mlir::verify(*ctx::module))) {
ctx::module->emitError("module failed to verify");
return nullptr;
throw BackendError("backend failed to build");
}

return ast;
Expand Down
14 changes: 14 additions & 0 deletions src/backend/expressions/flow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "backend/expressions/flow.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "shared/context.h"

void flow::branch(mlir::Value condition,
mlir::Block* b_true,
mlir::Block* b_false) {
ctx::builder->create<mlir::LLVM::CondBrOp>(*ctx::loc, condition, b_true,
b_false);
}

void flow::jump(mlir::Block* block) {
ctx::builder->create<mlir::LLVM::BrOp>(*ctx::loc, block);
}
27 changes: 27 additions & 0 deletions src/backend/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "ast/ast.h"
#include "backend/backend.h"
#include "backend/expressions/arithmetic.h"
#include "backend/expressions/flow.h"
#include "backend/types/boolean.h"
#include "backend/types/character.h"
#include "backend/types/integer.h"
Expand Down Expand Up @@ -32,6 +33,7 @@ mlir::Value Backend::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::Return, this->visit_return);
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);
try_visit(node, ast::Conditional, this->visit_conditional);

throw std::runtime_error("node not added to backend visit function");
}
Expand Down Expand Up @@ -110,6 +112,7 @@ mlir::Value Backend::visit_function(shared_ptr<ast::Function> node) {

mlir::Block* b_body = func.addEntryBlock();
ctx::builder->setInsertionPointToStart(b_body);
ctx::function_stack.push(func);

for (size_t i = 0; i < node->params.size(); i++) {
mlir::Value address = visit(node->params[i]);
Expand All @@ -119,6 +122,7 @@ mlir::Value Backend::visit_function(shared_ptr<ast::Function> node) {

visit(node->body);

ctx::function_stack.pop();
ctx::builder->setInsertionPointToEnd(ctx::module->getBody());

return nullptr;
Expand Down Expand Up @@ -156,3 +160,26 @@ mlir::Value Backend::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {

return arithmetic::unary_operation(rhs, node->type, node->get_type());
}

mlir::Value Backend::visit_conditional(shared_ptr<ast::Conditional> node) {
mlir::Value condition = visit(node->condition);
mlir::Block* b_cond = ctx::current_function().addBlock();
mlir::Block* b_else = ctx::current_function().addBlock();
mlir::Block* b_exit = ctx::current_function().addBlock();

flow::branch(condition, b_cond, b_else);

ctx::builder->setInsertionPointToStart(b_cond);
visit(node->body);
flow::jump(b_exit);

ctx::builder->setInsertionPointToStart(b_else);
if (node->else_if.has_value()) {
visit(node->else_if.value());
}
flow::jump(b_exit);

ctx::builder->setInsertionPointToStart(b_exit);

return nullptr;
}
7 changes: 6 additions & 1 deletion src/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ void Compiler::xml() {
}

void Compiler::build_backend() {
backend->traverse(this->builder->get_ast());
try {
backend->traverse(this->builder->get_ast());
} catch (CompileTimeException const& e) {
std::cerr << e.what() << std::endl;
exit(1);
}
}

void Compiler::to_object(std::string filename) {
Expand Down
Loading

0 comments on commit e14d21f

Please sign in to comment.