Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add dynamic read mxnet dll
Browse files Browse the repository at this point in the history
  • Loading branch information
yajiedesign committed Dec 27, 2019
1 parent 5a31b19 commit 6515bf6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ elseif(MSVC)
add_executable(gen_warp tools/windowsbuild/gen_warp.cpp)
add_library(mxnet SHARED tools/windowsbuild/warp_dll.cpp ${CMAKE_BINARY_DIR}/warp_gen_cpp.cpp
${CMAKE_BINARY_DIR}/warp_gen.asm)
target_link_libraries(mxnet PRIVATE cudart)
target_link_libraries(mxnet PRIVATE cudart Shlwapi)
list(GET cuda_arch 0 mxnet_first_arch)
foreach(arch ${cuda_arch})
add_library(mxnet_${arch} SHARED ${SOURCE})
Expand Down
57 changes: 53 additions & 4 deletions tools/windowsbuild/warp_dll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,61 @@
#include <cuda_runtime.h>
#include <algorithm>
#include <Windows.h>
#include <io.h>
#include <vector>
#include <regex>
#include <shlwapi.h>


extern "C" IMAGE_DOS_HEADER __ImageBase;


std::vector<int> find_mxnet_dll()
{
std::vector<int> version;
intptr_t handle;

_wfinddata_t findData{};
std::wregex reg(L".*?mxnet_([0-9]+)\\.dll");

HMODULE hModule = reinterpret_cast<HMODULE>(&__ImageBase);
WCHAR szPathBuffer[MAX_PATH] = { 0 };
GetModuleFileNameW(hModule, szPathBuffer, MAX_PATH);

PathRemoveFileSpecW(szPathBuffer);
wcscat_s(szPathBuffer, L"\\mxnet_*.dll");

handle = _wfindfirst(szPathBuffer, &findData);
if (handle == -1)
{
return version;
}

do
{
if (!(findData.attrib & _A_SUBDIR) || wcscmp(findData.name, L".") != 0 || wcscmp(findData.name, L"..") != 0)
{
std::wstring str(findData.name);
std::wsmatch base_match;
if(std::regex_match(str, base_match, reg))
{
if (base_match.size() == 2) {
std::wssub_match base_sub_match = base_match[1];
std::wstring base = base_sub_match.str();
version.push_back(std::stoi(base)) ;
}
}
}
} while (_wfindnext(handle, &findData) == 0);

_findclose(handle);
std::sort(version.begin(), version.end());
return version;
}

int find_version()
{
int known_sm[] = { 30,35,37,50,52,60,61,70,75 };
std::vector<int> known_sm = find_mxnet_dll();
int count = 0;
int version = 75;
if (cudaSuccess != cudaGetDeviceCount(&count))
Expand Down Expand Up @@ -65,9 +114,9 @@ void load_function(HMODULE hm);
void mxnet_init()
{
int version = find_version();
char dll_name[256];
sprintf(dll_name, "mxnet_%d.dll", version);
HMODULE hm = LoadLibrary(dll_name);
WCHAR dll_name[MAX_PATH];
wsprintfW(dll_name, L"mxnet_%d.dll", version);
HMODULE hm = LoadLibraryW(dll_name);
load_function(hm);
}

Expand Down

0 comments on commit 6515bf6

Please sign in to comment.