Skip to content

Commit

Permalink
feat: Added storing values into class fields, added class constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
alinalihassan committed May 24, 2022
1 parent 55e86bc commit 3675e51
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 62 deletions.
14 changes: 7 additions & 7 deletions src/liblesma/AST/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,18 @@ namespace lesma {
};

class Assignment : public Statement {
Literal *var;
Expression* lhs;
TokenType op;
Expression *expr;
Expression *rhs;

public:
Assignment(llvm::SMRange Loc, Literal *var, TokenType op, Expression *expr) : Statement(Loc), var(var), op(op), expr(expr) {}
Assignment(llvm::SMRange Loc, Expression* lhs, TokenType op, Expression *rhs) : Statement(Loc), lhs(lhs), op(op), rhs(rhs) {}

~Assignment() override = default;

[[nodiscard]] [[maybe_unused]] Literal *getIdentifier() const { return var; }
[[nodiscard]] [[maybe_unused]] Expression *getLeftHandSide() const { return lhs; }
[[nodiscard]] [[maybe_unused]] TokenType getOperator() const { return op; }
[[nodiscard]] [[maybe_unused]] Expression *getExpression() const { return expr; }
[[nodiscard]] [[maybe_unused]] Expression *getRightHandSide() const { return rhs; }

std::string toString(llvm::SourceMgr *srcMgr, int ind) override {
return fmt::format("{}Assignment[Line({}-{}):Col({}-{})]: {} {} {}\n",
Expand All @@ -365,9 +365,9 @@ namespace lesma {
srcMgr->getLineAndColumn(getEnd()).first,
srcMgr->getLineAndColumn(getStart()).second,
srcMgr->getLineAndColumn(getEnd()).second,
var->toString(srcMgr, ind),
lhs->toString(srcMgr, ind),
std::string{NAMEOF_ENUM(op)},
expr->toString(srcMgr, ind));
rhs->toString(srcMgr, ind));
}
};

Expand Down
117 changes: 79 additions & 38 deletions src/liblesma/Backend/Codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,72 +528,87 @@ void Codegen::visit(ExternFuncDecl *node) {
}

