diff --git a/README.md b/README.md index 263ed68..59d5ef0 100644 --- a/README.md +++ b/README.md @@ -104,3 +104,17 @@ fn main(): i32 { let a: bool = 5 != 5 && 5 == 5; } ``` + +#### Conditionals +Just like `if`/`else` statements in most languages +``` +let x: i32 = 5; +if(x == 4){ + println(0); +} else if(x == 5) { + println(10); +} else { + println(x); +} +``` +Conditions **must** be booleans diff --git a/grammar/Fusion.g4 b/grammar/Fusion.g4 index 387fcfc..b7748d0 100644 --- a/grammar/Fusion.g4 +++ b/grammar/Fusion.g4 @@ -5,6 +5,7 @@ file: statement* EOF; statement : function | call SEMI + | if | declaration | assignment | block @@ -25,6 +26,10 @@ assignment: ID EQUAL expr SEMI; call: ID L_PAREN expr? (COMMA expr)* R_PAREN; +if: IF L_PAREN expr R_PAREN block else?; + +else: ELSE (block | if); + return: RETURN expr SEMI; expr @@ -56,6 +61,8 @@ RETURN: 'return'; FUNCTION: 'fn'; CONST: 'const'; LET: 'let'; +IF: 'if'; +ELSE: 'else'; // symbols SEMI: ';'; diff --git a/include/ast/ast.h b/include/ast/ast.h index 68cec38..78b8903 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -230,4 +230,21 @@ class UnaryOperator : public Expression { UnaryOperator(UnaryOpType type, shared_ptr rhs, Token* token); void xml(int level) override; }; + +class Conditional : public Node { + public: + shared_ptr condition; + shared_ptr body; + std::optional> else_if; + + Conditional(shared_ptr condition, + shared_ptr body, + std::optional> else_if, + Token* token); + Conditional(shared_ptr condition, + shared_ptr body, + Token* token); + + void xml(int level) override; +}; } // namespace ast diff --git a/include/ast/builder.h b/include/ast/builder.h index 1042854..6bdc266 100644 --- a/include/ast/builder.h +++ b/include/ast/builder.h @@ -45,4 +45,6 @@ class Builder : public FusionBaseVisitor { std::any visitAndOrCond(FusionParser::AndOrCondContext* ctx) override; std::any visitUnary(FusionParser::UnaryContext* ctx) override; std::any visitAssignment(FusionParser::AssignmentContext* ctx) override; + std::any visitIf(FusionParser::IfContext* ctx) override; + std::any visitElse(FusionParser::ElseContext* ctx) override; }; diff --git a/include/ast/passes/pass.h b/include/ast/passes/pass.h index 783e2aa..71b9d0a 100644 --- a/include/ast/passes/pass.h +++ b/include/ast/passes/pass.h @@ -29,6 +29,7 @@ class Pass { virtual void visit_return(shared_ptr); virtual void visit_binary_operator(shared_ptr); virtual void visit_unary_operator(shared_ptr); + virtual void visit_conditional(shared_ptr); static void run_passes(shared_ptr ast, shared_ptr); }; diff --git a/include/ast/passes/type_check.h b/include/ast/passes/type_check.h index 71e382c..0080fff 100644 --- a/include/ast/passes/type_check.h +++ b/include/ast/passes/type_check.h @@ -20,4 +20,5 @@ class TypeCheck : public Pass { void visit_return(shared_ptr) override; void visit_binary_operator(shared_ptr) override; void visit_unary_operator(shared_ptr) override; + void visit_conditional(shared_ptr) override; }; diff --git a/include/backend/backend.h b/include/backend/backend.h index 87d68c4..7a460fc 100644 --- a/include/backend/backend.h +++ b/include/backend/backend.h @@ -36,4 +36,5 @@ class Backend { mlir::Value visit_return(shared_ptr); mlir::Value visit_binary_operator(shared_ptr); mlir::Value visit_unary_operator(shared_ptr); + mlir::Value visit_conditional(shared_ptr); }; diff --git a/include/backend/expressions/flow.h b/include/backend/expressions/flow.h new file mode 100644 index 0000000..cae401e --- /dev/null +++ b/include/backend/expressions/flow.h @@ -0,0 +1,9 @@ +#pragma once + +#include "mlir/IR/Block.h" +#include "mlir/IR/Value.h" + +namespace flow { +void branch(mlir::Value condition, mlir::Block* b_true, mlir::Block* b_false); +void jump(mlir::Block* block); +} // namespace flow diff --git a/include/errors/errors.h b/include/errors/errors.h index 3cccfb9..c2225c4 100644 --- a/include/errors/errors.h +++ b/include/errors/errors.h @@ -10,6 +10,14 @@ class CompileTimeException : public std::exception { const char* what() const noexcept override { return msg.c_str(); } }; +class RunTimeException : public std::exception { + protected: + std::string msg; + + public: + const char* what() const noexcept override { return msg.c_str(); } +}; + #define DEF_COMPILE_TIME_EXCEPTION(NAME) \ class NAME : public CompileTimeException { \ public: \ @@ -21,6 +29,16 @@ class CompileTimeException : public std::exception { } \ } +#define DEF_RUNTIME_TIME_EXCEPTION(NAME) \ + class NAME : public CompileTimeException { \ + public: \ + NAME(const std::string& description) { \ + std::stringstream buf; \ + buf << #NAME << ": " << description << std::endl; \ + msg = buf.str(); \ + } \ + } + DEF_COMPILE_TIME_EXCEPTION(MainError); DEF_COMPILE_TIME_EXCEPTION(SymbolError); @@ -32,3 +50,5 @@ DEF_COMPILE_TIME_EXCEPTION(SyntaxError); DEF_COMPILE_TIME_EXCEPTION(TypeError); DEF_COMPILE_TIME_EXCEPTION(AssignError); + +DEF_RUNTIME_TIME_EXCEPTION(BackendError); diff --git a/include/shared/context.h b/include/shared/context.h index 8414ffd..a4749bd 100644 --- a/include/shared/context.h +++ b/include/shared/context.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "shared/type/type.h" @@ -25,6 +26,8 @@ extern TypePtr none; extern TypePtr bool_; extern std::vector primitives; +extern std::stack function_stack; extern void initialize_context(); +extern mlir::LLVM::LLVMFuncOp current_function(); } // namespace ctx diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d587e55..7de009e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,6 +12,7 @@ set( "${CMAKE_CURRENT_SOURCE_DIR}/backend/builtin/builtin.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/backend/builtin/print.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/backend/expressions/arithmetic.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/backend/expressions/flow.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/ast/builder.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/ast/symbol/scope.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/ast/symbol/symbol.cpp" diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 5241d7b..44ca919 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -404,3 +404,34 @@ void ast::UnaryOperator::xml(int level) { std::cout << std::string(level * 4, ' ') << "\n"; } + +ast::Conditional::Conditional(shared_ptr condition, + shared_ptr body, + std::optional> else_if, + Token* token) + : Node(token) { + this->condition = condition; + this->body = body; + this->else_if = else_if; +} + +ast::Conditional::Conditional(shared_ptr condition, + shared_ptr body, + Token* token) + : ast::Conditional(condition, body, std::nullopt, token) {} + +void ast::Conditional::xml(int level) { + std::cout << std::string(level * 4, ' ') << "\n"; + std::cout << std::string((level + 1) * 4, ' ') << "\n"; + condition->xml(level + 2); + body->xml(level + 2); + std::cout << std::string((level + 1) * 4, ' ') << "\n"; + + if (else_if.has_value()) { + std::cout << std::string((level + 1) * 4, ' ') << "\n"; + else_if.value()->xml(level + 2); + std::cout << std::string((level + 1) * 4, ' ') << "\n"; + } + + std::cout << std::string(level * 4, ' ') << "\n"; +} diff --git a/src/ast/builder.cpp b/src/ast/builder.cpp index 2f7c9df..a44694d 100644 --- a/src/ast/builder.cpp +++ b/src/ast/builder.cpp @@ -64,6 +64,10 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) { return visit(ctx->return_()); } + if (ctx->if_() != nullptr) { + return visit(ctx->if_()); + } + throw std::runtime_error("found an invalid statement"); } @@ -379,3 +383,30 @@ std::any Builder::visitAssignment(FusionParser::AssignmentContext* ctx) { auto assn = make_shared(var, expr, token); return to_node(assn); } +std::any Builder::visitIf(FusionParser::IfContext* ctx) { + Token* token = ctx->IF()->getSymbol(); + + auto condition = cast_node(ast::Expression, visit(ctx->expr())); + auto block = cast_node(ast::Block, visit(ctx->block())); + + std::optional> else_if = std::nullopt; + if (ctx->else_() != nullptr) { + else_if = cast_node(ast::Conditional, visit(ctx->else_())); + } + + auto node = make_shared(condition, block, else_if, token); + return to_node(node); +} + +std::any Builder::visitElse(FusionParser::ElseContext* ctx) { + Token* token = ctx->ELSE()->getSymbol(); + + if (ctx->if_() != nullptr) { + return cast_node(ast::Node, visit(ctx->if_())); + } + + auto block = cast_node(ast::Block, visit(ctx->block())); + auto condition = make_shared(true, token); + auto node = make_shared(condition, block, token); + return to_node(node); +} diff --git a/src/ast/passes/pass.cpp b/src/ast/passes/pass.cpp index bce2f24..b8f35dc 100644 --- a/src/ast/passes/pass.cpp +++ b/src/ast/passes/pass.cpp @@ -51,6 +51,7 @@ void Pass::visit(shared_ptr node) { try_visit(node, ast::Return, this->visit_return); try_visit(node, ast::BinaryOperator, this->visit_binary_operator); try_visit(node, ast::UnaryOperator, this->visit_unary_operator); + try_visit(node, ast::Conditional, this->visit_conditional); throw std::runtime_error("node not added to pass manager"); } @@ -103,3 +104,12 @@ void Pass::visit_binary_operator(shared_ptr node) { void Pass::visit_unary_operator(shared_ptr node) { visit(node->rhs); } + +void Pass::visit_conditional(shared_ptr node) { + visit(node->condition); + visit(node->body); + + if (node->else_if.has_value()) { + visit(node->else_if.value()); + } +} diff --git a/src/ast/passes/type_check.cpp b/src/ast/passes/type_check.cpp index 95a5b9e..57045d8 100644 --- a/src/ast/passes/type_check.cpp +++ b/src/ast/passes/type_check.cpp @@ -144,6 +144,17 @@ void TypeCheck::visit_unary_operator(shared_ptr node) { } } +void TypeCheck::visit_conditional(shared_ptr node) { + visit(node->condition); + check_bool(node->condition->get_type(), node->token->getLine()); + + visit(node->body); + + if (node->else_if.has_value()) { + visit(node->else_if.value()); + } +} + void TypeCheck::check_numeric(TypePtr type, size_t line) { if (!type->is_numeric()) { throw TypeError(line, "type(" + type->get_name() + ") is not numeric"); diff --git a/src/backend/backend.cpp b/src/backend/backend.cpp index 0aff875..055862d 100644 --- a/src/backend/backend.cpp +++ b/src/backend/backend.cpp @@ -1,5 +1,6 @@ #include "backend/backend.h" #include "backend/builtin/builtin.h" +#include "errors/errors.h" #include "shared/context.h" #include "llvm/IR/LegacyPassManager.h" @@ -28,8 +29,7 @@ shared_ptr Backend::traverse(shared_ptr ast) { visit(ast); if (mlir::failed(mlir::verify(*ctx::module))) { - ctx::module->emitError("module failed to verify"); - return nullptr; + throw BackendError("backend failed to build"); } return ast; diff --git a/src/backend/expressions/flow.cpp b/src/backend/expressions/flow.cpp new file mode 100644 index 0000000..e19c679 --- /dev/null +++ b/src/backend/expressions/flow.cpp @@ -0,0 +1,14 @@ +#include "backend/expressions/flow.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "shared/context.h" + +void flow::branch(mlir::Value condition, + mlir::Block* b_true, + mlir::Block* b_false) { + ctx::builder->create(*ctx::loc, condition, b_true, + b_false); +} + +void flow::jump(mlir::Block* block) { + ctx::builder->create(*ctx::loc, block); +} diff --git a/src/backend/visitor.cpp b/src/backend/visitor.cpp index 89c0218..2004ff5 100644 --- a/src/backend/visitor.cpp +++ b/src/backend/visitor.cpp @@ -3,6 +3,7 @@ #include "ast/ast.h" #include "backend/backend.h" #include "backend/expressions/arithmetic.h" +#include "backend/expressions/flow.h" #include "backend/types/boolean.h" #include "backend/types/character.h" #include "backend/types/integer.h" @@ -32,6 +33,7 @@ mlir::Value Backend::visit(shared_ptr node) { try_visit(node, ast::Return, this->visit_return); try_visit(node, ast::BinaryOperator, this->visit_binary_operator); try_visit(node, ast::UnaryOperator, this->visit_unary_operator); + try_visit(node, ast::Conditional, this->visit_conditional); throw std::runtime_error("node not added to backend visit function"); } @@ -110,6 +112,7 @@ mlir::Value Backend::visit_function(shared_ptr node) { mlir::Block* b_body = func.addEntryBlock(); ctx::builder->setInsertionPointToStart(b_body); + ctx::function_stack.push(func); for (size_t i = 0; i < node->params.size(); i++) { mlir::Value address = visit(node->params[i]); @@ -119,6 +122,7 @@ mlir::Value Backend::visit_function(shared_ptr node) { visit(node->body); + ctx::function_stack.pop(); ctx::builder->setInsertionPointToEnd(ctx::module->getBody()); return nullptr; @@ -156,3 +160,26 @@ mlir::Value Backend::visit_unary_operator(shared_ptr node) { return arithmetic::unary_operation(rhs, node->type, node->get_type()); } + +mlir::Value Backend::visit_conditional(shared_ptr node) { + mlir::Value condition = visit(node->condition); + mlir::Block* b_cond = ctx::current_function().addBlock(); + mlir::Block* b_else = ctx::current_function().addBlock(); + mlir::Block* b_exit = ctx::current_function().addBlock(); + + flow::branch(condition, b_cond, b_else); + + ctx::builder->setInsertionPointToStart(b_cond); + visit(node->body); + flow::jump(b_exit); + + ctx::builder->setInsertionPointToStart(b_else); + if (node->else_if.has_value()) { + visit(node->else_if.value()); + } + flow::jump(b_exit); + + ctx::builder->setInsertionPointToStart(b_exit); + + return nullptr; +} diff --git a/src/compiler.cpp b/src/compiler.cpp index 2e22117..5705b02 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -69,7 +69,12 @@ void Compiler::xml() { } void Compiler::build_backend() { - backend->traverse(this->builder->get_ast()); + try { + backend->traverse(this->builder->get_ast()); + } catch (CompileTimeException const& e) { + std::cerr << e.what() << std::endl; + exit(1); + } } void Compiler::to_object(std::string filename) { diff --git a/src/shared/context.cpp b/src/shared/context.cpp index fa8e7bd..952ba1c 100644 --- a/src/shared/context.cpp +++ b/src/shared/context.cpp @@ -1,4 +1,5 @@ #include "shared/context.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "shared/type/boolean.h" #include "shared/type/character.h" #include "shared/type/float.h" @@ -20,6 +21,7 @@ TypePtr ctx::none; TypePtr ctx::bool_; std::vector ctx::primitives; +std::stack ctx::function_stack; void ctx::initialize_context() { context.loadDialect(); @@ -40,3 +42,11 @@ void ctx::initialize_context() { module = std::make_unique( mlir::ModuleOp::create(builder->getUnknownLoc())); } + +mlir::LLVM::LLVMFuncOp ctx::current_function() { + if (function_stack.empty()) { + throw std::runtime_error("you are not within a function"); + } + + return function_stack.top(); +} diff --git a/tests/input/conditionals/basic.in b/tests/input/conditionals/basic.in new file mode 100644 index 0000000..144bb8b --- /dev/null +++ b/tests/input/conditionals/basic.in @@ -0,0 +1,6 @@ +fn main(): i32 { + if(true) { + println(5); + } + return 0; +} diff --git a/tests/input/conditionals/else.in b/tests/input/conditionals/else.in new file mode 100644 index 0000000..6af0f30 --- /dev/null +++ b/tests/input/conditionals/else.in @@ -0,0 +1,8 @@ +fn main(): i32 { + if(false) { + println(5); + } else { + println(10); + } + return 0; +} diff --git a/tests/input/conditionals/else_if.in b/tests/input/conditionals/else_if.in new file mode 100644 index 0000000..5797389 --- /dev/null +++ b/tests/input/conditionals/else_if.in @@ -0,0 +1,10 @@ +fn main(): i32 { + if(false) { + println(5); + } else if(true) { + println(15); + } else { + println(10); + } + return 0; +} diff --git a/tests/input/conditionals/if.in b/tests/input/conditionals/if.in new file mode 100644 index 0000000..d35c612 --- /dev/null +++ b/tests/input/conditionals/if.in @@ -0,0 +1,8 @@ +fn main(): i32 { + if(true) { + println(5); + } else { + println(10); + } + return 0; +} diff --git a/tests/input/errors/type/else_if.in b/tests/input/errors/type/else_if.in new file mode 100644 index 0000000..84ed5cd --- /dev/null +++ b/tests/input/errors/type/else_if.in @@ -0,0 +1,8 @@ +fn main(): i32 { + if(false) { + println(5); + } else if('c') { + println(9); + } + return 0; +} diff --git a/tests/input/errors/type/if.in b/tests/input/errors/type/if.in new file mode 100644 index 0000000..349ba6c --- /dev/null +++ b/tests/input/errors/type/if.in @@ -0,0 +1,6 @@ +fn main(): i32 { + if(1) { + println(5); + } + return 0; +} diff --git a/tests/output/conditionals/basic.out b/tests/output/conditionals/basic.out new file mode 100644 index 0000000..7ed6ff8 --- /dev/null +++ b/tests/output/conditionals/basic.out @@ -0,0 +1 @@ +5 diff --git a/tests/output/conditionals/else.out b/tests/output/conditionals/else.out new file mode 100644 index 0000000..f599e28 --- /dev/null +++ b/tests/output/conditionals/else.out @@ -0,0 +1 @@ +10 diff --git a/tests/output/conditionals/else_if.out b/tests/output/conditionals/else_if.out new file mode 100644 index 0000000..60d3b2f --- /dev/null +++ b/tests/output/conditionals/else_if.out @@ -0,0 +1 @@ +15 diff --git a/tests/output/conditionals/if.out b/tests/output/conditionals/if.out new file mode 100644 index 0000000..7ed6ff8 --- /dev/null +++ b/tests/output/conditionals/if.out @@ -0,0 +1 @@ +5 diff --git a/tests/output/errors/type/else_if.out b/tests/output/errors/type/else_if.out new file mode 100644 index 0000000..7cd033a --- /dev/null +++ b/tests/output/errors/type/else_if.out @@ -0,0 +1 @@ +TypeError on Line 4 \ No newline at end of file diff --git a/tests/output/errors/type/if.out b/tests/output/errors/type/if.out new file mode 100644 index 0000000..cae1341 --- /dev/null +++ b/tests/output/errors/type/if.out @@ -0,0 +1 @@ +TypeError on Line 2 \ No newline at end of file