Skip to content

Commit

Permalink
Unary Arithmetic (#34)
Browse files Browse the repository at this point in the history
* test ci thing

* test ci thing

* revert ci

* add unary ops
  • Loading branch information
jackparsonss committed Jun 26, 2024
1 parent a29e174 commit 14cc99a
Show file tree
Hide file tree
Showing 14 changed files with 123 additions and 13 deletions.
2 changes: 2 additions & 0 deletions grammar/Fusion.g4
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ return: RETURN expr SEMI;

expr
: call #callExpr
| (op=MINUS | op=BANG) expr #unary
| <assoc='right'> expr CARET expr #power
| expr (op=STAR | op=SLASH | op=MOD | op=DSTAR) expr #mulDivMod
| expr (op=PLUS | op=MINUS) expr #addSub
Expand Down Expand Up @@ -76,6 +77,7 @@ EQ: '==';
CARET: '^';
DAND: '&&';
DOR: '||';
BANG: '!';

// comments
LINE_COMMENT: '//' .*? ('\n' | EOF) -> skip;
Expand Down
15 changes: 15 additions & 0 deletions include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ enum class BinaryOpType {
OR,
};

enum class UnaryOpType {
MINUS,
NOT,
};

std::string random_name();
std::string qualifier_to_string(Qualifier qualifier);
std::string binary_op_type_to_string(BinaryOpType type);
std::string unary_op_type_to_string(UnaryOpType type);

class Node {
public:
Expand Down Expand Up @@ -202,4 +208,13 @@ class BinaryOperator : public Expression {
Token* token);
void xml(int level) override;
};

class UnaryOperator : public Expression {
public:
UnaryOpType type;
shared_ptr<Expression> rhs;

UnaryOperator(UnaryOpType type, shared_ptr<Expression> rhs, Token* token);
void xml(int level) override;
};
} // namespace ast
1 change: 1 addition & 0 deletions include/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ class Builder : public FusionBaseVisitor {
std::any visitGtLtCond(FusionParser::GtLtCondContext* ctx) override;
std::any visitEqNeCond(FusionParser::EqNeCondContext* ctx) override;
std::any visitAndOrCond(FusionParser::AndOrCondContext* ctx) override;
std::any visitUnary(FusionParser::UnaryContext* ctx) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Pass {
virtual void visit_parameter(shared_ptr<ast::Parameter>);
virtual void visit_return(shared_ptr<ast::Return>);
virtual void visit_binary_operator(shared_ptr<ast::BinaryOperator>);
virtual void visit_unary_operator(shared_ptr<ast::UnaryOperator>);
};

namespace pass {
Expand Down
1 change: 1 addition & 0 deletions include/backend/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ class Backend {
mlir::Value visit_parameter(shared_ptr<ast::Parameter>);
mlir::Value visit_return(shared_ptr<ast::Return>);
mlir::Value visit_binary_operator(shared_ptr<ast::BinaryOperator>);
mlir::Value visit_unary_operator(shared_ptr<ast::UnaryOperator>);
};
6 changes: 6 additions & 0 deletions include/backend/expressions/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mlir::Value lte(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value and_(mlir::Value lhs, mlir::Value rhs, TypePtr type);
mlir::Value or_(mlir::Value lhs, mlir::Value rhs, TypePtr type);

mlir::Value not_(mlir::Value value);
mlir::Value negate(mlir::Value value, TypePtr type);

mlir::Value binary_equality(mlir::Value lhs,
mlir::Value rhs,
TypePtr type,
Expand All @@ -31,4 +34,7 @@ mlir::Value binary_operation(mlir::Value lhs,
mlir::Value rhs,
ast::BinaryOpType op_type,
TypePtr type);
mlir::Value unary_operation(mlir::Value rhs,
ast::UnaryOpType op_type,
TypePtr type);
} // namespace arithmetic
29 changes: 29 additions & 0 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ std::string ast::binary_op_type_to_string(ast::BinaryOpType type) {
}
}

std::string ast::unary_op_type_to_string(ast::UnaryOpType type) {
switch (type) {
case ast::UnaryOpType::MINUS:
return "MINUS";
case ast::UnaryOpType::NOT:
return "NOT";
}
}

ast::Node::Node(Token* token) {
if (token == nullptr) {
return;
Expand Down Expand Up @@ -348,3 +357,23 @@ void ast::BinaryOperator::xml(int level) {

std::cout << std::string(level * 4, ' ') << "</binary operator>\n";
}

ast::UnaryOperator::UnaryOperator(UnaryOpType type,
shared_ptr<Expression> rhs,
Token* token)
: Expression(rhs->get_type(), token) {
this->type = type;
this->rhs = rhs;
}

void ast::UnaryOperator::xml(int level) {
std::cout << std::string(level * 4, ' ') << "<binop op_type=\""
<< ast::unary_op_type_to_string(type) << "\" type=\""
<< get_type()->get_name() << "\">\n";

std::cout << std::string((level + 1) * 4, ' ') << "<rhs>\n";
this->rhs->xml(level + 2);
std::cout << std::string((level + 1) * 4, ' ') << "</rhs>\n";

std::cout << std::string(level * 4, ' ') << "</binary operator>\n";
}
19 changes: 19 additions & 0 deletions src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,22 @@ std::any Builder::visitAndOrCond(FusionParser::AndOrCondContext* ctx) {

return to_node(binop);
}

std::any Builder::visitUnary(FusionParser::UnaryContext* ctx) {
Token* token;
ast::UnaryOpType type;

if (ctx->MINUS() != nullptr) {
type = ast::UnaryOpType::MINUS;
token = ctx->MINUS()->getSymbol();
} else if (ctx->BANG() != nullptr) {
type = ast::UnaryOpType::NOT;
token = ctx->BANG()->getSymbol();
} else {
throw std::runtime_error("Unrecognized unary operator");
}
auto rhs = cast_node(ast::Expression, visit(ctx->expr()));
auto binop = make_shared<ast::UnaryOperator>(type, rhs, token);

return to_node(binop);
}
7 changes: 7 additions & 0 deletions src/ast/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ void Pass::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::Parameter, this->visit_parameter);
try_visit(node, ast::Return, this->visit_return);
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);

throw std::runtime_error("node not added to pass manager");
}

void Pass::visit_block(shared_ptr<ast::Block> node) {
Expand Down Expand Up @@ -88,3 +91,7 @@ void Pass::visit_binary_operator(shared_ptr<ast::BinaryOperator> node) {
visit(node->lhs);
visit(node->rhs);
}

void Pass::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {
visit(node->rhs);
}
25 changes: 24 additions & 1 deletion src/backend/expressions/arithmetic.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "backend/expressions/arithmetic.h"
#include "ast/ast.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "backend/types/boolean.h"
#include "backend/types/integer.h"
#include "shared/context.h"

#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

namespace {
Expand Down Expand Up @@ -75,6 +77,16 @@ mlir::Value arithmetic::lte(mlir::Value lhs, mlir::Value rhs, TypePtr type) {
return binary_equality(lhs, rhs, type, mlir::LLVM::ICmpPredicate::sle);
}

mlir::Value arithmetic::not_(mlir::Value value) {
mlir::Value f = boolean::create_bool(false);
return eq(f, value, ctx::t_bool);
}

mlir::Value arithmetic::negate(mlir::Value value, TypePtr type) {
mlir::Value neg_one = integer::create_i32(-1);
return mul(neg_one, value, type);
}

mlir::Value arithmetic::binary_equality(mlir::Value lhs,
mlir::Value rhs,
TypePtr type,
Expand Down Expand Up @@ -120,3 +132,14 @@ mlir::Value arithmetic::binary_operation(mlir::Value lhs,
return or_(lhs, rhs, type);
}
}

mlir::Value arithmetic::unary_operation(mlir::Value rhs,
ast::UnaryOpType op_type,
TypePtr type) {
switch (op_type) {
case ast::UnaryOpType::MINUS:
return negate(rhs, type);
case ast::UnaryOpType::NOT:
return not_(rhs);
}
}
7 changes: 7 additions & 0 deletions src/backend/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mlir::Value Backend::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::Parameter, this->visit_parameter);
try_visit(node, ast::Return, this->visit_return);
try_visit(node, ast::BinaryOperator, this->visit_binary_operator);
try_visit(node, ast::UnaryOperator, this->visit_unary_operator);

throw std::runtime_error("node not added to backend visit function");
}
Expand Down Expand Up @@ -134,3 +135,9 @@ mlir::Value Backend::visit_binary_operator(

return arithmetic::binary_operation(lhs, rhs, node->type, node->get_type());
}

mlir::Value Backend::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {
mlir::Value rhs = visit(node->rhs);

return arithmetic::unary_operation(rhs, node->type, node->get_type());
}
7 changes: 7 additions & 0 deletions tests/input/operators/unary.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
fn main(): i32 {
println(!true);
println(!false);
println(-5);

return 0;
}
13 changes: 1 addition & 12 deletions tests/lli-temp.out
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
8
5
1
6
2
0
1
0
0
1
0
1
1
0
-5
3 changes: 3 additions & 0 deletions tests/output/operators/unary.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0
1
-5

0 comments on commit 14cc99a

Please sign in to comment.