diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f3031502668..42cb46547614 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,8 @@ option(USE_ASAN "Enable Clang/GCC ASAN sanitizers." OFF) option(ENABLE_TESTCOVERAGE "Enable compilation with test coverage metric output" OFF) option(USE_INT64_TENSOR_SIZE "Use int64_t to represent the total number of elements in a tensor" OFF) option(BUILD_CYTHON_MODULES "Build cython modules." OFF) +cmake_dependent_option(USE_SPLIT_ARCH_DLL "Build a separate DLL for each Cuda arch (Windows only)." ON "MSVC" OFF) + message(STATUS "CMAKE_CROSSCOMPILING ${CMAKE_CROSSCOMPILING}") message(STATUS "CMAKE_HOST_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR}") @@ -100,6 +102,7 @@ endif() if(MSVC) set(SYSTEM_ARCHITECTURE x86_64) + enable_language(ASM_MASM) else() execute_process(COMMAND uname -m COMMAND tr -d '\n' OUTPUT_VARIABLE SYSTEM_ARCHITECTURE) endif() @@ -192,9 +195,11 @@ else() add_definitions(-DDMLC_USE_CXX11=1) add_definitions(-DDMLC_USE_CXX14=1) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") + set(CMAKE_CUDA_STANDARD 14) elseif(SUPPORT_CXX11) add_definitions(-DDMLC_USE_CXX11=1) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + set(CMAKE_CUDA_STANDARD 11) elseif(SUPPORT_CXX0X) add_definitions(-DDMLC_USE_CXX11=1) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") @@ -678,30 +683,66 @@ if(UNIX) target_link_libraries(mxnet_static PUBLIC ${CMAKE_DL_LIBS}) target_compile_options(sample_lib PUBLIC -shared) set_target_properties(mxnet_static PROPERTIES OUTPUT_NAME mxnet) -else() - add_library(mxnet SHARED ${SOURCE}) +elseif(MSVC) target_compile_options(sample_lib PUBLIC /LD) set_target_properties(sample_lib PROPERTIES PREFIX "lib") -endif() -if(USE_CUDA AND MSVC) - target_compile_options(mxnet PUBLIC "$<$:-Xcompiler=-MTd -Gy>") - target_compile_options(mxnet PUBLIC "$<$:-Xcompiler=-MT -Gy>") + if(USE_CUDA) + if(FIRST_CUDA AND MSVC) + if(USE_SPLIT_ARCH_DLL) + add_executable(gen_warp tools/windowsbuild/gen_warp.cpp) + add_library(mxnet SHARED tools/windowsbuild/warp_dll.cpp ${CMAKE_BINARY_DIR}/warp_gen_cpp.cpp + ${CMAKE_BINARY_DIR}/warp_gen.asm) + target_link_libraries(mxnet PRIVATE cudart Shlwapi) + list(GET cuda_arch 0 mxnet_first_arch) + foreach(arch ${cuda_arch}) + add_library(mxnet_${arch} SHARED ${SOURCE}) + target_compile_options( + mxnet_${arch} + PRIVATE + "$<$:--gpu-architecture=compute_${arch}>" + ) + target_compile_options( + mxnet_${arch} + PRIVATE + "$<$:--gpu-code=sm_${arch},compute_${arch}>" + ) + target_compile_options( + mxnet_${arch} + PRIVATE "$<$,$>:-Xcompiler=-MTd -Gy /bigobj>") + target_compile_options( + mxnet_${arch} + PRIVATE "$<$,$>:-Xcompiler=-MT -Gy /bigobj>") + endforeach() + + add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/warp_gen_cpp.cpp ${CMAKE_BINARY_DIR}/warp_gen.asm + COMMAND gen_warp $ WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/ DEPENDS $) + else(USE_SPLIT_ARCH_DLL) + string(REPLACE ";" " " NVCC_FLAGS_ARCH "${NVCC_FLAGS_ARCH}") + set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS_ARCH}") + add_library(mxnet SHARED ${SOURCE}) + target_compile_options( + mxnet + PRIVATE "$<$,$>:-Xcompiler=-MTd -Gy /bigobj>") + target_compile_options( + mxnet + PRIVATE "$<$,$>:-Xcompiler=-MT -Gy /bigobj>") + + endif(USE_SPLIT_ARCH_DLL) + else() + add_library(mxnet SHARED ${SOURCE}) + endif() + else() + add_library(mxnet SHARED ${SOURCE}) + endif() + endif() + if(USE_DIST_KVSTORE) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/ps-lite/CMakeLists.txt) add_subdirectory("3rdparty/ps-lite") list(APPEND pslite_LINKER_LIBS pslite protobuf) - target_link_libraries(mxnet PUBLIC debug ${pslite_LINKER_LIBS_DEBUG}) - target_link_libraries(mxnet PUBLIC optimized ${pslite_LINKER_LIBS_RELEASE}) - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - list(APPEND mxnet_LINKER_LIBS ${pslite_LINKER_LIBS_DEBUG}) - else() - list(APPEND mxnet_LINKER_LIBS ${pslite_LINKER_LIBS_RELEASE}) - endif() - target_link_libraries(mxnet PUBLIC debug ${pslite_LINKER_LIBS_DEBUG}) - target_link_libraries(mxnet PUBLIC optimized ${pslite_LINKER_LIBS_RELEASE}) - else() set(pslite_LINKER_LIBS protobuf zmq-static) endif() @@ -735,13 +776,24 @@ if(USE_TVM_OP) ) endif() -target_link_libraries(mxnet PUBLIC ${mxnet_LINKER_LIBS}) - if(USE_PLUGINS_WARPCTC) - target_link_libraries(mxnet PUBLIC debug ${WARPCTC_LIB_DEBUG}) - target_link_libraries(mxnet PUBLIC optimized ${WARPCTC_LIB_RELEASE}) + list(APPEND mxnet_LINKER_LIBS ${WARPCTC_LIB}) endif() +if(MSVC) + if(FIRST_CUDA AND USE_SPLIT_ARCH_DLL) + foreach(arch ${cuda_arch}) + target_link_libraries(mxnet_${arch} PUBLIC ${mxnet_LINKER_LIBS}) + target_link_libraries(mxnet_${arch} PUBLIC dmlc) + endforeach() + else() + target_link_libraries(mxnet PUBLIC ${mxnet_LINKER_LIBS}) + target_link_libraries(mxnet PUBLIC dmlc) + endif() +else() + target_link_libraries(mxnet PUBLIC ${mxnet_LINKER_LIBS}) + target_link_libraries(mxnet PUBLIC dmlc) +endif() if(USE_OPENCV AND OpenCV_VERSION_MAJOR GREATER 2) add_executable(im2rec "tools/im2rec.cc") @@ -761,7 +813,6 @@ else() is required for im2rec, im2rec will not be available") endif() -target_link_libraries(mxnet PUBLIC dmlc) if(MSVC AND USE_MXNET_LIB_NAMING) set_target_properties(mxnet PROPERTIES OUTPUT_NAME "libmxnet") diff --git a/tools/windowsbuild/README.md b/tools/windowsbuild/README.md new file mode 100644 index 000000000000..7d8e7cf331cf --- /dev/null +++ b/tools/windowsbuild/README.md @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + +Due to dll size limitation under windows. Split dll into different dlls according to arch +Reference https://github.com/apache/incubator-mxnet/pull/16980 \ No newline at end of file diff --git a/tools/windowsbuild/gen_warp.cpp b/tools/windowsbuild/gen_warp.cpp new file mode 100644 index 000000000000..2d90eaf364f3 --- /dev/null +++ b/tools/windowsbuild/gen_warp.cpp @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define IMAGE_SIZEOF_SIGNATURE 4 + + +DWORD rva_to_foa(IN DWORD RVA, IN PIMAGE_SECTION_HEADER section_header) +{ + + size_t count = 0; + for (count = 1; RVA > (section_header->VirtualAddress + section_header->Misc.VirtualSize); count++, section_header++); + + DWORD FOA = RVA - section_header->VirtualAddress + section_header->PointerToRawData; + + return FOA; +} + +std::string format(const char* format, ...) +{ + va_list args; + va_start(args, format); +#ifndef _MSC_VER + size_t size = std::snprintf(nullptr, 0, format, args) + 1; // Extra space for '\0' + std::unique_ptr buf(new char[size]); + std::vsnprintf(buf.get(), size, format, args); + return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside +#else + int size = _vscprintf(format, args) +1; + std::unique_ptr buf(new char[size]); + vsnprintf_s(buf.get(), size, _TRUNCATE, format, args); + return std::string(buf.get()); +#endif + va_end(args); +} + +int main(int argc, char* argv[]) +{ + + if (argc != 2) + { + return 0; + } + + //open file + const HANDLE h_file = CreateFile( + argv[1], + GENERIC_READ , + FILE_SHARE_READ , + nullptr, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + nullptr); + + + DWORD size_high; + const DWORD size_low = GetFileSize(h_file, &size_high); + + uint64_t dll_size = ((uint64_t(size_high)) << 32) + (uint64_t)size_low; + + // Create File Mapping + const HANDLE h_map_file = CreateFileMapping( + h_file, + nullptr, + PAGE_READONLY, + size_high, + size_low, + nullptr); + if (h_map_file == INVALID_HANDLE_VALUE || h_map_file == nullptr) + { + std::cout << "error"; + CloseHandle(h_file); + return 0; + } + + //Map File to memory + void* pv_file = MapViewOfFile( + h_map_file, + FILE_MAP_READ, + 0, + 0, + 0); + + if (pv_file == nullptr) + { + std::cout << "error"; + CloseHandle(h_file); + return 0; + } + + uint8_t* p = static_cast(pv_file); + + + PIMAGE_DOS_HEADER dos_header = reinterpret_cast(p); + + const PIMAGE_NT_HEADERS nt_headers = reinterpret_cast(p + dos_header->e_lfanew); + + const PIMAGE_FILE_HEADER file_header = &nt_headers->FileHeader; + + PIMAGE_OPTIONAL_HEADER optional_header = (PIMAGE_OPTIONAL_HEADER)(&nt_headers->OptionalHeader); + + const DWORD file_alignment = optional_header->FileAlignment; + + + PIMAGE_SECTION_HEADER section_table = + reinterpret_cast(p + dos_header->e_lfanew + + IMAGE_SIZEOF_SIGNATURE + + IMAGE_SIZEOF_FILE_HEADER + + file_header->SizeOfOptionalHeader); + + DWORD export_foa = rva_to_foa(optional_header->DataDirectory[0].VirtualAddress, section_table); + + PIMAGE_EXPORT_DIRECTORY export_directory = (PIMAGE_EXPORT_DIRECTORY)(p + export_foa); + + + DWORD name_list_foa = rva_to_foa(export_directory->AddressOfNames, section_table); + + PDWORD name_list = (PDWORD)(p + name_list_foa); + + + + + std::vector func_list; + + for (size_t i = 0; i < export_directory->NumberOfNames; i++, name_list++) + { + + DWORD name_foa = rva_to_foa(* name_list, section_table); + char* name = (char*)(p + name_foa); + func_list.emplace_back(name); + + } + + + UnmapViewOfFile(pv_file); + CloseHandle(h_map_file); + CloseHandle(h_file); + + + std::ofstream gen_cpp_obj; + gen_cpp_obj.open("warp_gen_cpp.cpp", std::ios::out | std::ios::trunc); + gen_cpp_obj << "#include \n"; + gen_cpp_obj << "extern \"C\" \n{\n"; + + + for (size_t i = 0; i < func_list.size(); ++i) + { + auto fun = func_list[i]; + gen_cpp_obj << format("void * warp_point_%d;\n", i); + gen_cpp_obj << format("#pragma comment(linker, \"/export:%s=warp_func_%d\")\n", fun.c_str(), i); + gen_cpp_obj << format("void warp_func_%d();\n", i); + gen_cpp_obj << ("\n"); + } + gen_cpp_obj << ("}\n"); + + + gen_cpp_obj << ("void load_function(HMODULE hm)\n{\n"); + for (size_t i = 0; i < func_list.size(); ++i) + { + auto fun = func_list[i]; + gen_cpp_obj << format("warp_point_%d = (void*)GetProcAddress(hm, \"%s\");\n", i, fun.c_str()); + } + gen_cpp_obj << ("}\n"); + + gen_cpp_obj.close(); + + + + std::ofstream gen_asm_obj; + gen_asm_obj.open("warp_gen.asm", std::ios::out | std::ios::trunc); + for (size_t i = 0; i < func_list.size(); ++i) + { + auto fun = func_list[i]; + gen_asm_obj << format("EXTERN warp_point_%d:QWORD;\n", i); + } + gen_asm_obj << ".CODE\n"; + for (size_t i = 0; i < func_list.size(); ++i) + { + auto fun = func_list[i]; + gen_asm_obj << format("warp_func_%d PROC\njmp warp_point_%d;\nwarp_func_%d ENDP\n", i,i,i); + } + gen_asm_obj << "END\n"; + gen_asm_obj.close(); +} diff --git a/tools/windowsbuild/warp_dll.cpp b/tools/windowsbuild/warp_dll.cpp new file mode 100644 index 000000000000..6a89a4e189de --- /dev/null +++ b/tools/windowsbuild/warp_dll.cpp @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +extern "C" IMAGE_DOS_HEADER __ImageBase; + + +std::vector find_mxnet_dll() +{ + std::vector version; + intptr_t handle; + + _wfinddata_t findData{}; + std::wregex reg(L".*?mxnet_([0-9]+)\\.dll"); + + HMODULE hModule = reinterpret_cast(&__ImageBase); + WCHAR szPathBuffer[MAX_PATH] = { 0 }; + GetModuleFileNameW(hModule, szPathBuffer, MAX_PATH); + + PathRemoveFileSpecW(szPathBuffer); + wcscat_s(szPathBuffer, L"\\mxnet_*.dll"); + + handle = _wfindfirst(szPathBuffer, &findData); + if (handle == -1) + { + return version; + } + + do + { + if (!(findData.attrib & _A_SUBDIR) || wcscmp(findData.name, L".") != 0 || wcscmp(findData.name, L"..") != 0) + { + std::wstring str(findData.name); + std::wsmatch base_match; + if(std::regex_match(str, base_match, reg)) + { + if (base_match.size() == 2) { + std::wssub_match base_sub_match = base_match[1]; + std::wstring base = base_sub_match.str(); + version.push_back(std::stoi(base)) ; + } + } + } + } while (_wfindnext(handle, &findData) == 0); + + _findclose(handle); + std::sort(version.begin(), version.end()); + return version; +} + +int find_version() +{ + std::vector known_sm = find_mxnet_dll(); + int count = 0; + int version = 75; + if (cudaSuccess != cudaGetDeviceCount(&count)) + { + return 30; + } + if (count == 0) + { + return 30; + } + + + for (int device = 0; device < count; ++device) + { + cudaDeviceProp prop{}; + if (cudaSuccess == cudaGetDeviceProperties(&prop, device)) + { + version = std::min(version, prop.major * 10 + prop.minor); + } + } + + for (int i = known_sm.size() -1 ; i >=0; --i) + { + if(known_sm[i]<= version) + { + return known_sm[i]; + } + } + + return version; +} + +void load_function(HMODULE hm); + +void mxnet_init() +{ + int version = find_version(); + WCHAR dll_name[MAX_PATH]; + wsprintfW(dll_name, L"mxnet_%d.dll", version); + HMODULE hm = LoadLibraryW(dll_name); + load_function(hm); +} + + +extern "C" BOOL WINAPI DllMain( + HINSTANCE const instance, // handle to DLL module + DWORD const reason, // reason for calling function + LPVOID const reserved) // reserved +{ + // Perform actions based on the reason for calling. + switch (reason) + { + case DLL_PROCESS_ATTACH: + mxnet_init(); + // Initialize once for each new process. + // Return FALSE to fail DLL load. + break; + + case DLL_THREAD_ATTACH: + // Do thread-specific initialization. + break; + + case DLL_THREAD_DETACH: + // Do thread-specific cleanup. + break; + + case DLL_PROCESS_DETACH: + // Perform any necessary cleanup. + break; + } + return TRUE; // Successful DLL_PROCESS_ATTACH. +} \ No newline at end of file