Skip to content

Commit

Permalink
[LLVM] Make compilable with LLVM-20 (#17547)
Browse files Browse the repository at this point in the history
* [LLVM] Make compilable with LLVM-20

---------

Co-authored-by: Renat Idrisov <[email protected]>
  • Loading branch information
parsifal-47 and parsifal-47 authored Dec 2, 2024
1 parent d94f115 commit 513c2be
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 71 deletions.
13 changes: 12 additions & 1 deletion src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}

buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
buf,
llvmGetPointerTo(DTypeToLLVMType(op->dtype), buf->getType()->getPointerAddressSpace()));
ICHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
Expand Down Expand Up @@ -187,7 +188,12 @@ class CodeGenAMDGPU : public CodeGenLLVM {
LOG(FATAL) << "unknown workgroup idx";
}
}
#if TVM_LLVM_VERSION >= 200
llvm::Function* f = llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module_.get(), intrin_id, {}));
#else
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
#endif
llvm::Value* result = builder_->CreateCall(f, {});
return this->CreateCast(DataType::Int(32), iv->var->dtype, result);
}
Expand All @@ -197,8 +203,13 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (sync == "warp") {
return nullptr;
} else if (sync == "shared") {
#if TVM_LLVM_VERSION >= 200
llvm::Function* f = llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(
module_.get(), llvm::Intrinsic::amdgcn_s_barrier, {}));
#else
llvm::Function* f =
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::amdgcn_s_barrier);
#endif
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
auto void_ty = llvm::Type::getVoidTy(*ctx);
auto int32_ty = llvm::Type::getInt32Ty(*ctx);
auto int8_ty = llvm::Type::getInt8Ty(*ctx);
auto int8_ptr_ty = int8_ty->getPointerTo(0);
auto int8_ptr_ty = llvmGetPointerTo(int8_ty, 0);

llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
auto* tvm_dev_mblob_reg =
Expand Down
113 changes: 63 additions & 50 deletions src/target/llvm/codegen_cpu.cc

Large diffs are not rendered by default.

37 changes: 27 additions & 10 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) {
op->op.same_as(builtin::end_profile_intrinsic())) {
llvm::Value* id = MakeValue(op->args[0]);
auto instrprof_id = llvm::Intrinsic::hexagon_instrprof_custom;
#if TVM_LLVM_VERSION >= 200
llvm::Function* func = llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module_.get(), instrprof_id, {}));
#else
llvm::Function* func = llvm::Intrinsic::getDeclaration(module_.get(), instrprof_id);
#endif
llvm::GlobalVariable* name_var = module_->getGlobalVariable("handler_name");
if (!name_var) {
llvm::StringRef init_str = "lwp_handler";
Expand All @@ -220,7 +225,7 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) {
name_var = new llvm::GlobalVariable(*module_, init->getType(), true,
llvm::GlobalValue::InternalLinkage, init, "handler_name");
}
llvm::Type* t_int8_p_ = t_int8_->getPointerTo();
llvm::Type* t_int8_p_ = llvmGetPointerTo(t_int8_, 0);
return builder_->CreateCall(func, {llvm::ConstantExpr::getBitCast(name_var, t_int8_p_), id});
}
#endif
Expand All @@ -237,17 +242,23 @@ void CodeGenHexagon::CreatePrintf(const std::string& format,
llvm::Function* func = module_->getFunction(func_name);
if (func == nullptr) {
llvm::FunctionType* ftype = llvm::FunctionType::get(
t_void_, {t_int32_, t_char_->getPointerTo(), t_int32_, t_char_->getPointerTo()}, true);
t_void_, {t_int32_, llvmGetPointerTo(t_char_, 0), t_int32_, llvmGetPointerTo(t_char_, 0)},
true);
func = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, func_name, module_.get());
}

// There is no such filename/line number for this print statement
#if TVM_LLVM_VERSION >= 200
llvm::Value* filename = builder_->CreateGlobalString("generated-LLVM-code", "dummy_filename");
llvm::Value* format_str = builder_->CreateGlobalString(format, "printf_format_str");
#else
llvm::Value* filename = builder_->CreateGlobalStringPtr("generated-LLVM-code", "dummy_filename");
llvm::Value* format_str = builder_->CreateGlobalStringPtr(format, "printf_format_str");
#endif

// The value of FARF_ALWAYS_LEVEL, defined as HAP_LEVEL_HIGH
llvm::Value* level = ConstInt32(2);

