diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index ee956c85e0ed..0f72a4694a82 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -1,7 +1,8 @@ """Utility for ROCm backend""" import subprocess +from os.path import join from . import util -from ..api import register_func +from ..api import register_func, convert def rocm_link(in_file, out_file): """Link relocatable ELF object to shared ELF object using lld @@ -49,3 +50,32 @@ def callback_rocm_link(obj_bin): rocm_link(tmp_obj, tmp_cobj) cobj_bin = bytearray(open(tmp_cobj, "rb").read()) return cobj_bin + +@register_func("tvm_callback_rocm_bitcode_path") +def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"): + """Utility function to find ROCm device library bitcodes + + Parameters + ---------- + rocdl_dir : str + The path to rocm library directory + The default value is the standard location + """ + # seems link order matters. + bitcode_files = [ + "oclc_daz_opt_on.amdgcn.bc", + "ocml.amdgcn.bc", + "hc.amdgcn.bc", + "irif.amdgcn.bc", + "ockl.amdgcn.bc", + "oclc_correctly_rounded_sqrt_off.amdgcn.bc", + "oclc_correctly_rounded_sqrt_on.amdgcn.bc", + "oclc_daz_opt_off.amdgcn.bc", + "oclc_finite_only_off.amdgcn.bc", + "oclc_finite_only_on.amdgcn.bc", + "oclc_isa_version_803.amdgcn.bc", + "oclc_isa_version_900.amdgcn.bc", + "oclc_unsafe_math_off.amdgcn.bc", + "oclc_unsafe_math_on.amdgcn.bc" + ] + return convert([join(rocdl_dir, bitcode) for bitcode in bitcode_files]) diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index f3d9d811eec1..f49f2283210f 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -161,6 +161,24 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { cg->AddFunction(f); } + const auto *find_rocm_bitcodes = + tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); + Array bitcode_files = (*find_rocm_bitcodes)(); + + for (auto &bitcode : bitcode_files) { + std::string path = bitcode.as()->value; + llvm::SMDiagnostic err; + std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); + if (mlib.get() == nullptr) { + std::string msg = err.getMessage(); + LOG(FATAL) << "Fail to load bitcode file " << path << "\n" + << "line " << err.getLineNo() << ":" << msg; + } + mlib->setTargetTriple(tm->getTargetTriple().str()); + mlib->setDataLayout(tm->createDataLayout()); + cg->AddLinkModule(std::move(mlib)); + } + std::unique_ptr module = cg->Finish(); llvm::SmallString<8> dataObj, data_ll, dataAsm; llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index cb2eae40eaeb..2654dee0f7e5 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -516,11 +516,10 @@ llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { } llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { - CHECK_GE(op->args.size(), 1U); std::vector arg_value; std::vector arg_type; - for (size_t i = 1; i < op->args.size(); ++i) { - arg_value.push_back(MakeValue(op->args[i + 1])); + for (size_t i = 0; i < op->args.size(); ++i) { + arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } llvm::FunctionType* ftype = llvm::FunctionType::get( diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc new file mode 100644 index 000000000000..e64a4fce4f27 --- /dev/null +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -0,0 +1,48 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file intrin_rule_llvm.cc + */ +#ifdef TVM_LLVM_VERSION + +#include "./intrin_rule_llvm.h" +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { + Expr e = args[0]; + using namespace ir; + const Call* call = e.as(); + CHECK(call != nullptr); + std::ostringstream intrinsic_name; + intrinsic_name << "__ocml_" << call->name << "_f" << call->type.bits(); + *rv = Call::make(call->type, intrinsic_name.str(), call->args, + Call::PureExtern); +} + +namespace llvm { + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") +.set_body(DispatchExternOCML); + +} // namespace llvm +} // namespace codegen +} // namespace tvm + +#endif // LLVM_VERSION