Skip to content

Commit

Permalink
i32 Arithmetic (#26)
Browse files Browse the repository at this point in the history
* add ast node an builder

* add backend

* update readme
  • Loading branch information
jackparsonss committed Jun 24, 2024
1 parent 09bbeea commit c603d0a
Show file tree
Hide file tree
Showing 21 changed files with 346 additions and 17 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,23 @@ Prints the argument passed into stdout
```
print(5);
```

#### Arithmetic
- addition: `+`
- subtraction: `-`
- power: `^`
- multiplication: `*`
- division: `/`
- modulus: `%`
```
fn main(): i32 {
let a: i32 = 5 + 5;
let s: i32 = 5 - 5;
let p: i32 = 5 ^ 5;
let m: i32 = 5 * 5;
let d: i32 = 5 / 5;
let r: i32 = 5 % 5;
}
```
```
```
30 changes: 24 additions & 6 deletions grammar/Fusion.g4
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ statement
;

declaration:
variable EQ expr SEMI
variable EQUAL expr SEMI
;

block: L_CURLY statement* R_CURLY;
Expand All @@ -25,10 +25,15 @@ call: ID L_PAREN expr? (COMMA expr)* R_PAREN;
return: RETURN expr SEMI;

expr
: call #callExpr
| CHARACTER #literalChar
| INT #literalInt
| ID #identifier
: call #callExpr
| <assoc='right'> expr CARET expr #power
| expr (op=STAR | op=SLASH | op=MOD | op=DSTAR) expr #mulDivMod
| expr (op=PLUS | op=MINUS) expr #addSub
| expr (op=GT | op=LT | op=GE | op=LE) expr #gtLtCond
| expr (op=EQ | op=NE) expr #eqNeCond
| CHARACTER #literalChar
| INT #literalInt
| ID #identifier
;

qualifier: CONST | LET;
Expand All @@ -45,11 +50,24 @@ LET: 'let';
SEMI: ';';
COLON: ':';
COMMA: ',';
EQ: '=';
EQUAL: '=';
L_PAREN: '(';
R_PAREN: ')';
L_CURLY: '{';
R_CURLY: '}';
PLUS: '+';
MINUS: '-';
STAR: '*';
DSTAR: '**';
SLASH: '/';
MOD: '%';
GT: '>';
LT: '<';
GE: '>=';
LE: '<=';
NE: '!=';
EQ: '==';
CARET: '^';

// comments
LINE_COMMENT: '//' .*? ('\n' | EOF) -> skip;
Expand Down
33 changes: 30 additions & 3 deletions include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,32 @@

#include "shared/type.h"

using antlr4::Token;
using std::make_shared;
using std::shared_ptr;
using antlr4::Token, std::make_shared, std::shared_ptr;

namespace ast {
enum class Qualifier {
Const,
Let,
};

enum class BinaryOpType {
POW,
ADD,
SUB,
MUL,
DIV,
MOD,
GT,
GTE,
LT,
LTE,
EQ,
NE,
};

std::string random_name();
std::string qualifier_to_string(Qualifier qualifier);
std::string binary_op_type_to_string(BinaryOpType type);

class Node {
public:
Expand Down Expand Up @@ -162,4 +176,17 @@ class Return : public Node {
Return(shared_ptr<Expression> expr, Token* token);
void xml(int level) override;
};

class BinaryOperator : public Expression {
public:
BinaryOpType type;
shared_ptr<Expression> lhs;
shared_ptr<Expression> rhs;

BinaryOperator(BinaryOpType type,
shared_ptr<Expression> lhs,
shared_ptr<Expression> rhs,
Token* token);
void xml(int level) override;
};
} // namespace ast
5 changes: 5 additions & 0 deletions include/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ class Builder : public FusionBaseVisitor {
std::any visitVariable(FusionParser::VariableContext* ctx) override;
std::any visitReturn(FusionParser::ReturnContext* ctx) override;
std::any visitCallExpr(FusionParser::CallExprContext* ctx) override;
std::any visitPower(FusionParser::PowerContext* ctx) override;
std::any visitMulDivMod(FusionParser::MulDivModContext* ctx) override;
std::any visitAddSub(FusionParser::AddSubContext* ctx) override;
std::any visitGtLtCond(FusionParser::GtLtCondContext* ctx) override;
std::any visitEqNeCond(FusionParser::EqNeCondContext* ctx) override;
};
1 change: 1 addition & 0 deletions include/backend/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ class Backend {
mlir::Value visit_parameter(shared_ptr<ast::Parameter>);
mlir::Value visit_return(shared_ptr<ast::Return>);
mlir::Value visit_character_literal(shared_ptr<ast::CharacterLiteral>);
mlir::Value visit_binary_operator(shared_ptr<ast::BinaryOperator>);
};
18 changes: 18 additions & 0 deletions include/backend/expressions/arithmetic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "ast/ast.h"
#include "mlir/IR/Value.h"
#include "shared/type.h"

namespace arithmetic {
mlir::Value add(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value sub(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value mul(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value div(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value mod(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value pow(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value binary_operation(mlir::Value lhs,
mlir::Value rhs,
ast::BinaryOpType op_type,
TypePtr type);
} // namespace arithmetic
1 change: 1 addition & 0 deletions include/shared/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ extern mlir::MLIRContext context;

extern shared_ptr<Type> ch;
extern shared_ptr<Type> i32;
extern shared_ptr<Type> f32;
extern shared_ptr<Type> none;

extern void initialize_context();
Expand Down
1 change: 1 addition & 0 deletions include/shared/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Type {

static const Type ch;
static const Type i32;
static const Type f32;
static const Type none;
static const Type unset;
};
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(
"${CMAKE_CURRENT_SOURCE_DIR}/backend/types/integer.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/backend/types/character.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/backend/builtin/print.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/backend/expressions/arithmetic.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/builder.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/symbol/scope.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/symbol/symbol.cpp"
Expand Down
54 changes: 54 additions & 0 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,35 @@ std::string ast::qualifier_to_string(ast::Qualifier qualifier) {
throw std::runtime_error("invalid qualifier case");
}

std::string ast::binary_op_type_to_string(ast::BinaryOpType type) {
switch (type) {
case ast::BinaryOpType::POW:
return "POW";
case ast::BinaryOpType::ADD:
return "ADD";
case ast::BinaryOpType::SUB:
return "SUB";
case ast::BinaryOpType::MUL:
return "MUL";
case ast::BinaryOpType::DIV:
return "DIV";
case ast::BinaryOpType::MOD:
return "MOD";
case ast::BinaryOpType::GT:
return "GT";
case ast::BinaryOpType::GTE:
return "GTE";
case ast::BinaryOpType::LT:
return "LT";
case ast::BinaryOpType::LTE:
return "LTE";
case ast::BinaryOpType::EQ:
return "EQ";
case ast::BinaryOpType::NE:
return "NE";
}
}

ast::Node::Node(Token* token) {
if (token == nullptr) {
return;
Expand Down Expand Up @@ -274,3 +303,28 @@ void ast::Return::xml(int level) {
this->expr->xml(level + 1);
std::cout << std::string(level * 4, ' ') << "</return>\n";
}

ast::BinaryOperator::BinaryOperator(BinaryOpType type,
shared_ptr<Expression> lhs,
shared_ptr<Expression> rhs,
Token* token)
: Expression(lhs->get_type(), token) {
this->type = type;
this->lhs = lhs;
this->rhs = rhs;
}

void ast::BinaryOperator::xml(int level) {
std::cout << std::string(level * 4, ' ') << "<binary operator type=\""
<< ast::binary_op_type_to_string(type) << "\">\n";

std::cout << std::string((level + 1) * 4, ' ') << "<lhs>\n";
this->lhs->xml(level + 1);
std::cout << std::string((level + 1) * 4, ' ') << "</lhs>\n";

std::cout << std::string((level + 1) * 4, ' ') << "<rhs>\n";
this->rhs->xml(level + 1);
std::cout << std::string((level + 1) * 4, ' ') << "</rhs>\n";

std::cout << std::string(level * 4, ' ') << "</binary operator>\n";
}
64 changes: 63 additions & 1 deletion src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) {
}

std::any Builder::visitDeclaration(FusionParser::DeclarationContext* ctx) {
Token* token = ctx->EQ()->getSymbol();
Token* token = ctx->EQUAL()->getSymbol();

auto expr = cast_node(ast::Expression, visit(ctx->expr()));
auto var = cast_node(ast::Variable, visit(ctx->variable()));
Expand Down Expand Up @@ -199,3 +199,65 @@ std::any Builder::visitReturn(FusionParser::ReturnContext* ctx) {

return to_node(ret);
}

std::any Builder::visitPower(FusionParser::PowerContext* ctx) {
Token* token = ctx->CARET()->getSymbol();
auto lhs = cast_node(ast::Expression, visit(ctx->expr()[0]));
auto rhs = cast_node(ast::Expression, visit(ctx->expr()[1]));
auto binop = make_shared<ast::BinaryOperator>(ast::BinaryOpType::POW, lhs,
rhs, token);

return to_node(binop);
}

std::any Builder::visitMulDivMod(FusionParser::MulDivModContext* ctx) {
Token* token;
ast::BinaryOpType type;

if (ctx->STAR() != nullptr) {
type = ast::BinaryOpType::MUL;
token = ctx->STAR()->getSymbol();
} else if (ctx->SLASH() != nullptr) {
type = ast::BinaryOpType::DIV;
token = ctx->SLASH()->getSymbol();
} else if (ctx->MOD() != nullptr) {
type = ast::BinaryOpType::MOD;
token = ctx->MOD()->getSymbol();
} else {
throw std::runtime_error(
"Unrecognized operator when visiting mul div mod");
}
auto lhs = cast_node(ast::Expression, visit(ctx->expr()[0]));
auto rhs = cast_node(ast::Expression, visit(ctx->expr()[1]));
auto binop = make_shared<ast::BinaryOperator>(type, lhs, rhs, token);

return to_node(binop);
}

std::any Builder::visitAddSub(FusionParser::AddSubContext* ctx) {
Token* token;
ast::BinaryOpType type;

if (ctx->PLUS() != nullptr) {
type = ast::BinaryOpType::ADD;
token = ctx->PLUS()->getSymbol();
} else if (ctx->MINUS() != nullptr) {
type = ast::BinaryOpType::SUB;
token = ctx->MINUS()->getSymbol();
} else {
throw std::runtime_error("Unrecognized operator when visiting add sub");
}
auto lhs = cast_node(ast::Expression, visit(ctx->expr()[0]));
auto rhs = cast_node(ast::Expression, visit(ctx->expr()[1]));
auto binop = make_shared<ast::BinaryOperator>(type, lhs, rhs, token);

return to_node(binop);
}

std::any Builder::visitGtLtCond(FusionParser::GtLtCondContext* ctx) {
throw std::runtime_error("gt, lt conds not implemented yet");
}

std::any Builder::visitEqNeCond(FusionParser::EqNeCondContext* ctx) {
throw std::runtime_error("eq, ne conds not implemented yet");
}
2 changes: 1 addition & 1 deletion src/ast/symbol/symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ std::string Symbol::get_name() {
return this->name;
}

std::shared_ptr<Type> Symbol::get_type() {
TypePtr Symbol::get_type() {
return this->type;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ast/symbol/symbol_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void SymbolTable::init_types() {
define(make_shared<BuiltinTypeSymbol>(ctx::ch->get_name()));
}

shared_ptr<ast::Function> make_print(shared_ptr<Type> type) {
shared_ptr<ast::Function> make_print(TypePtr type) {
Token* token = new antlr4::CommonToken(1);
auto body = make_shared<ast::Block>(token);

Expand Down
3 changes: 2 additions & 1 deletion src/backend/backend.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "backend/backend.h"
#include "backend/builtin/print.h"
#include "shared/context.h"

#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
Expand All @@ -16,7 +18,6 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "shared/context.h"

Backend::Backend(shared_ptr<ast::Block> ast) {
this->ast = ast;
Expand Down
Loading

0 comments on commit c603d0a

Please sign in to comment.