Skip to content

Commit

Permalink
this is a huge mess but i have to compile llvm myself :(
Browse files Browse the repository at this point in the history
  • Loading branch information
kem0x committed Jan 24, 2023
1 parent 279c0c5 commit 05b0f72
Show file tree
Hide file tree
Showing 13 changed files with 472 additions and 67 deletions.
181 changes: 163 additions & 18 deletions Custom-lang/Compile/Compiler.ixx
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,59 @@ import AST.Statements;
import AST.Expressions;
import Reflection;
import <format>;
import <memory>;

static Shared<llvm::LLVMContext> TheContext = std::make_unique<llvm::LLVMContext>();
static Shared<llvm::Module> TheModule = std::make_unique<llvm::Module>("my cool jit", *TheContext);
static Shared<llvm::IRBuilder<>> Builder = std::make_unique<llvm::IRBuilder<>>(*TheContext);
static UnorderedMap<String, llvm::Value*> NamedValues;
import "KaleidoscopeJIT.h";

export Unique<llvm::LLVMContext> TheContext;
export Unique<llvm::Module> TheModule;
export Unique<llvm::IRBuilder<>> Builder;
export Unique<llvm::legacy::FunctionPassManager> TheFPM;

export Unique<llvm::orc::KaleidoscopeJIT> TheJIT;

UnorderedMap<String, llvm::Value*> NamedValues;

UnorderedMap<String, Shared<ExternDeclaration>> Externs;

llvm::Type* TypeNameToLLVMType(String typeName)
{
if (typeName == "Int")
{
return llvm::Type::getInt64Ty(*TheContext);
}
else if (typeName == "Float")
{
return llvm::Type::getDoubleTy(*TheContext);
}
else if (typeName == "Bool")
{
return llvm::Type::getInt1Ty(*TheContext);
}
else if (typeName == "String")
{
return llvm::Type::getInt8PtrTy(*TheContext);
}
else
{
Safety::Throw(std::format("Unknown type: '{}'", typeName));
}

return llvm::Type::getInt8PtrTy(*TheContext);
}

llvm::Value* Compile(Shared<Statement> node);

llvm::Value* CompileProgram(Shared<Program> node)
{
llvm::Value* lastValue = nullptr;
llvm::Value* LastValue = nullptr;

for (auto&& Stmt : node->Body)
{
lastValue = Compile(Stmt);
LastValue = Compile(Stmt);
}

return lastValue;
return LastValue;
}

llvm::Value* CompileIntLiteral(Shared<IntLiteral> node)
Expand All @@ -46,14 +81,14 @@ llvm::Value* CompileStringLiteral(Shared<StringLiteral> node)

llvm::Value* CompileIdentifier(Shared<Identifier> node)
{
auto value = NamedValues[node->Name];
auto Value = NamedValues[node->Name];

if (!value)
if (!Value)
{
Safety::Throw(std::format("Unknown variable name '{0}'", node->Name));
}

return value;
return Value;
}

llvm::Value* CompileBinaryExpr(Shared<BinaryExpr> node)
Expand All @@ -66,33 +101,137 @@ llvm::Value* CompileBinaryExpr(Shared<BinaryExpr> node)
Safety::Throw("Invalid binary expression");
}

if (node->Operator == "+")
switch (node->Operator)
{
return Builder->CreateFAdd(Left, Right, "addtmp");
case '+':
return Builder->CreateAdd(Left, Right, "addtmp");
case '-':
return Builder->CreateSub(Left, Right, "subtmp");
case '*':
return Builder->CreateMul(Left, Right, "multmp");
case '/':
return Builder->CreateUDiv(Left, Right, "divtmp");
default:
Safety::Throw(std::format("Invalid binary operator: '{0}'", node->Operator));
}
else if (node->Operator == "-")

return nullptr;
}

llvm::Value* CompileFunctionDeclaration(Shared<FunctionDeclaration> node)
{
Vector<llvm::Type*> ArgsTypes;

for (auto&& Arg : node->Parameters)
{
return Builder->CreateFSub(Left, Right, "subtmp");
ArgsTypes.push_back(TypeNameToLLVMType(Arg.first));
}
else if (node->Operator == "*")

auto FunctionType = llvm::FunctionType::get(TypeNameToLLVMType(node->ReturnType), ArgsTypes, false);

auto Function = llvm::Function::Create(FunctionType, llvm::Function::ExternalLinkage, node->Name, *TheModule);

auto BB = llvm::BasicBlock::Create(*TheContext, "entry", Function);

Builder->SetInsertPoint(BB);

for (auto i = 0; i < node->Parameters.size(); i++)
{
return Builder->CreateFMul(Left, Right, "multmp");
auto Param = node->Parameters[i];

auto Arg = Function->arg_begin() + i;

Arg->setName(Param.second);

NamedValues[Param.second] = Arg;
}
else if (node->Operator == "/")

llvm::Value* ReturnValue = nullptr;

for (auto&& Stmt : node->Body)
{
ReturnValue = Compile(Stmt);
}

if (node->ReturnType == "Void" || !ReturnValue)
{
return Builder->CreateFDiv(Left, Right, "divtmp");
Builder->CreateRetVoid();
}
else
{
Builder->CreateRet(ReturnValue);
}

llvm::verifyFunction(*Function);

TheFPM->run(*Function);

return Function;
}

