Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yajiedesign committed Dec 30, 2019
1 parent f366d9e commit fc27ce9
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,17 @@ if(USE_CUDA)
include(${CMAKE_ROOT}/Modules/FindCUDA/select_compute_arch.cmake)
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS ${MXNET_CUDA_ARCH})
message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}")
set(arch_code_list)
foreach(arch_str ${CUDA_ARCH_FLAGS})
if((arch_str MATCHES ".*sm_[0-9]+"))
string( REGEX REPLACE ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} )
list(APPEND arch_code_list ${arch_code})
endif()
endforeach()
MESSAGE( "${arch_code_list}")

string(REPLACE ";" " " CUDA_ARCH_FLAGS_SPACES "${CUDA_ARCH_FLAGS}")
string(APPEND CMAKE_CUDA_FLAGS " ${CUDA_ARCH_FLAGS_SPACES}")


find_package(CUDAToolkit REQUIRED cublas cufft cusolver curand
OPTIONAL_COMPONENTS nvToolsExt nvrtc)
Expand Down Expand Up @@ -672,6 +681,7 @@ add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib
target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
set(MXNET_INSTALL_TARGETS mxnet)
if(UNIX)
string(APPEND CMAKE_CUDA_FLAGS "${CUDA_ARCH_FLAGS_SPACES}")
# Create dummy file since we want an empty shared library before linking
set(DUMMY_SOURCE ${CMAKE_BINARY_DIR}/dummy.c)
file(WRITE ${DUMMY_SOURCE} "")
Expand All @@ -688,14 +698,14 @@ elseif(MSVC)
set_target_properties(sample_lib PROPERTIES PREFIX "lib")

if(USE_CUDA)
if(FIRST_CUDA AND MSVC)
if(MSVC)
if(USE_SPLIT_ARCH_DLL)
add_executable(gen_warp tools/windowsbuild/gen_warp.cpp)
add_library(mxnet SHARED tools/windowsbuild/warp_dll.cpp ${CMAKE_BINARY_DIR}/warp_gen_cpp.cpp
${CMAKE_BINARY_DIR}/warp_gen.asm)
target_link_libraries(mxnet PRIVATE cudart Shlwapi)
list(GET cuda_arch 0 mxnet_first_arch)
foreach(arch ${cuda_arch})
list(GET arch_code_list 0 mxnet_first_arch)
foreach(arch ${arch_code_list})
add_library(mxnet_${arch} SHARED ${SOURCE})
target_compile_options(
mxnet_${arch}
Expand All @@ -720,7 +730,7 @@ elseif(MSVC)
COMMAND gen_warp $<TARGET_FILE:mxnet_${mxnet_first_arch}> WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/ DEPENDS $<TARGET_FILE:mxnet_${mxnet_first_arch}>)
else(USE_SPLIT_ARCH_DLL)
string(REPLACE ";" " " NVCC_FLAGS_ARCH "${NVCC_FLAGS_ARCH}")
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS_ARCH}")
set(CMAKE_CUDA_FLAGS "${CUDA_ARCH_FLAGS_SPACES}")
add_library(mxnet SHARED ${SOURCE})
target_compile_options(
mxnet
Expand Down Expand Up @@ -781,7 +791,7 @@ if(USE_PLUGINS_WARPCTC)
endif()

if(MSVC)
if(FIRST_CUDA AND USE_SPLIT_ARCH_DLL)
if(USE_SPLIT_ARCH_DLL)
foreach(arch ${cuda_arch})
target_link_libraries(mxnet_${arch} PUBLIC ${mxnet_LINKER_LIBS})
target_link_libraries(mxnet_${arch} PUBLIC dmlc)
Expand Down

0 comments on commit fc27ce9

Please sign in to comment.