// There is no such filename/line number for this print statement
llvm::Value* filename = builder_->CreateGlobalStringPtr("generated-LLVM-code", "dummy_filename");
llvm::Value* line_number = ConstInt32(1);

std::vector<llvm::Value*> func_args = {level, filename, line_number, format_str};
Expand Down Expand Up @@ -295,9 +306,9 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V

if (kind < builtin::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_array_, 0));
} else {
ICHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
ICHECK_EQ(buf->getType(), llvmGetPointerTo(t_tvm_array_, 0));
}
/* The following "kinds" are accessing the members of DLTensor:
typedef struct {
Expand Down Expand Up @@ -350,16 +361,17 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V
ICHECK_EQ(t.lanes(), 1);
ICHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_int64_, 0));
return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index));
} else if (t.is_float()) {
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_float64_, 0));
return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index));
} else {
ICHECK(t.is_handle());
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_value_, 0));
buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index);
return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()));
return TypedPointer(t_void_p_,
builder_->CreatePointerCast(buf, llvmGetPointerTo(t_void_p_, 0)));
}
}

Expand All @@ -369,7 +381,12 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V

llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID,
llvm::ArrayRef<llvm::Value*> args) {
#if TVM_LLVM_VERSION >= 200
llvm::Function* intf =
llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(module_.get(), IntID, {}));
#else
llvm::Function* intf = llvm::Intrinsic::getDeclaration(module_.get(), IntID);
#endif
#if TVM_LLVM_VERSION >= 90
auto intf_callee = llvm::FunctionCallee(intf);
#else
Expand Down
47 changes: 40 additions & 7 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target,
md_builder_.reset(new llvm::MDBuilder(*ctx));
// types
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace());
t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace());
t_int_ = llvm::Type::getInt32Ty(*ctx);
t_char_ = llvm::Type::getInt8Ty(*ctx);
t_int8_ = llvm::Type::getInt8Ty(*ctx);
Expand All @@ -169,7 +169,11 @@ void CodeGenLLVM::InitTarget() {
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
module_->setTargetTriple(tm->getTargetTriple().str());
module_->setDataLayout(tm->createDataLayout());
#if TVM_LLVM_VERSION >= 200
data_layout_.reset(new llvm::DataLayout(module_.get()->getDataLayout()));
#else
data_layout_.reset(new llvm::DataLayout(module_.get()));
#endif
if (native_vector_bits_ == 0) {
const auto& arch = tm->getTargetTriple().getArch();
if (arch == llvm::Triple::x86_64) {
Expand Down Expand Up @@ -624,7 +628,7 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
}
}
// TODO(tvm-team) consider put storage scope into the pointer type.
return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace());
return llvmGetPointerTo(GetLLVMType(ptr->element_type), GetGlobalAddressSpace());
} else if (IsVoidType(type)) {
return t_void_;
} else {
Expand Down Expand Up @@ -967,9 +971,9 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,

llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype);
llvm::PointerType* element_ptr_type =
DTypeToLLVMType(buffer_element_dtype)->getPointerTo(address_space);
llvmGetPointerTo(DTypeToLLVMType(buffer_element_dtype), address_space);
llvm::Type* value_type = DTypeToLLVMType(value_dtype);
llvm::PointerType* value_ptr_type = value_type->getPointerTo(address_space);
llvm::PointerType* value_ptr_type = llvmGetPointerTo(value_type, address_space);

ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer";

Expand Down Expand Up @@ -1012,7 +1016,11 @@ void CodeGenLLVM::CreatePrintf(const std::string& format,
llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "fflush", module_.get());
}

#if TVM_LLVM_VERSION >= 200
llvm::Value* str = builder_->CreateGlobalString(format);
#else
llvm::Value* str = builder_->CreateGlobalStringPtr(format);
#endif
str->setName("printf_format_str");

std::vector<llvm::Value*> printf_args = {str};
Expand All @@ -1030,8 +1038,13 @@ void CodeGenLLVM::CreatePrintf(const std::string& format,
llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) {
EmitDebugLocation();
llvm::Value* level_val = llvm::ConstantInt::get(t_int32_, level);
#if TVM_LLVM_VERSION >= 200
llvm::Function* builtin = llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module_.get(), llvm::Intrinsic::returnaddress, {}));
#else
llvm::Function* builtin =
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::returnaddress);
#endif
llvm::Value* call = builder_->CreateCall(builtin, level_val);
call->setName("return_addr");

Expand Down Expand Up @@ -1061,7 +1074,11 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
llvm::Module* module = module_.get();

if (!llvm::Intrinsic::isOverloaded(id)) {
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(module, id, {}));
#else
return llvm::Intrinsic::getDeclaration(module, id, {});
#endif
}

llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
Expand Down Expand Up @@ -1089,7 +1106,12 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
// The return type doesn't match, there is nothing else to do.
return nullptr;
case llvm::Intrinsic::MatchIntrinsicTypes_Match:
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module, id, overload_types));
#else
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
#endif
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
break;
}
Expand All @@ -1101,13 +1123,18 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
if (i > 0) var_types.push_back(arg_types[i - 1]);
auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module, id, overload_types));
#else
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
#endif
}
}
// Failed to identify the type.
return nullptr;

#else // TVM_LLVM_VERSION
#else // TVM_LLVM_VERSION
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
// matchIntrinsicType returns true on error.
if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
Expand All @@ -1118,7 +1145,12 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
return nullptr;
}
}
#if TVM_LLVM_VERSION >= 200
return llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module, id, overload_types));
#else
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
#endif
#endif // TVM_LLVM_VERSION
}

Expand Down Expand Up @@ -1354,7 +1386,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (param_type != arg_value[0]->getType()) {
unsigned addrspace =
llvm::dyn_cast<llvm::PointerType>(arg_value[0]->getType())->getAddressSpace();
arg_value[0] = builder_->CreatePointerCast(arg_value[0], t_char_->getPointerTo(addrspace));
arg_value[0] =
builder_->CreatePointerCast(arg_value[0], llvmGetPointerTo(t_char_, addrspace));
}
}

Expand Down Expand Up @@ -2064,7 +2097,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
buf = alloca;

buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
buf, llvmGetPointerTo(DTypeToLLVMType(op->dtype), buf->getType()->getPointerAddressSpace()));
AddDebugInformation(buf, op->buffer_var);

ICHECK(!var_map_.count(op->buffer_var.get()));
Expand Down
13 changes: 12 additions & 1 deletion src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
}

buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
buf,
llvmGetPointerTo(DTypeToLLVMType(op->dtype), buf->getType()->getPointerAddressSpace()));
ICHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
Expand Down Expand Up @@ -170,7 +171,12 @@ class CodeGenNVPTX : public CodeGenLLVM {
LOG(FATAL) << "unknown thread idx";
}
}
#if TVM_LLVM_VERSION >= 200
llvm::Function* f = llvm::cast<llvm::Function>(
llvm::Intrinsic::getOrInsertDeclaration(module_.get(), intrin_id, {}));
#else
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
#endif
return builder_->CreateCall(f, {});
}

Expand All @@ -180,8 +186,13 @@ class CodeGenNVPTX : public CodeGenLLVM {
// TODO(tqchen) warp sync in CUDA9
return nullptr;
} else if (sync == "shared" || sync == "shared.dyn") {
#if TVM_LLVM_VERSION >= 200
llvm::Function* f = llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(
module_.get(), llvm::Intrinsic::nvvm_barrier0, {}));
#else
llvm::Function* f =
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::nvvm_barrier0);
#endif
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
Expand Down
7 changes: 6 additions & 1 deletion src/target/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes,
llvm::Type* result_ty,
const std::vector<llvm::Value*>& args) {
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
#if TVM_LLVM_VERSION >= 200
llvm::Function* f =
llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(module_.get(), id, {}));
#else
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id);
#endif
#if TVM_LLVM_VERSION >= 120
size_t num_elems = llvm::cast<llvm::FixedVectorType>(result_ty)->getNumElements();
#else
Expand Down
7 changes: 7 additions & 0 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
#include <utility>
#include <vector>

// LLVM compatibility macro
#if TVM_LLVM_VERSION >= 200
#define llvmGetPointerTo(arg, offset) (llvm::PointerType::get((arg), (offset)))
#else
#define llvmGetPointerTo(arg, offset) (arg->getPointerTo(offset))
#endif

namespace llvm {
class LLVMContext;
class MemoryBuffer;
Expand Down
4 changes: 4 additions & 0 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,11 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")

TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
.set_body_typed([](std::string name) -> int64_t {
#if TVM_LLVM_VERSION >= 200
return static_cast<int64_t>(llvm::Intrinsic::lookupIntrinsicID(name));
#else
return static_cast<int64_t>(llvm::Function::lookupIntrinsicID(name));
#endif
});

TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String {
Expand Down

0 comments on commit 513c2be

Please sign in to comment.