Skip to content

Commit

Permalink
improve numeric check on types
Browse files Browse the repository at this point in the history
  • Loading branch information
jackparsonss committed Jul 2, 2024
1 parent b79a04f commit 3c65976
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 14 deletions.
5 changes: 3 additions & 2 deletions include/ast/passes/type_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#include <stack>
#include "ast/ast.h"
#include "ast/passes/pass.h"
#include "shared/type/type.h"

class TypeCheck : public Pass {
private:
std::stack<shared_ptr<ast::Function>> func_stack;
void check_numeric(Type type, size_t line);
void check_bool(Type type, size_t line);
void check_numeric(TypePtr type, size_t line);
void check_bool(TypePtr type, size_t line);

public:
explicit TypeCheck();
Expand Down
1 change: 1 addition & 0 deletions include/shared/type/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ class I32 : public Type {
I32();
std::string get_specifier() const override;
mlir::Type get_mlir() const override;
bool is_numeric() const override;
};
6 changes: 5 additions & 1 deletion include/shared/type/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <memory>
#include <string>

#include "mlir/IR/Types.h"

class Type {
Expand All @@ -10,10 +11,13 @@ class Type {

public:
explicit Type(std::string name);

std::string get_name() const;
mlir::Type get_pointer();

virtual std::string get_specifier() const;
virtual mlir::Type get_mlir() const;
mlir::Type get_pointer();
virtual bool is_numeric() const;

bool operator==(const Type rhs) const;
};
Expand Down
24 changes: 13 additions & 11 deletions src/ast/passes/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ void TypeCheck::visit_binary_operator(shared_ptr<ast::BinaryOperator> node) {
visit(node->rhs);

size_t line = node->token->getLine();
Type lhs = *node->lhs->get_type();
Type rhs = *node->rhs->get_type();
if (lhs != rhs) {
throw TypeError(line, "mismatched lhs(" + lhs.get_name() +
") and rhs(" + rhs.get_name() +
TypePtr lhs = node->lhs->get_type();
TypePtr rhs = node->rhs->get_type();

if (*lhs != *rhs) {
throw TypeError(line, "mismatched lhs(" + lhs->get_name() +
") and rhs(" + rhs->get_name() +
") types on binary operator: " +
ast::binary_op_type_to_string(node->type));
}
Expand Down Expand Up @@ -116,6 +117,7 @@ void TypeCheck::visit_binary_operator(shared_ptr<ast::BinaryOperator> node) {
case ast::BinaryOpType::EQ:
case ast::BinaryOpType::NE:
node->set_type(ctx::t_bool);
break;
}
}

Expand All @@ -137,14 +139,14 @@ void TypeCheck::visit_unary_operator(shared_ptr<ast::UnaryOperator> node) {
}
}

void TypeCheck::check_numeric(Type type, size_t line) {
if (type != *ctx::i32) {
throw TypeError(line, "type(" + type.get_name() + ") is not numeric");
void TypeCheck::check_numeric(TypePtr type, size_t line) {
if (!type->is_numeric()) {
throw TypeError(line, "type(" + type->get_name() + ") is not numeric");
}
}

void TypeCheck::check_bool(Type type, size_t line) {
if (type != *ctx::t_bool) {
throw TypeError(line, "type(" + type.get_name() + ") is not boolean");
void TypeCheck::check_bool(TypePtr type, size_t line) {
if (*type != *ctx::t_bool) {
throw TypeError(line, "type(" + type->get_name() + ") is not boolean");
}
}
4 changes: 4 additions & 0 deletions src/shared/type/integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ std::string I32::get_specifier() const {
mlir::Type I32::get_mlir() const {
return ctx::builder->getI32Type();
}

bool I32::is_numeric() const {
return true;
}
4 changes: 4 additions & 0 deletions src/shared/type/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ mlir::Type Type::get_mlir() const {
throw std::runtime_error("invalid mlir type found " + name);
}

bool Type::is_numeric() const {
return false;
}

mlir::Type Type::get_pointer() {
return mlir::LLVM::LLVMPointerType::get(get_mlir());
}
Expand Down

0 comments on commit 3c65976

Please sign in to comment.