Skip to content

Commit

Permalink
feat: default values for functions now fully working!
Browse files Browse the repository at this point in the history
  • Loading branch information
alinalihassan committed Apr 23, 2023
1 parent 1e9aa05 commit 118f179
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 71 deletions.
1 change: 1 addition & 0 deletions ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- [x] Improve Visitor Pattern
- [x] Migrate to LLVM 16 with custom Value and Type classes
- [ ] Fix multiple imports not working
- [ ] Add default values in function declarations
- [ ] Add operator overloading
- [ ] Add generics
Expand Down
115 changes: 75 additions & 40 deletions src/liblesma/Backend/Codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,16 @@ void Codegen::CompileModule(llvm::SMRange span, const std::string &filepath, boo

if (isJIT) {
// Insert the function declaration, since we linked the modules earlier
F = llvm::cast<Function>(TheModule->getOrInsertFunction(sym.first, FTy).getCallee());
F = llvm::cast<Function>(TheModule->getOrInsertFunction(sym.second->getMangledName(), FTy).getCallee());
}

auto name = std::string{sym.first};
std::vector<llvm::Type *> paramTypes;
for (unsigned param_i = 0; param_i < FTy->getNumParams(); param_i++)
paramTypes.push_back(FTy->getParamType(param_i));
std::vector<lesma::Type *> paramTypes;
for (auto field: sym.second->getType()->getFields()) {
paramTypes.push_back(field->type);
}

Value *func_symbol = codegen->Scope->lookup(name);
Value *func_symbol = codegen->Scope->lookupFunction(name, paramTypes);

// Only import if it's exported
auto demangled_name = getDemangledName(name);
Expand All @@ -234,12 +235,14 @@ void Codegen::CompileModule(llvm::SMRange span, const std::string &filepath, boo
symbol->getType()->setLLVMType(FTy);
symbol->setLLVMValue(F);
symbol->setExported(false);
symbol->setMangledName(sym.second->getMangledName());
} else {
// If it's compiled, we need to make a new Function declaration in the importing file
auto new_func = Function::Create(FTy, Function::ExternalLinkage, name, *TheModule);
symbol->getType()->setLLVMType(new_func->getFunctionType());
symbol->setLLVMValue(new_func);
symbol->setExported(false);
symbol->setMangledName(sym.second->getMangledName());
}
Scope->insertSymbol(symbol);
}
Expand Down Expand Up @@ -406,19 +409,21 @@ void Codegen::visit(const TypeExpr *node) {
} else if (node->getType() == TokenType::FUNC_TYPE) {
node->getReturnType()->accept(*this);
auto ret_type = result;
std::vector<std::unique_ptr<Field>> paramTypes;
std::vector<Field *> fields;
std::vector<lesma::Type *> paramTypes;
std::vector<llvm::Type *> paramLLVMTypes;
for (auto param_type: node->getParams()) {
param_type->accept(*this);
paramLLVMTypes.push_back(result->getType()->getLLVMType());
paramTypes.push_back(std::make_unique<Field>(Field{result->getName(), result->getType()}));
paramTypes.push_back(result->getType());
fields.push_back(new Field{result->getName(), result->getType()});
}

llvm::Type *funcType = FunctionType::get(ret_type->getType()->getLLVMType(), paramLLVMTypes, false)->getPointerTo();
result = new lesma::Value(new lesma::Type(TY_FUNCTION, funcType, std::move(paramTypes)));
result = new lesma::Value(new lesma::Type(TY_FUNCTION, funcType, std::move(fields)));
} else if (node->getType() == TokenType::CUSTOM_TYPE) {
auto typ = Scope->lookupType(node->getName());
auto sym = Scope->lookup(node->getName());
auto sym = Scope->lookupStruct(node->getName());
if (typ == nullptr || sym->getType()->getLLVMType() == nullptr)
throw CodegenError(node->getSpan(), "Type not found: {}", node->getName());

Expand Down Expand Up @@ -559,14 +564,14 @@ void Codegen::visit(const FuncDecl *node) {
if (selfSymbol != nullptr && node->getName() == "new" && node->getReturnType()->getType() != TokenType::VOID_TYPE)
throw CodegenError(node->getSpan(), "Cannot create class method new with return type {}", node->getReturnType()->getName());

std::vector<std::unique_ptr<Field>> fields;
std::vector<Field *> fields;
std::vector<lesma::Type *> paramTypes;
std::vector<llvm::Type *> paramLLVMTypes;

if (selfSymbol != nullptr) {
paramTypes.push_back(selfSymbol->getType());
paramLLVMTypes.push_back(selfSymbol->getType()->getLLVMType()->getPointerTo());
fields.push_back(std::make_unique<Field>(Field{"self", selfSymbol->getType()}));
fields.push_back(new Field{"self", selfSymbol->getType()});
}

for (auto param: node->getParameters()) {
Expand Down Expand Up @@ -597,20 +602,21 @@ void Codegen::visit(const FuncDecl *node) {

paramTypes.push_back(typeResult->getType());
paramLLVMTypes.push_back(typeResult->getType()->getLLVMType());
fields.push_back(std::make_unique<Field>(Field{param->name, typeResult->getType(), defaultValResult}));
fields.push_back(new Field{param->name, typeResult->getType(), defaultValResult});
}

auto name = getMangledName(node->getSpan(), node->getName(), paramTypes, selfSymbol != nullptr);
auto mangledName = getMangledName(node->getSpan(), node->getName(), paramTypes, selfSymbol != nullptr);
auto linkage = node->isExported() ? Function::ExternalLinkage : Function::PrivateLinkage;

node->getReturnType()->accept(*this);

llvm::FunctionType *funcType = FunctionType::get(result->getType()->getLLVMType(), paramLLVMTypes, node->getVarArgs());
Function *F = Function::Create(funcType, linkage, name, *TheModule);
Function *F = Function::Create(funcType, linkage, mangledName, *TheModule);

auto func_symbol = new Value(name, new Type(BaseType::TY_FUNCTION, funcType, std::move(fields)), F);
auto func_symbol = new Value(node->getName(), new Type(BaseType::TY_FUNCTION, funcType, std::move(fields)), F);
func_symbol->getType()->setReturnType(result->getType());
func_symbol->setExported(node->isExported());
func_symbol->setMangledName(mangledName);
Scope->insertSymbol(func_symbol);

Prototypes.emplace_back(func_symbol, node, selfSymbol);
Expand All @@ -619,25 +625,46 @@ void Codegen::visit(const FuncDecl *node) {
}

void Codegen::visit(const ExternFuncDecl *node) {
std::vector<std::unique_ptr<Field>> paramTypes;
std::vector<Field *> fields;
std::vector<lesma::Type *> paramTypes;
std::vector<llvm::Type *> paramLLVMTypes;

if (selfSymbol != nullptr) {
paramTypes.push_back(std::make_unique<Field>(Field{"self", selfSymbol->getType()}));
paramLLVMTypes.push_back(selfSymbol->getType()->getLLVMType()->getPointerTo());
}

for (auto param: node->getParameters()) {
param->type->accept(*this);
paramLLVMTypes.push_back(result->getType()->getLLVMType());
paramTypes.push_back(std::make_unique<Field>(Field{result->getName(), result->getType()}));
lesma::Value *typeResult = nullptr;
lesma::Value *defaultValResult = nullptr;

// Check if it has either a type or a value or both
if (param->type) {
param->type->accept(*this);
typeResult = result;
}
if (param->default_val) {
param->default_val->accept(*this);
defaultValResult = result;
if (!typeResult) {
typeResult = defaultValResult;
}
}

// If it's a class type, we mean to pass a pointer to a class
if (typeResult->getType()->is(TY_CLASS)) {
typeResult = new Value("", new Type(TY_PTR, Builder->getPtrTy(), result->getType()));
}

if (defaultValResult != nullptr && *typeResult->getType() != *defaultValResult->getType()) {
throw CodegenError(node->getSpan(), "Declared parameter type and default value do not match for {}", param->name);
}

paramTypes.push_back(typeResult->getType());
paramLLVMTypes.push_back(typeResult->getType()->getLLVMType());
fields.push_back(new Field{param->name, typeResult->getType(), defaultValResult});
}

node->getReturnType()->accept(*this);
auto ret_type = result->getType();

Function *F;
if (TheModule->getFunction(node->getName()) != nullptr && Scope->lookup(node->getName()) != nullptr)
if (TheModule->getFunction(node->getName()) != nullptr && Scope->lookupFunction(node->getName(), paramTypes) != nullptr)
return;
else if (TheModule->getFunction(node->getName()) != nullptr) {
F = TheModule->getFunction(node->getName());
Expand All @@ -649,7 +676,7 @@ void Codegen::visit(const ExternFuncDecl *node) {
}
}

auto func_symbol = new Value(node->getName(), new Type(BaseType::TY_FUNCTION, F->getFunctionType(), std::move(paramTypes)), F);
auto func_symbol = new Value(node->getName(), new Type(BaseType::TY_FUNCTION, F->getFunctionType(), fields), F);
func_symbol->getType()->setReturnType(ret_type);
func_symbol->setExported(node->isExported());
Scope->insertSymbol(func_symbol);
Expand Down Expand Up @@ -810,7 +837,7 @@ void Codegen::visit(const Import *node) {
}

void Codegen::visit(const Class *node) {
std::vector<std::unique_ptr<Field>> fields;
std::vector<Field *> fields;
std::vector<llvm::Type *> elementLLVMTypes;

for (auto field: node->getFields()) {
Expand All @@ -821,7 +848,7 @@ void Codegen::visit(const Class *node) {
}

elementLLVMTypes.push_back(result->getType()->getLLVMType());
fields.push_back(std::make_unique<Field>(Field{field->getIdentifier()->getValue(), result->getType()}));
fields.push_back(new Field{field->getIdentifier()->getValue(), result->getType(), field->getValue().has_value() ? result : nullptr});
}

llvm::StructType *structType = llvm::StructType::create(*TheContext->getContext(), elementLLVMTypes, node->getIdentifier());
Expand Down Expand Up @@ -852,10 +879,10 @@ void Codegen::visit(const Class *node) {
void Codegen::visit(const Enum *node) {
std::vector<llvm::Type *> elementTypes = {Builder->getInt8Ty()};
llvm::StructType *structType = llvm::StructType::create(*TheContext->getContext(), elementTypes, node->getIdentifier());
std::vector<std::unique_ptr<Field>> fields;
std::vector<Field *> fields;

for (const auto &field: node->getValues())
fields.push_back(std::make_unique<Field>(Field{field, new Type(TY_VOID, Builder->getVoidTy())}));
fields.push_back(new Field{field, new Type(TY_VOID, Builder->getVoidTy())});

auto *type = new Type(TY_ENUM, structType, std::move(fields));
auto *structSymbol = new Value(node->getIdentifier(), type);
Expand Down Expand Up @@ -1121,7 +1148,7 @@ void Codegen::visit(const DotOp *node) {
if (val == -1)
throw CodegenError(node->getLeft()->getSpan(), "Identifier {} not in {}", right->getValue(), left->getValue());

auto struct_val = Scope->lookup(left->getValue());
auto struct_val = Scope->lookupStruct(left->getValue());
auto enum_ptr = Builder->CreateAlloca(struct_val->getType()->getLLVMType());
auto field = Builder->CreateStructGEP(struct_val->getType()->getLLVMType(), enum_ptr, 0);
Builder->CreateStore(Builder->getInt8(val), field);
Expand Down Expand Up @@ -1177,7 +1204,7 @@ void Codegen::visit(const DotOp *node) {

// TODO: Somehow, when we call a class method with a variable x,
// we lose the class name from cls, so we set it again
auto cls = Scope->lookupStructByName(lesma_type->getLLVMType()->getStructName().str());
auto cls = Scope->lookupStruct(lesma_type->getLLVMType()->getStructName().str());
cls->setName(lesma_type->getLLVMType()->getStructName().str());

if (cls->getType()->is(TY_CLASS)) {
Expand Down Expand Up @@ -1470,9 +1497,10 @@ lesma::Value *Codegen::genFuncCall(const FuncCall *node, const std::vector<lesma
paramsLLVM.push_back(result->getLLVMValue());
}

std::string name;
Value *symbol;
// Check if it's a constructor like `Classname()`
auto selfSymbolTmp = selfSymbol;
auto class_sym = Scope->lookup(node->getName());
auto class_sym = Scope->lookupStruct(node->getName());
llvm::Value *class_ptr = nullptr;
if (class_sym != nullptr && class_sym->getType()->is(TY_CLASS)) {
// It's a class constructor, allocate and add self param
Expand All @@ -1481,22 +1509,29 @@ lesma::Value *Codegen::genFuncCall(const FuncCall *node, const std::vector<lesma
paramTypes.insert(paramTypes.begin(), new Type(TY_PTR, Builder->getPtrTy(), class_sym->getType()));

selfSymbol = class_sym;
name = getMangledName(node->getSpan(), "new", paramTypes, true);
symbol = Scope->lookupFunction("new", paramTypes);
} else {
name = getMangledName(node->getSpan(), node->getName(), paramTypes, !extra_params.empty());
symbol = Scope->lookupFunction(node->getName(), paramTypes);
}

auto symbol = Scope->lookup(name);
// Get function without name mangling in case of extern C functions
symbol = symbol == nullptr ? Scope->lookup(node->getName()) : symbol;

if (symbol == nullptr) {
throw CodegenError(node->getSpan(), "Function {} not in current scope.", node->getName());
}

if (!symbol->getType()->isOneOf({TY_CLASS, TY_FUNCTION}))
throw CodegenError(node->getSpan(), "Symbol {} is not a function or constructor.", node->getName());

if (symbol->getType()->getFields().size() > paramsLLVM.size()) {
auto fields = symbol->getType()->getFields();
size_t start = paramsLLVM.size();

for (auto it = fields.begin() + start; it != fields.end(); ++it) {
if (!(*it)->defaultValue) {
throw CodegenError(node->getSpan(), "Something bad happened, lookup found a function with incorrect defaults", node->getName());
}
paramsLLVM.push_back((*it)->defaultValue->getLLVMValue());
}
}
auto *func = cast<Function>(symbol->getType()->is(TY_CLASS) ? symbol->getConstructor()->getLLVMValue() : symbol->getLLVMValue());
if (class_sym != nullptr && class_sym->getType()->is(TY_CLASS)) {
Builder->CreateCall(func, paramsLLVM);
Expand Down
73 changes: 65 additions & 8 deletions src/liblesma/Symbol/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace lesma;
* @param entry Symbol Table Entry
*/
void SymbolTable::insertSymbol(Value *entry) {
symbols.insert_or_assign(entry->getName(), entry);
symbols.emplace(entry->getName(), entry);
}

/**
Expand All @@ -21,19 +21,73 @@ void SymbolTable::insertType(const std::string &name, Type *type) {
types.insert_or_assign(name, type);
}

/**
* Check if a symbol exists in the current or any parent scope and return it if possible
*
* @param name Name of the desired symbol
* @return Desired symbol / nullptr if the symbol was not found
*/
Value *SymbolTable::lookupFunction(const std::string &name, std::vector<lesma::Type *> paramTypes) {
auto range = symbols.equal_range(name);
for (auto it = range.first; it != range.second; ++it) {
if (!it->second->getType()->is(TY_FUNCTION))
continue;
// Check if the parameter types match
std::vector<Field *> funcParamTypes = it->second->getType()->getFields();
std::vector<llvm::Value *> tmpValues;

bool paramsMatch = true;
size_t numParams = std::max(funcParamTypes.size(), paramTypes.size());
for (size_t i = 0; i < numParams; ++i) {
if (i < funcParamTypes.size() && i < paramTypes.size()) {
if (*funcParamTypes[i]->type != *paramTypes[i]) {
paramsMatch = false;
break;
}
} else if (i < funcParamTypes.size() && funcParamTypes[i]->defaultValue != nullptr) {
// Use default value for missing parameter
paramTypes.push_back(funcParamTypes[i]->type);
} else if (i >= funcParamTypes.size() && it->second->getType()->getLLVMType()->isFunctionVarArg()) {
// Varargs
break;
} else {
paramsMatch = false;
break;
}
}

if (!paramsMatch) {
continue;// Parameter types don't match
}

return it->second;
}

if (parent == nullptr) {
return nullptr;
}

return parent->lookupFunction(name, paramTypes);
}

/**
* Check if a symbol exists in the current or any parent scope and return it if possible
*
* @param name Name of the desired symbol
* @return Desired symbol / nullptr if the symbol was not found
*/
Value *SymbolTable::lookup(const std::string &name) {
if (symbols.find(name) == symbols.end()) {
if (parent == nullptr) return nullptr;
return parent->lookup(name);
for (const auto &sym: symbols) {
if (sym.first == name) {
return sym.second;
}
}

if (parent == nullptr) {
return nullptr;
}

return symbols.at(name);
return parent->lookup(name);
}

/**
Expand All @@ -42,15 +96,18 @@ Value *SymbolTable::lookup(const std::string &name) {
* @param name Name of the desired symbol
* @return Desired symbol / nullptr if the symbol was not found
*/
Value *SymbolTable::lookupStructByName(const std::string &name) {
Value *SymbolTable::lookupStruct(const std::string &name) {
for (auto sym: symbols) {
if (sym.second->getType()->getLLVMType() != nullptr && sym.second->getType()->isOneOf({TY_CLASS, TY_ENUM}) &&
llvm::cast<llvm::StructType>(sym.second->getType()->getLLVMType())->getName() == name)
return sym.second;
}

if (parent == nullptr) return nullptr;
return parent->lookupStructByName(name);
if (parent == nullptr) {
return nullptr;
}

return parent->lookupStruct(name);
}

/**
Expand Down
Loading

0 comments on commit 118f179

Please sign in to comment.