Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixing issue #17840 #18526

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
@@ -152,6 +152,12 @@ def get_cuda_arch(arch):
# we create libtvmop.o first, which gives us chance to link tvm_runtime together with the libtvmop
# to allow mxnet find external helper functions in libtvm_runtime
func_binary.save(arguments.target_path + "/libtvmop.o")
try:
func_binary.imported_modules
except NameError:
func_binary.imported_modules = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from https://github.com/apache/incubator-tvm/blob/master/python/tvm/runtime/module.py#L136 we can see func_binary.imported_modules should always exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. I have deleted these lines.

if len(func_binary.imported_modules):
func_binary.imported_modules[0].save(arguments.target_path + "/libtvmop.cubin")
ld_path = arguments.target_path if arguments.ld_path is None else arguments.ld_path
create_shared(arguments.target_path + "/libtvmop.so",
arguments.target_path + "/libtvmop.o",
10 changes: 9 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
@@ -1363,7 +1363,15 @@ int MXGetVersion(int *out) {
#if MXNET_USE_TVM_OP
int MXLoadTVMOp(const char *libpath) {
API_BEGIN();
tvm::runtime::TVMOpModule::Get()->Load(libpath);
tvm::runtime::TVMOpModule *libpath_module = tvm::runtime::TVMOpModule::Get();
libpath_module->Load(libpath);
#if MXNET_USE_CUDA
std::string libpathstr(libpath);
std::string cubinpath = libpathstr.substr(0, libpathstr.size() - 11) + "libtvmop.cubin";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to pass libpath as dir, and do libpath + "libtvmop.so" as well to keep consistency.

Copy link
Contributor Author

@jinboci jinboci Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrite it in a more elegant way :)

Yes, but MXLoadTVMOp is called at:
https://github.com/apache/incubator-mxnet/blob/1bf881f381f91b157a26d9beddcaa8f4960cc038/python/mxnet/tvmop.py#L31-L32
where _LIB_TVM_OP is returned from the
https://github.com/apache/incubator-mxnet/blob/1bf881f381f91b157a26d9beddcaa8f4960cc038/python/mxnet/libinfo.py#L25
, and _LIB_TVM_OP[0] is the path of libtvmop.so.
We may need to modify find_lib_path or write a new function to get the directory that libtvmop.so locates.

tvm::runtime::TVMOpModule cubin_module;
cubin_module.Load(cubinpath);
libpath_module->Import(cubin_module);
#endif
API_END();
}

6 changes: 6 additions & 0 deletions src/operator/tvmop/op_module.cc
Original file line number Diff line number Diff line change
@@ -46,6 +46,12 @@ void TVMOpModule::Load(const std::string &filepath) {
*module_ptr_ = module;
}

void TVMOpModule::Import(const TVMOpModule& module) {
CHECK(module_ptr_ != nullptr) << "module_ptr_ is not initialized.";
std::lock_guard<std::mutex> lock(mutex_);
module_ptr_->Import(*(module.module_ptr_));
}

PackedFunc GetFunction(const std::shared_ptr<Module> &module,
const std::string &op_name,
const std::vector<mxnet::TBlob> &args) {
2 changes: 2 additions & 0 deletions src/operator/tvmop/op_module.h
Original file line number Diff line number Diff line change
@@ -44,6 +44,8 @@ class TVMOpModule {
// Load TVM operators binary
void Load(const std::string& filepath);

void Import(const TVMOpModule& module);

void Call(const std::string& func_name,
const mxnet::OpContext& ctx,
const std::vector<mxnet::TBlob>& args) const;