Skip to content

Commit

Permalink
Type Refactor (#43)
Browse files Browse the repository at this point in the history
* refactor types

* improve numeric check on types
  • Loading branch information
jackparsonss committed Jul 2, 2024
1 parent 54c9406 commit 6ebe5f3
Show file tree
Hide file tree
Showing 27 changed files with 242 additions and 138 deletions.
2 changes: 1 addition & 1 deletion include/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "Token.h"

#include "shared/type.h"
#include "shared/type/type.h"

using antlr4::Token, std::make_shared, std::shared_ptr;

Expand Down
5 changes: 3 additions & 2 deletions include/ast/passes/type_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#include <stack>
#include "ast/ast.h"
#include "ast/passes/pass.h"
#include "shared/type/type.h"

class TypeCheck : public Pass {
private:
std::stack<shared_ptr<ast::Function>> func_stack;
void check_numeric(Type type, size_t line);
void check_bool(Type type, size_t line);
void check_numeric(TypePtr type, size_t line);
void check_bool(TypePtr type, size_t line);

public:
explicit TypeCheck();
Expand Down
6 changes: 3 additions & 3 deletions include/ast/symbol/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <string>

#include "ast/ast.h"
#include "shared/type.h"
#include "shared/type/type.h"

using std::shared_ptr;

Expand All @@ -23,9 +23,9 @@ class Symbol {

typedef shared_ptr<Symbol> SymbolPtr;

class BuiltinTypeSymbol : public Symbol, public Type {
class BuiltinTypeSymbol : public Symbol {
public:
BuiltinTypeSymbol(std::string name);
BuiltinTypeSymbol(TypePtr type);
std::string get_name() override;
};

Expand Down
2 changes: 1 addition & 1 deletion include/backend/builtin/print.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "shared/type.h"
#include "shared/type/type.h"

namespace builtin {
void define_all_print();
Expand Down
2 changes: 1 addition & 1 deletion include/backend/expressions/arithmetic.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "ast/ast.h"
#include "shared/type.h"
#include "shared/type/type.h"

#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/IR/Value.h"
Expand Down
2 changes: 1 addition & 1 deletion include/backend/io.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "mlir/IR/Value.h"
#include "shared/type.h"
#include "shared/type/type.h"

namespace io {
void printf(mlir::Value value, TypePtr type);
Expand Down
2 changes: 1 addition & 1 deletion include/shared/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <mlir/IR/Types.h>
#include <memory>

#include "shared/type.h"
#include "shared/type/type.h"

using std::shared_ptr, std::make_shared, std::unique_ptr;

Expand Down
29 changes: 0 additions & 29 deletions include/shared/type.h

This file was deleted.

10 changes: 10 additions & 0 deletions include/shared/type/boolean.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#include "shared/type/type.h"

class Boolean : public Type {
public:
Boolean();
std::string get_specifier() const override;
mlir::Type get_mlir() const override;
};
10 changes: 10 additions & 0 deletions include/shared/type/character.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#include "shared/type/type.h"

class Character : public Type {
public:
Character();
std::string get_specifier() const override;
mlir::Type get_mlir() const override;
};
10 changes: 10 additions & 0 deletions include/shared/type/float.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#include "shared/type/type.h"

class F32 : public Type {
public:
F32();
std::string get_specifier() const override;
mlir::Type get_mlir() const override;
};
11 changes: 11 additions & 0 deletions include/shared/type/integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include "shared/type/type.h"

class I32 : public Type {
public:
I32();
std::string get_specifier() const override;
mlir::Type get_mlir() const override;
bool is_numeric() const override;
};
44 changes: 44 additions & 0 deletions include/shared/type/type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <memory>
#include <string>

#include "mlir/IR/Types.h"

class Type {
protected:
std::string name;

public:
explicit Type(std::string name);

std::string get_name() const;
mlir::Type get_pointer();

virtual std::string get_specifier() const;
virtual mlir::Type get_mlir() const;
virtual bool is_numeric() const;

bool operator==(const Type rhs) const;
};

class Any : public Type {
public:
Any();
std::string get_specifier() const override;
};

class None : public Type {
public:
None();
std::string get_specifier() const override;
mlir::Type get_mlir() const override;
};

class Unset : public Type {
public:
Unset();
std::string get_specifier() const override;
};

typedef std::shared_ptr<Type> TypePtr;
7 changes: 6 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ set(
"${CMAKE_CURRENT_SOURCE_DIR}/ast/passes/def_ref.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/passes/type_check.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/ast/passes/builtin.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type/type.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type/character.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type/integer.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type/float.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type/boolean.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/type/type.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shared/context.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/errors/syntax.cpp"
)
Expand Down
2 changes: 1 addition & 1 deletion src/ast/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ ast::Call::Call(std::string name,
ast::Call::Call(std::string name,
std::vector<shared_ptr<Expression>> args,
Token* token)
: Expression(make_shared<Type>(Type::unset), token) {
: Expression(make_shared<Unset>(), token) {
this->name = name;
this->arguments = args;
this->function = nullptr;
Expand Down
7 changes: 4 additions & 3 deletions src/ast/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ast/ast.h"
#include "ast/builder.h"
#include "shared/context.h"
#include "shared/type/type.h"

using std::any_cast;

Expand Down Expand Up @@ -81,7 +82,7 @@ std::any Builder::visitType(FusionParser::TypeContext* ctx) {
throw std::runtime_error("invalid type found");
}

return dynamic_pointer_cast<Type>(type.value());
return dynamic_pointer_cast<Type>(type.value()->get_type());
}

std::any Builder::visitQualifier(FusionParser::QualifierContext* ctx) {
Expand Down Expand Up @@ -144,7 +145,7 @@ std::any Builder::visitLiteralChar(FusionParser::LiteralCharContext* ctx) {
std::any Builder::visitIdentifier(FusionParser::IdentifierContext* ctx) {
Token* token = ctx->ID()->getSymbol();
std::string name = ctx->ID()->getText();
TypePtr type = make_shared<Type>(Type::unset);
TypePtr type = make_shared<Unset>();

auto var =
make_shared<ast::Variable>(ast::Qualifier::Let, type, name, token);
Expand Down Expand Up @@ -361,7 +362,7 @@ std::any Builder::visitUnary(FusionParser::UnaryContext* ctx) {
std::any Builder::visitAssignment(FusionParser::AssignmentContext* ctx) {
Token* token = ctx->ID()->getSymbol();
std::string name = ctx->ID()->getText();
TypePtr type = make_shared<Type>(Type::unset);
TypePtr type = make_shared<Unset>();

auto var =
make_shared<ast::Variable>(ast::Qualifier::Let, type, name, token);
Expand Down
24 changes: 13 additions & 11 deletions src/ast/passes/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ void TypeCheck::visit_binary_operator(shared_ptr<ast::BinaryOperator> node) {
visit(node->rhs);

size_t line = node->token->getLine();
Type lhs = *node->lhs->get_type();
Type rhs = *node->rhs->get_type();
if (lhs != rhs) {
throw TypeError(line, "mismatched lhs(" + lhs.get_name() +
") and rhs(" + rhs.get_name() +
TypePtr lhs = node->lhs->get_type();
TypePtr rhs = node->rhs->get_type();

if (*lhs != *rhs) {
throw TypeError(line, "mismatched lhs(" + lhs->get_name() +
") and rhs(" + rhs->get_name() +
") types on binary operator: " +
ast::binary_op_type_to_string(node->type));
}
Expand Down Expand Up @@ -116,6 +117,7 @@ void TypeCheck::visit_binary_operator(shared_ptr<ast::BinaryOperator> node) {
case ast::BinaryOpType::EQ:
case ast::BinaryOpType::NE:
node->set_type(ctx::t_bool);
break;
}
}

Expand All @@ -137,14 +139,14 @@ void TypeCheck::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {
}
}

void TypeCheck::check_numeric(Type type, size_t line) {
if (type != *ctx::i32) {
throw TypeError(line, "type(" + type.get_name() + ") is not numeric");
void TypeCheck::check_numeric(TypePtr type, size_t line) {
if (!type->is_numeric()) {
throw TypeError(line, "type(" + type->get_name() + ") is not numeric");
}
}

void TypeCheck::check_bool(Type type, size_t line) {
if (type != *ctx::t_bool) {
throw TypeError(line, "type(" + type.get_name() + ") is not boolean");
void TypeCheck::check_bool(TypePtr type, size_t line) {
if (*type != *ctx::t_bool) {
throw TypeError(line, "type(" + type->get_name() + ") is not boolean");
}
}
4 changes: 2 additions & 2 deletions src/ast/symbol/symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ TypePtr Symbol::get_type() {
return this->type;
}

BuiltinTypeSymbol::BuiltinTypeSymbol(std::string name)
: Symbol(name), Type(name) {}
BuiltinTypeSymbol::BuiltinTypeSymbol(TypePtr type)
: Symbol(type->get_name(), type) {}

std::string BuiltinTypeSymbol::get_name() {
return Symbol::get_name();
Expand Down
4 changes: 2 additions & 2 deletions src/ast/symbol/symbol_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ SymbolTable::SymbolTable() {

void SymbolTable::init_types() {
for (const auto& ty : ctx::primitives) {
define(make_shared<BuiltinTypeSymbol>(ty->get_name()));
define(make_shared<BuiltinTypeSymbol>(ty));
}
define(make_shared<BuiltinTypeSymbol>(ctx::any->get_name()));
define(make_shared<BuiltinTypeSymbol>(ctx::any));
}

void SymbolTable::init_builtins() {
Expand Down
2 changes: 1 addition & 1 deletion src/backend/builtin/print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace {
void create_type_str(TypePtr type) {
std::string ty = type->get_specifier();
auto gvalue = mlir::StringRef(ty.c_str(), 3);
auto gvalue = mlir::StringRef(ty.c_str(), ty.size() + 1);

auto gtype =
mlir::LLVM::LLVMArrayType::get(ctx::ch->get_mlir(), gvalue.size());
Expand Down
17 changes: 11 additions & 6 deletions src/shared/context.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "shared/context.h"
#include "shared/type/boolean.h"
#include "shared/type/character.h"
#include "shared/type/float.h"
#include "shared/type/integer.h"
#include "shared/type/type.h"

mlir::MLIRContext ctx::context;
std::unique_ptr<mlir::Location> ctx::loc =
Expand All @@ -19,12 +24,12 @@ void ctx::initialize_context() {
context.loadDialect<mlir::LLVM::LLVMDialect>();
builder = std::make_shared<mlir::OpBuilder>(&context);

ctx::ch = std::make_shared<Type>(Type::ch);
ctx::any = std::make_shared<Type>(Type::any);
ctx::i32 = std::make_shared<Type>(Type::i32);
ctx::f32 = std::make_shared<Type>(Type::f32);
ctx::none = std::make_shared<Type>(Type::none);
ctx::t_bool = std::make_shared<Type>(Type::t_bool);
ctx::any = make_shared<Any>();
ctx::i32 = make_shared<I32>();
ctx::f32 = make_shared<F32>();
ctx::none = make_shared<None>();
ctx::ch = make_shared<Character>();
ctx::t_bool = make_shared<Boolean>();

ctx::primitives = {
ctx::ch,
Expand Down
Loading

0 comments on commit 6ebe5f3

Please sign in to comment.