llvm::Value* CompileCallExpr(Shared<CallExpr> node)
{
auto Name = node->Callee->As<Identifier>()->Name;

if (auto CalleeF = TheModule->getFunction(Name))
{
if (CalleeF->arg_size() != node->Arguments.size())
{
Safety::Throw(std::format("Incorrect number of arguments passed to '{0}'", Name));
}

Vector<llvm::Value*> ArgsV;

for (auto&& Arg : node->Arguments)
{
ArgsV.push_back(Compile(Arg));
}

return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
}

llvm::Value* CompileExternDeclaration(Shared<ExternDeclaration> node)
{
Externs[node->Identifier] = node;

return nullptr;
}

export llvm::Value* Compile(Shared<Statement> node)
{
// auto expr = std::dynamic_pointer_cast<Expression>(node);

// if (expr)
{
// if (expr->bIsTopLevel)
{
auto RT = TheJIT->getMainJITDylib().createResourceTracker();

auto TSM = llvm::orc::ThreadSafeModule(std::move(TheModule), std::move(TheContext));
TheJIT->addModule(std::move(TSM), RT);

// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = TheJIT->lookup("__anon_expr");

// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = (double (*)())(intptr_t)ExprSymbol->getAddress();
fprintf(stderr, "Evaluated to %f\n", FP());

// Delete the anonymous expression module from the JIT.
RT->remove();
}
}

switch (node->Type)
{
case ASTNodeType::Program:
return CompileProgram(node->As<Program>());

case ASTNodeType::Identifier:
return CompileIdentifier(node->As<Identifier>());

case ASTNodeType::IntLiteral:
return CompileIntLiteral(node->As<IntLiteral>());

Expand All @@ -105,6 +244,12 @@ export llvm::Value* Compile(Shared<Statement> node)
case ASTNodeType::BinaryExpr:
return CompileBinaryExpr(node->As<BinaryExpr>());

case ASTNodeType::FunctionDeclaration:
return CompileFunctionDeclaration(node->As<FunctionDeclaration>());

case ASTNodeType::CallExpr:
return CompileCallExpr(node->As<CallExpr>());

default:
return Safety::Throw<llvm::Value*>(std::format("Unknown AST node type: {}", Reflection::EnumToString(node->Type)));
}
Expand Down
116 changes: 116 additions & 0 deletions Custom-lang/Compile/KaleidoscopeJIT.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains a simple JIT definition for use in the kaleidoscope tutorials.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H

#include "llvm/ADT/StringRef.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/LLVMContext.h"
#include <memory>

namespace llvm
{
namespace orc
{

class KaleidoscopeJIT
{
private:
std::unique_ptr<ExecutionSession> ES;

DataLayout DL;
MangleAndInterner Mangle;

RTDyldObjectLinkingLayer ObjectLayer;
IRCompileLayer CompileLayer;

JITDylib& MainJD;

public:
KaleidoscopeJIT(std::unique_ptr<ExecutionSession> ES,
JITTargetMachineBuilder JTMB, DataLayout DL)
: ES(std::move(ES))
, DL(std::move(DL))
, Mangle(*this->ES, this->DL)
, ObjectLayer(*this->ES,
[]()
{ return std::make_unique<SectionMemoryManager>(); })
, CompileLayer(*this->ES, ObjectLayer,
std::make_unique<ConcurrentIRCompiler>(std::move(JTMB)))
, MainJD(this->ES->createBareJITDylib("<main>"))
{
MainJD.addGenerator(
cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
DL.getGlobalPrefix())));
if (JTMB.getTargetTriple().isOSBinFormatCOFF())
{
ObjectLayer.setOverrideObjectFlagsWithResponsibilityFlags(true);
ObjectLayer.setAutoClaimResponsibilityForObjectSymbols(true);
}
}

~KaleidoscopeJIT()
{
if (auto Err = ES->endSession())
ES->reportError(std::move(Err));
}

static Expected<std::unique_ptr<KaleidoscopeJIT>> Create()
{
auto EPC = SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto ES = std::make_unique<ExecutionSession>(std::move(*EPC));

JITTargetMachineBuilder JTMB(
ES->getExecutorProcessControl().getTargetTriple());

auto DL = JTMB.getDefaultDataLayoutForTarget();
if (!DL)
return DL.takeError();

return std::make_unique<KaleidoscopeJIT>(std::move(ES), std::move(JTMB),
std::move(*DL));
}

const DataLayout& getDataLayout() const { return DL; }

JITDylib& getMainJITDylib() { return MainJD; }

Error addModule(ThreadSafeModule TSM, ResourceTrackerSP RT = nullptr)
{
if (!RT)
RT = MainJD.getDefaultResourceTracker();
return CompileLayer.add(RT, std::move(TSM));
}

Expected<JITEvaluatedSymbol> lookup(StringRef Name)
{
return ES->lookup({ &MainJD }, Mangle(Name.str()));
}
};

} // end namespace orc
} // end namespace llvm

#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
15 changes: 14 additions & 1 deletion Custom-lang/Compile/LLVM.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
#pragma once

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
Loading

0 comments on commit 05b0f72

Please sign in to comment.