void Codegen::visit(Assignment *node) {
auto symbol = Scope->lookup(node->getIdentifier()->getValue());
if (symbol == nullptr)
throw CodegenError(node->getSpan(), "Variable not found: {}", node->getIdentifier()->getValue());
if (!symbol->getMutability())
throw CodegenError(node->getSpan(), "Assigning immutable variable a new value");
llvm::Type *lhs_type;
llvm::Value *lhs_val;
isAssignment = true;
if (dynamic_cast<Literal *>(node->getLeftHandSide())) {
auto lit = dynamic_cast<Literal *>(node->getLeftHandSide());
auto symbol = Scope->lookup(lit->getValue());
if (symbol == nullptr)
throw CodegenError(node->getSpan(), "Variable not found: {}", lit->getValue());
if (!symbol->getMutability())
throw CodegenError(node->getSpan(), "Assigning immutable variable a new value");

lhs_type = symbol->getLLVMType();
lhs_val = symbol->getLLVMValue();
} else if (dynamic_cast<DotOp *>(node->getLeftHandSide())) {
lhs_val = visit(node->getLeftHandSide());
lhs_type = lhs_val->getType();
} else {
throw CodegenError(node->getSpan(), "Unable to assign {} to {}", node->getRightHandSide()->toString(SourceManager.get(), 0), node->getLeftHandSide()->toString(SourceManager.get(), 0));
}
isAssignment = false;

auto value = visit(node->getExpression());
value = Cast(node->getSpan(), value, symbol->getLLVMType(), true);
auto value = visit(node->getRightHandSide());
value = Cast(node->getSpan(), value, lhs_type, true);
llvm::Value *var_val;

switch (node->getOperator()) {
case TokenType::EQUAL:
Builder->CreateStore(value, symbol->getLLVMValue());
Builder->CreateStore(value, lhs_val);
break;
case TokenType::PLUS_EQUAL:
var_val = Builder->CreateLoad(symbol->getLLVMType(), symbol->getLLVMValue(), ".tmp");
if (symbol->getLLVMType()->isFloatingPointTy()) {
var_val = Builder->CreateLoad(lhs_type, lhs_val, ".tmp");
if (lhs_type->isFloatingPointTy()) {
auto new_val = Builder->CreateFAdd(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
} else if (symbol->getLLVMType()->isIntegerTy()) {
Builder->CreateStore(new_val, lhs_val);
} else if (lhs_type->isIntegerTy()) {
auto new_val = Builder->CreateAdd(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
Builder->CreateStore(new_val, lhs_val);
} else
throw CodegenError(node->getSpan(), "Invalid operator: {}", NAMEOF_ENUM(node->getOperator()));
break;
case TokenType::MINUS_EQUAL:
var_val = Builder->CreateLoad(symbol->getLLVMType(), symbol->getLLVMValue(), ".tmp");
if (symbol->getLLVMType()->isFloatingPointTy()) {
var_val = Builder->CreateLoad(lhs_type, lhs_val, ".tmp");
if (lhs_type->isFloatingPointTy()) {
auto new_val = Builder->CreateFSub(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
} else if (symbol->getLLVMType()->isIntegerTy()) {
Builder->CreateStore(new_val, lhs_val);
} else if (lhs_type->isIntegerTy()) {
auto new_val = Builder->CreateSub(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
Builder->CreateStore(new_val, lhs_val);
} else
throw CodegenError(node->getSpan(), "Invalid operator: {}", NAMEOF_ENUM(node->getOperator()));
break;
case TokenType::SLASH_EQUAL:
var_val = Builder->CreateLoad(symbol->getLLVMType(), symbol->getLLVMValue(), ".tmp");
if (symbol->getLLVMType()->isFloatingPointTy()) {
var_val = Builder->CreateLoad(lhs_type, lhs_val, ".tmp");
if (lhs_type->isFloatingPointTy()) {
auto new_val = Builder->CreateFDiv(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
} else if (symbol->getLLVMType()->isIntegerTy()) {
Builder->CreateStore(new_val, lhs_val);
} else if (lhs_type->isIntegerTy()) {
auto new_val = Builder->CreateSDiv(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
Builder->CreateStore(new_val, lhs_val);
} else
throw CodegenError(node->getSpan(), "Invalid operator: {}", NAMEOF_ENUM(node->getOperator()));
break;
case TokenType::STAR_EQUAL:
var_val = Builder->CreateLoad(symbol->getLLVMType(), symbol->getLLVMValue(), ".tmp");
if (symbol->getLLVMType()->isFloatingPointTy()) {
var_val = Builder->CreateLoad(lhs_type, lhs_val, ".tmp");
if (lhs_type->isFloatingPointTy()) {
auto new_val = Builder->CreateFMul(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
} else if (symbol->getLLVMType()->isIntegerTy()) {
Builder->CreateStore(new_val, lhs_val);
} else if (lhs_type->isIntegerTy()) {
auto new_val = Builder->CreateMul(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
Builder->CreateStore(new_val, lhs_val);
} else
throw CodegenError(node->getSpan(), "Invalid operator: {}", NAMEOF_ENUM(node->getOperator()));
break;
case TokenType::MOD_EQUAL:
var_val = Builder->CreateLoad(symbol->getLLVMType(), symbol->getLLVMValue(), ".tmp");
if (symbol->getLLVMType()->isFloatingPointTy()) {
var_val = Builder->CreateLoad(lhs_type, lhs_val, ".tmp");
if (lhs_type->isFloatingPointTy()) {
auto new_val = Builder->CreateFRem(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
} else if (symbol->getLLVMType()->isIntegerTy()) {
Builder->CreateStore(new_val, lhs_val);
} else if (lhs_type->isIntegerTy()) {
auto new_val = Builder->CreateSRem(value, var_val, ".tmp");
Builder->CreateStore(new_val, symbol->getLLVMValue());
Builder->CreateStore(new_val, lhs_val);
} else
throw CodegenError(node->getSpan(), "Invalid operator: {}", NAMEOF_ENUM(node->getOperator()));
break;
Expand Down Expand Up @@ -910,7 +925,7 @@ llvm::Value *Codegen::visit(DotOp *node) {

auto struct_ty = Scope->lookup(left->getValue());
auto enum_ptr = Builder->CreateAlloca(struct_ty->getLLVMType(), nullptr, ".tmp");
auto field = Builder->CreateInBoundsGEP(struct_ty->getLLVMType(), enum_ptr, {Builder->getInt32(0), Builder->getInt32(0)});
auto field = Builder->CreateStructGEP(struct_ty->getLLVMType(), enum_ptr, 0);
Builder->CreateStore(Builder->getInt8(val), field);

return enum_ptr;
Expand Down Expand Up @@ -939,7 +954,11 @@ llvm::Value *Codegen::visit(DotOp *node) {
if (index == -1)
throw CodegenError(node->getRight()->getSpan(), "Could not find field {} in {}", field, val->getType()->getStructName().str());

return Builder->CreateInBoundsGEP(cls->getLLVMType(), val, {Builder->getInt32(0), Builder->getInt32(index)});
auto ptr = Builder->CreateStructGEP(cls->getLLVMType(), val, index);
if (isAssignment)
return ptr;
auto x = cls->getType()->getFields()[index];
return Builder->CreateLoad(ptr->getType()->getPointerElementType(), ptr);
} else if (method != nullptr) {
classSymbol = cls;
auto ret_val = genFuncCall(method, {val});
Expand Down Expand Up @@ -1102,7 +1121,8 @@ llvm::Value *Codegen::Cast(llvm::SMRange span, llvm::Value *val, llvm::Type *typ
}

llvm::Value *Codegen::Cast(llvm::SMRange span, llvm::Value *val, llvm::Type *type, bool isStore) {
if (type == nullptr || val->getType() == type || (isStore && val->getType() == type->getPointerTo()) || (val->getType()->isPointerTy() && val->getType()->getPointerElementType()->isStructTy()))
// TODO: Fix me pls
if (type == nullptr || val->getType() == type || (isStore && val->getType()->getPointerTo(0) == type) || (isStore && val->getType() == type->getPointerTo()) || (val->getType()->isPointerTy() && val->getType()->getPointerElementType()->isStructTy()))
return val;

if (type->isIntegerTy()) {
Expand All @@ -1120,7 +1140,7 @@ llvm::Value *Codegen::Cast(llvm::SMRange span, llvm::Value *val, llvm::Type *typ
throw CodegenError(span, "Unsupported Cast");
}

llvm::Value *Codegen::genFuncCall(FuncCall *node, std::vector<llvm::Value *> extra_params = {}) {
llvm::Value *Codegen::genFuncCall(FuncCall *node, const std::vector<llvm::Value *>& extra_params = {}) {
std::vector<llvm::Value *> params;
std::vector<llvm::Type *> paramTypes;

Expand All @@ -1134,7 +1154,21 @@ llvm::Value *Codegen::genFuncCall(FuncCall *node, std::vector<llvm::Value *> ext
paramTypes.push_back(params.back()->getType());
}

auto name = getMangledName(node->getSpan(), node->getName(), paramTypes);
std::string name;
auto classSymbolTmp = classSymbol;
auto class_sym = Scope->lookup(node->getName());
Value* class_ptr = nullptr;
if (class_sym != nullptr && class_sym->getType()->is(TY_CLASS)) {
// It's a class constructor, allocate and add self param
class_ptr = Builder->CreateAlloca(class_sym->getLLVMType(), nullptr, ".tmp");
params.insert(params.begin(), class_ptr);
paramTypes.insert(paramTypes.begin(), class_ptr->getType());

classSymbol = class_sym;
name = getMangledName(node->getSpan(), "new", paramTypes);
} else {
name = getMangledName(node->getSpan(), node->getName(), paramTypes);
}
auto symbol = Scope->lookup(name);
// Get function without name mangling in case of extern C functions
symbol = symbol == nullptr ? Scope->lookup(node->getName()) : symbol;
Expand All @@ -1146,6 +1180,13 @@ llvm::Value *Codegen::genFuncCall(FuncCall *node, std::vector<llvm::Value *> ext
throw CodegenError(node->getSpan(), "Symbol {} is not a function.", node->getName());

auto *func = dyn_cast<Function>(symbol->getLLVMValue());
if (class_sym != nullptr && class_sym->getType()->is(TY_CLASS)) {
Builder->CreateCall(func, params, func->getReturnType()->isVoidTy() ? "" : "tmp");
auto val = Builder->CreateLoad(class_sym->getLLVMType(), class_ptr);
classSymbol = classSymbolTmp;

return val;
}
return Builder->CreateCall(func, params, func->getReturnType()->isVoidTy() ? "" : "tmp");
}

Expand Down
9 changes: 5 additions & 4 deletions src/liblesma/Backend/Codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace lesma {
SymbolTableEntry *classSymbol = nullptr;
bool isBreak = false;
bool isReturn = false;
bool isAssignment = false;
bool isJIT = false;
bool isMain = true;

Expand Down Expand Up @@ -97,13 +98,13 @@ namespace lesma {
llvm::Value *visit(Else *node) override;

// TODO: Helper functions, move them out somewhere
SymbolType *getType(llvm::Type *type);
static SymbolType *getType(llvm::Type *type);
llvm::Value *Cast(llvm::SMRange span, llvm::Value *val, llvm::Type *type);
llvm::Value *Cast(llvm::SMRange span, llvm::Value *val, llvm::Type *type, bool isStore);
llvm::Type *GetExtendedType(llvm::Type *left, llvm::Type *right);
static llvm::Type *GetExtendedType(llvm::Type *left, llvm::Type *right);
std::string getMangledName(llvm::SMRange span, std::string func_name, const std::vector<llvm::Type *> &paramTypes);
std::string getTypeMangledName(llvm::SMRange span, llvm::Type *type);
llvm::Value *genFuncCall(FuncCall *node, std::vector<llvm::Value *> extra_params);
int FindIndexInFields(SymbolType *_struct, const std::string &field);
llvm::Value *genFuncCall(FuncCall *node, const std::vector<llvm::Value *>& extra_params);
static int FindIndexInFields(SymbolType *_struct, const std::string &field);
};
}// namespace lesma
28 changes: 21 additions & 7 deletions src/liblesma/Frontend/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ bool Parser::CheckAny() {
return CheckAny<type, remained_types...>(0);
}

template<TokenType type, TokenType... remained_types>
bool Parser::CheckAnyInLine() {
int i = 0;
while (!CheckAny<TokenType::NEWLINE, TokenType::EOF_TOKEN>(i)) {
if (CheckAny<type, remained_types...>(i))
return true;
i++;
}

return false;
}

template<TokenType type, TokenType... remained_types>
bool Parser::CheckAny(unsigned long pos) {
if (!Check(type, pos)) {
Expand Down Expand Up @@ -322,16 +334,19 @@ Statement *Parser::ParseFor() {
}

Statement *Parser::ParseAssignment() {
auto identifier = Consume(TokenType::IDENTIFIER);
auto var = new Literal(identifier->span, identifier->lexeme, identifier->type);

auto identifier = ParseDot();

if (!(dynamic_cast<Literal *>(identifier) && dynamic_cast<Literal *>(identifier)->getType() == TokenType::IDENTIFIER) && !dynamic_cast<DotOp *>(identifier))
throw ParserError(identifier->getSpan(), "Expected either identifier or class field for assignment");

if (AdvanceIfMatchAny<TokenType::EQUAL, TokenType::PLUS_EQUAL, TokenType::MINUS_EQUAL, TokenType::STAR_EQUAL,
TokenType::SLASH_EQUAL, TokenType::MOD_EQUAL, TokenType::POWER_EQUAL>()) {
auto op = Previous()->type;
auto expr = ParseExpression();

ConsumeNewline();
return new Assignment({identifier->getStart(), expr->getEnd()}, var, op, expr);
return new Assignment({identifier->getStart(), expr->getEnd()}, identifier, op, expr);
}

Error(Peek(), fmt::format("Unsupported assignment operator: {}", Peek()->lexeme));
Expand Down Expand Up @@ -399,9 +414,8 @@ Statement *Parser::ParseStatement(bool isTopLevel) {
return ParseReturn();
else if (Check(TokenType::DEFER))
return ParseDefer();
else if (Check(TokenType::IDENTIFIER) &&
CheckAny<TokenType::EQUAL, TokenType::PLUS_EQUAL, TokenType::MINUS_EQUAL, TokenType::STAR_EQUAL,
TokenType::SLASH_EQUAL, TokenType::MOD_EQUAL, TokenType::POWER_EQUAL>(1))
else if (CheckAnyInLine<TokenType::EQUAL, TokenType::PLUS_EQUAL, TokenType::MINUS_EQUAL, TokenType::STAR_EQUAL,
TokenType::SLASH_EQUAL, TokenType::MOD_EQUAL, TokenType::POWER_EQUAL>())
return ParseAssignment();

// Check if it's just an expression statement
Expand Down Expand Up @@ -474,7 +488,7 @@ Statement *Parser::ParseFunctionDeclaration() {

auto body = ParseBlock();

return new FuncDecl({loc.Start, body->getEnd()}, identifier->lexeme, return_type, parameters, body);
return new FuncDecl({loc.Start, return_type->getEnd()}, identifier->lexeme, return_type, parameters, body);
}

Statement *Parser::ParseImport() {
Expand Down
3 changes: 3 additions & 0 deletions src/liblesma/Frontend/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ namespace lesma {
template<TokenType type, TokenType... remained_types>
bool CheckAny();

template<TokenType type, TokenType... remained_types>
bool CheckAnyInLine();

template<TokenType type, TokenType... remained_types>
bool CheckAny(unsigned long pos);

Expand Down
13 changes: 7 additions & 6 deletions tests/lesma/class.les
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ def extern exit(x: int)

class Animal
var x: int
let y: float = 5.5

def hello() -> int
return 101
def new()
self.x = 101

var z: Animal
z.x
exit(z.hello())
def getX() -> int
return self.x

var z = Animal()
exit(z.getX())

0 comments on commit 3675e51

Please sign in to comment.