Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assignment #39

Merged
merged 6 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions grammar/Fusion.g4
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ statement
: function
| call SEMI
| declaration
| assignment
| block
| return
;
Expand All @@ -20,6 +21,8 @@ function: FUNCTION ID L_PAREN variable? (COMMA variable)* R_PAREN COLON type blo

variable: qualifier ID COLON type;

assignment: ID EQUAL expr SEMI;

call: ID L_PAREN expr? (COMMA expr)* R_PAREN;

return: RETURN expr SEMI;
Expand Down
13 changes: 13 additions & 0 deletions include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class Expression : public Node {

void set_type(TypePtr type);
TypePtr get_type() const;
virtual bool is_l_value();
};

class IntegerLiteral : public Expression {
Expand Down Expand Up @@ -121,6 +122,7 @@ class Variable : public Expression {
void set_qualifier(Qualifier qualifer);
void set_ref_name(std::string name);
void xml(int level) override;
bool is_l_value() override;
};

class Declaration : public Node {
Expand All @@ -134,6 +136,17 @@ class Declaration : public Node {
void xml(int level) override;
};

class Assignment : public Node {
public:
shared_ptr<Variable> var;
shared_ptr<Expression> expr;
explicit Assignment(shared_ptr<Variable> var,
shared_ptr<Expression> expr,
Token* token);

void xml(int level) override;
};

class Parameter : public Node {
public:
shared_ptr<Variable> var;
Expand Down
1 change: 1 addition & 0 deletions include/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ class Builder : public FusionBaseVisitor {
std::any visitEqNeCond(FusionParser::EqNeCondContext* ctx) override;
std::any visitAndOrCond(FusionParser::AndOrCondContext* ctx) override;
std::any visitUnary(FusionParser::UnaryContext* ctx) override;
std::any visitAssignment(FusionParser::AssignmentContext* ctx) override;
};
1 change: 1 addition & 0 deletions include/ast/passes/def_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DefRef : public Pass {
explicit DefRef(shared_ptr<SymbolTable> symbol_table);
void visit_block(shared_ptr<ast::Block>) override;
void visit_declaration(shared_ptr<ast::Declaration>) override;
void visit_assignment(shared_ptr<ast::Assignment>) override;
void visit_parameter(shared_ptr<ast::Parameter>) override;
void visit_function(shared_ptr<ast::Function>) override;
void visit_variable(shared_ptr<ast::Variable>) override;
Expand Down
1 change: 1 addition & 0 deletions include/ast/passes/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Pass {
virtual void visit_character_literal(shared_ptr<ast::CharacterLiteral>);
virtual void visit_boolean_literal(shared_ptr<ast::BooleanLiteral>);
virtual void visit_declaration(shared_ptr<ast::Declaration>);
virtual void visit_assignment(shared_ptr<ast::Assignment>);
virtual void visit_variable(shared_ptr<ast::Variable>);
virtual void visit_function(shared_ptr<ast::Function>);
virtual void visit_call(shared_ptr<ast::Call>);
Expand Down
1 change: 1 addition & 0 deletions include/ast/passes/type_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TypeCheck : public Pass {
public:
explicit TypeCheck();
void visit_declaration(shared_ptr<ast::Declaration>) override;
void visit_assignment(shared_ptr<ast::Assignment>) override;
void visit_function(shared_ptr<ast::Function>) override;
void visit_call(shared_ptr<ast::Call>) override;
void visit_return(shared_ptr<ast::Return>) override;
Expand Down
1 change: 1 addition & 0 deletions include/backend/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Backend {
mlir::Value visit_boolean_literal(shared_ptr<ast::BooleanLiteral>);
mlir::Value visit_variable(shared_ptr<ast::Variable>);
mlir::Value visit_declaration(shared_ptr<ast::Declaration>);
mlir::Value visit_assignment(shared_ptr<ast::Assignment>);
mlir::Value visit_function(shared_ptr<ast::Function>);
mlir::Value visit_call(shared_ptr<ast::Call>);
mlir::Value visit_parameter(shared_ptr<ast::Parameter>);
Expand Down
3 changes: 3 additions & 0 deletions include/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ast/builder.h"
#include "ast/symbol/symbol_table.h"
#include "backend/backend.h"
#include "errors/syntax.h"

using std::shared_ptr, std::unique_ptr;

Expand All @@ -22,6 +23,8 @@ class Compiler {
antlr4::tree::ParseTree* tree;
antlr4::CommonTokenStream* tokens;
fusion::FusionParser* parser;
LexerErrorListener* lexer_error;
SyntaxErrorListener* syntax_error;

public:
Compiler(std::string filename,
Expand Down
4 changes: 4 additions & 0 deletions include/errors.h → include/errors/errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ DEF_COMPILE_TIME_EXCEPTION(MainError);

DEF_COMPILE_TIME_EXCEPTION(SymbolError);

DEF_COMPILE_TIME_EXCEPTION(LexerError);

DEF_COMPILE_TIME_EXCEPTION(SyntaxError);

DEF_COMPILE_TIME_EXCEPTION(TypeError);

DEF_COMPILE_TIME_EXCEPTION(AssignError);
36 changes: 36 additions & 0 deletions include/errors/syntax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "BaseErrorListener.h"
#include "antlr4-runtime.h"

#define _NC "\033[0m"
#define _RED "\033[0;31m"
#define _BLUE "\033[0;34m"

class LexerErrorListener : public antlr4::BaseErrorListener {
public:
void syntaxError(antlr4::Recognizer* recognizer,
antlr4::Token* offending_symbol,
size_t line,
size_t char_position_in_line,
const std::string& msg,
std::exception_ptr e) override;
};

class SyntaxErrorListener : public antlr4::BaseErrorListener {
public:
void syntaxError(antlr4::Recognizer* recognizer,
antlr4::Token* offending_symbol,
size_t line,
size_t char_position_in_line,
const std::string& msg,
std::exception_ptr e) override;
};

void underline_error(antlr4::Recognizer* recognizer,
antlr4::Token* offending_symbol,
size_t line,
size_t char_position_in_line,
std::ostream& out);

void show_rule_stack(antlr4::Recognizer* recognizer, std::ostream& out);
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(
"${CMAKE_CURRENT_SOURCE_DIR}/ast/passes/builtin.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/context.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/errors/syntax.cpp"
)

# Build our executable from the source files.
Expand Down
27 changes: 27 additions & 0 deletions src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ ast::Expression::Expression(TypePtr type, Token* token) : Node(token) {
this->type = type;
}

bool ast::Expression::is_l_value() {
return false;
}

void ast::Expression::set_type(TypePtr type) {
this->type = type;
}
Expand Down Expand Up @@ -179,6 +183,10 @@ std::string ast::Variable::get_ref_name() {
return this->ref_name;
}

bool ast::Variable::is_l_value() {
return this->qualifier == ast::Qualifier::Let;
}

void ast::Variable::xml(int level) {
std::cout << std::string(level * 4, ' ');
std::cout << "<variable qualifier=\"" << ast::qualifier_to_string(qualifier)
Expand All @@ -205,6 +213,25 @@ void ast::Declaration::xml(int level) {
std::cout << std::string(level * 4, ' ') << "</declaration>\n";
}

ast::Assignment::Assignment(shared_ptr<Variable> var,
shared_ptr<Expression> expr,
Token* token)
: Node(token) {
this->var = var;
this->expr = expr;
}

void ast::Assignment::xml(int level) {
std::cout << std::string(level * 4, ' ') << "<assignment>\n";
this->var->xml(level + 1);

std::cout << std::string((level + 1) * 4, ' ') << "<rhs>\n";
this->expr->xml(level + 2);
std::cout << std::string((level + 1) * 4, ' ') << "</rhs>\n";

std::cout << std::string(level * 4, ' ') << "</assignment>\n";
}

ast::Parameter::Parameter(shared_ptr<Variable> var, Token* token)
: Node(token) {
this->var = var;
Expand Down
17 changes: 17 additions & 0 deletions src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ std::any Builder::visitStatement(FusionParser::StatementContext* ctx) {
return visit(ctx->declaration());
}

if (ctx->assignment() != nullptr) {
return visit(ctx->assignment());
}

if (ctx->function() != nullptr) {
return visit(ctx->function());
}
Expand Down Expand Up @@ -353,3 +357,16 @@ std::any Builder::visitUnary(FusionParser::UnaryContext* ctx) {

return to_node(binop);
}

std::any Builder::visitAssignment(FusionParser::AssignmentContext* ctx) {
Token* token = ctx->ID()->getSymbol();
std::string name = ctx->ID()->getText();
TypePtr type = make_shared<Type>(Type::unset);

auto var =
make_shared<ast::Variable>(ast::Qualifier::Let, type, name, token);
auto expr = cast_node(ast::Expression, visit(ctx->expr()));

auto assn = make_shared<ast::Assignment>(var, expr, token);
return to_node(assn);
}
13 changes: 12 additions & 1 deletion src/ast/passes/def_ref.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "ast/passes/def_ref.h"
#include "ast/ast.h"
#include "ast/symbol/function_symbol.h"
#include "errors.h"
#include "errors/errors.h"

DefRef::DefRef(shared_ptr<SymbolTable> symbol_table) : Pass("DefRef") {
this->symbol_table = symbol_table;
Expand Down Expand Up @@ -30,6 +30,17 @@ void DefRef::visit_declaration(shared_ptr<ast::Declaration> node) {
visit(node->expr);
}

void DefRef::visit_assignment(shared_ptr<ast::Assignment> node) {
visit(node->var);
if (!node->var->is_l_value()) {
throw AssignError(
node->token->getLine(),
"Cannot assign to const variable " + node->var->get_name());
}

visit(node->expr);
}

void DefRef::visit_parameter(shared_ptr<ast::Parameter> node) {
shared_ptr<ast::Variable> var = node->var;
if (symbol_table->resolve_local(var->get_name()).has_value()) {
Expand Down
6 changes: 6 additions & 0 deletions src/ast/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ void Pass::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::BooleanLiteral, this->visit_boolean_literal);
try_visit(node, ast::Variable, this->visit_variable);
try_visit(node, ast::Declaration, this->visit_declaration);
try_visit(node, ast::Assignment, this->visit_assignment);
try_visit(node, ast::Function, this->visit_function);
try_visit(node, ast::Call, this->visit_call);
try_visit(node, ast::Parameter, this->visit_parameter);
Expand Down Expand Up @@ -71,6 +72,11 @@ void Pass::visit_declaration(shared_ptr<ast::Declaration> node) {
visit(node->expr);
}

void Pass::visit_assignment(shared_ptr<ast::Assignment> node) {
visit(node->var);
visit(node->expr);
}

void Pass::visit_function(shared_ptr<ast::Function> node) {
for (const auto& param : node->params) {
visit(param);
Expand Down
15 changes: 14 additions & 1 deletion src/ast/passes/type_check.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "ast/passes/type_check.h"
#include "ast/ast.h"
#include "errors.h"
#include "errors/errors.h"
#include "shared/context.h"

TypeCheck::TypeCheck() : Pass("Typecheck") {}
Expand All @@ -18,6 +18,19 @@ void TypeCheck::visit_declaration(shared_ptr<ast::Declaration> node) {
}
}

void TypeCheck::visit_assignment(shared_ptr<ast::Assignment> node) {
visit(node->var);
visit(node->expr);

Type var = *node->var->get_type();
Type expr = *node->expr->get_type();
if (var != expr) {
throw TypeError(node->token->getLine(),
"mismatched lhs(" + var.get_name() + ") and rhs(" +
expr.get_name() + ") types on assignment");
}
}

void TypeCheck::visit_function(shared_ptr<ast::Function> node) {
for (const auto& param : node->params) {
visit(param);
Expand Down
11 changes: 11 additions & 0 deletions src/backend/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mlir::Value Backend::visit(shared_ptr<ast::Node> node) {
try_visit(node, ast::BooleanLiteral, this->visit_boolean_literal);
try_visit(node, ast::Variable, this->visit_variable);
try_visit(node, ast::Declaration, this->visit_declaration);
try_visit(node, ast::Assignment, this->visit_assignment);
try_visit(node, ast::Function, this->visit_function);
try_visit(node, ast::Call, this->visit_call);
try_visit(node, ast::Parameter, this->visit_parameter);
Expand Down Expand Up @@ -80,6 +81,16 @@ mlir::Value Backend::visit_declaration(shared_ptr<ast::Declaration> node) {
return nullptr;
}

mlir::Value Backend::visit_assignment(shared_ptr<ast::Assignment> node) {
std::string name = node->var->get_ref_name();
mlir::Value expr = visit(node->expr);
mlir::Value address = variables[name];

utils::store(address, expr);

return nullptr;
}

mlir::Value Backend::visit_parameter(shared_ptr<ast::Parameter> node) {
std::string name = node->var->get_ref_name();
mlir::Value address =
Expand Down
13 changes: 12 additions & 1 deletion src/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include "FusionParser.h"
#include "ast/passes/pass.h"
#include "compiler.h"
#include "errors.h"
#include "errors/errors.h"
#include "errors/syntax.h"

Compiler::Compiler(std::string filename,
shared_ptr<SymbolTable> symbol_table,
Expand All @@ -19,9 +20,17 @@ Compiler::Compiler(std::string filename,
file = new antlr4::ANTLRFileStream();
file->loadFromFile(filename);

lexer_error = new LexerErrorListener();
lexer = new fusion::FusionLexer(file);
lexer->removeErrorListeners();
lexer->addErrorListener(lexer_error);

tokens = new antlr4::CommonTokenStream(lexer);

syntax_error = new SyntaxErrorListener();
parser = new fusion::FusionParser(tokens);
parser->removeErrorListeners();
parser->addErrorListener(syntax_error);

tree = parser->file();
}
Expand All @@ -31,6 +40,8 @@ Compiler::~Compiler() {
delete lexer;
delete tokens;
delete parser;
delete lexer_error;
delete syntax_error;
}

void Compiler::build_ast() {
Expand Down
Loading
Loading