-
-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fails to run find_package(Torch) on Windows with libtorch package #333
Comments
Yeah, I just noticed this in conda-forge/torchaudio-feedstock#14. The issue is that we move
In this way, this is a follow-up to of #327 & #328. I think the easiest might be to patch the CMake metadata here? Otherwise we'd have to patch the build system itself to ensure DLLs get installed into |
Actually, https://github.com/pytorch/pytorch/blob/v2.5.1/cmake/TorchConfig.cmake.in looks patchable |
I think (but did not test) that to install the c10 DLLs in bin it should be sufficient to remove the |
As a note, c10 is not the only issue; other targets' DLLs should also be set to bin, such as |
I'm working on a patch + test. |
Indeed, simply removing the destination does not work:
|
Default DESTINATION argument is available for |
@traversaro, I've switched things to (only for TARGETS) -install(TARGETS torch_python DESTINATION "${TORCH_INSTALL_LIB_DIR}")
+install(TARGETS torch_python
+ LIBRARY DESTINATION lib
+ RUNTIME DESTINATION bin) (and so on) but looking at the logs, it seems that the import libraries do not get installed anymore. You're right though, before I made the mistake of also changing some |
Either we delete DESTINATION from install(TARGETS, or add ARCHIVE DESTINATION for installing import libraries. |
gah, why would import libraries be handled by |
I get your pain, but just for reference the related docs is in https://cmake.org/cmake/help/latest/command/install.html#signatures . |
Even with the fixes from #318, there's still a problem with CMake, in that the CUDA-enabled builds (at least on linux), insist on finding actual CUDA drivers(?):
This will obviously affect all feedstocks that want to build on top of a cuda-enabled pytorch. Someone else raised this in conda-forge/cuda-feedstock#59 already; the issue is that the Caffe2 CMake files still use the long-deprecated |
@danpetry, thanks for the offer to help. Here's my draft of replacing diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index b51c7cc637b..6e107b5b02a 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -906,25 +906,25 @@ if(USE_ROCM)
"$<$<COMPILE_LANGUAGE:CXX>:ATen/core/ATen_pch.h>")
endif()
elseif(USE_CUDA)
- set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
+ set(CUDAToolkit_LINK_LIBRARIES_KEYWORD PRIVATE)
list(APPEND Caffe2_GPU_SRCS ${GENERATED_CXX_TORCH_CUDA})
- if(CUDA_SEPARABLE_COMPILATION)
+ if(CUDAToolkit_SEPARABLE_COMPILATION)
# Separate compilation fails when kernels using `thrust::sort_by_key`
# are linked with the rest of CUDA code. Workaround by linking them separately.
add_library(torch_cuda ${Caffe2_GPU_SRCS} ${Caffe2_GPU_CU_SRCS})
- set_property(TARGET torch_cuda PROPERTY CUDA_SEPARABLE_COMPILATION ON)
+ set_property(TARGET torch_cuda PROPERTY CUDAToolkit_SEPARABLE_COMPILATION ON)
add_library(torch_cuda_w_sort_by_key OBJECT
${Caffe2_GPU_SRCS_W_SORT_BY_KEY}
${Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY})
- set_property(TARGET torch_cuda_w_sort_by_key PROPERTY CUDA_SEPARABLE_COMPILATION OFF)
+ set_property(TARGET torch_cuda_w_sort_by_key PROPERTY CUDAToolkit_SEPARABLE_COMPILATION OFF)
target_link_libraries(torch_cuda PRIVATE torch_cuda_w_sort_by_key)
else()
add_library(torch_cuda
${Caffe2_GPU_SRCS} ${Caffe2_GPU_SRCS_W_SORT_BY_KEY}
${Caffe2_GPU_CU_SRCS} ${Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY})
endif()
- set(CUDA_LINK_LIBRARIES_KEYWORD)
+ set(CUDAToolkit_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
@@ -973,12 +973,12 @@ elseif(USE_CUDA)
torch_cuda
)
if($ENV{ATEN_STATIC_CUDA})
- if(CUDA_VERSION_MAJOR LESS_EQUAL 11)
+ if(CUDAToolkit_VERSION_MAJOR LESS_EQUAL 11)
target_link_libraries(torch_cuda_linalg PRIVATE
CUDA::cusolver_static
${CUDAToolkit_LIBRARY_DIR}/liblapack_static.a # needed for libcusolver_static
)
- elseif(CUDA_VERSION_MAJOR GREATER_EQUAL 12)
+ elseif(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12)
target_link_libraries(torch_cuda_linalg PRIVATE
CUDA::cusolver_static
${CUDAToolkit_LIBRARY_DIR}/libcusolver_lapack_static.a # needed for libcusolver_static
diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake
index d51c451589c..154f04a89dd 100644
--- a/cmake/Summary.cmake
+++ b/cmake/Summary.cmake
@@ -76,7 +76,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
message(STATUS " USE_CUDSS : ${USE_CUDSS}")
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
- message(STATUS " CUDA version : ${CUDA_VERSION}")
+ message(STATUS " CUDA version : ${CUDAToolkit_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
if(${USE_CUDNN})
@@ -88,7 +88,7 @@ function(caffe2_print_configuration_summary)
if(${USE_CUFILE})
message(STATUS " cufile library : ${CUDA_cuFile_LIBRARY}")
endif()
- message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}")
+ message(STATUS " CUDA root directory : ${CUDAToolkit_ROOT}")
message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}")
message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}")
message(STATUS " cublas library : ${CUDA_cublas_LIBRARY}")
@@ -108,12 +108,12 @@ function(caffe2_print_configuration_summary)
message(STATUS " cuDSS library : ${__tmp}")
endif()
message(STATUS " nvrtc : ${CUDA_nvrtc_LIBRARY}")
- message(STATUS " CUDA include path : ${CUDA_INCLUDE_DIRS}")
- message(STATUS " NVCC executable : ${CUDA_NVCC_EXECUTABLE}")
+ message(STATUS " CUDA include path : ${CUDATookit_INCLUDE_DIRS}")
+ message(STATUS " NVCC executable : ${CUDATookit_NVCC_EXECUTABLE}")
message(STATUS " CUDA compiler : ${CMAKE_CUDA_COMPILER}")
message(STATUS " CUDA flags : ${CMAKE_CUDA_FLAGS}")
message(STATUS " CUDA host compiler : ${CMAKE_CUDA_HOST_COMPILER}")
- message(STATUS " CUDA --device-c : ${CUDA_SEPARABLE_COMPILATION}")
+ message(STATUS " CUDA --device-c : ${CUDATookit_SEPARABLE_COMPILATION}")
message(STATUS " USE_TENSORRT : ${USE_TENSORRT}")
if(${USE_TENSORRT})
message(STATUS " TensorRT runtime library: ${TENSORRT_LIBRARY}")
diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in
index cba4d929855..da904fc6a18 100644
--- a/cmake/TorchConfig.cmake.in
+++ b/cmake/TorchConfig.cmake.in
@@ -125,7 +125,7 @@ if(@USE_CUDA@)
find_library(CAFFE2_NVRTC_LIBRARY caffe2_nvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib")
list(APPEND TORCH_CUDA_LIBRARIES ${CAFFE2_NVRTC_LIBRARY})
else()
- set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
+ set(TORCH_CUDA_LIBRARIES CUDA::nvrtc)
endif()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake
index 152fbdbe6dd..13bae9b6227 100644
--- a/cmake/public/cuda.cmake
+++ b/cmake/public/cuda.cmake
@@ -26,8 +26,8 @@ if(NOT MSVC)
endif()
# Find CUDA.
-find_package(CUDA)
-if(NOT CUDA_FOUND)
+find_package(CUDAToolkit)
+if(NOT CUDAToolkit_FOUND)
message(WARNING
"Caffe2: CUDA cannot be found. Depending on whether you are building "
"Caffe2 or a Caffe2 dependent library, the next warning / error will "
@@ -36,8 +36,6 @@ if(NOT CUDA_FOUND)
return()
endif()
-# Enable CUDA language support
-set(CUDAToolkit_ROOT "${CUDA_TOOLKIT_ROOT_DIR}")
# Pass clang as host compiler, which according to the docs
# Must be done before CUDA language is enabled, see
# https://cmake.org/cmake/help/v3.15/variable/CMAKE_CUDA_HOST_COMPILER.html
@@ -56,24 +54,18 @@ if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.12.0)
cmake_policy(SET CMP0074 NEW)
endif()
-find_package(CUDAToolkit REQUIRED)
+find_package(CUDAToolkit REQUIRED COMPONENTS cudart nvrtc REQUIRED)
cmake_policy(POP)
-if(NOT CMAKE_CUDA_COMPILER_VERSION VERSION_EQUAL CUDAToolkit_VERSION)
- message(FATAL_ERROR "Found two conflicting CUDA versions:\n"
- "V${CMAKE_CUDA_COMPILER_VERSION} in '${CUDA_INCLUDE_DIRS}' and\n"
- "V${CUDAToolkit_VERSION} in '${CUDAToolkit_INCLUDE_DIRS}'")
-endif()
-
-message(STATUS "Caffe2: CUDA detected: " ${CUDA_VERSION})
-message(STATUS "Caffe2: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
-message(STATUS "Caffe2: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})
-if(CUDA_VERSION VERSION_LESS 11.0)
+message(STATUS "Caffe2: CUDA detected: " ${CUDAToolkit_VERSION})
+message(STATUS "Caffe2: CUDA nvcc is: " ${CUDAToolkit_NVCC_EXECUTABLE})
+message(STATUS "Caffe2: CUDA toolkit directory: " ${CUDAToolkit_ROOT})
+if(CUDAToolkit_VERSION VERSION_LESS 11.0)
message(FATAL_ERROR "PyTorch requires CUDA 11.0 or above.")
endif()
-if(CUDA_FOUND)
+if(CUDAToolkit_FOUND)
# Sometimes, we may mismatch nvcc with the CUDA headers we are
# compiling with, e.g., if a ccache nvcc is fed to us by CUDA_NVCC_EXECUTABLE
# but the PATH is not consistent with CUDA_HOME. It's better safe
@@ -97,8 +89,8 @@ if(CUDA_FOUND)
)
if(NOT CMAKE_CROSSCOMPILING)
try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
- CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
- LINK_LIBRARIES ${CUDA_LIBRARIES}
+ CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDAToolkit_INCLUDE_DIRS}"
+ LINK_LIBRARIES ${CUDAToolkit_LIBRARIES}
RUN_OUTPUT_VARIABLE cuda_version_from_header
COMPILE_OUTPUT_VARIABLE output_var
)
@@ -106,20 +98,20 @@ if(CUDA_FOUND)
message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var})
endif()
message(STATUS "Caffe2: Header version is: " ${cuda_version_from_header})
- if(NOT cuda_version_from_header STREQUAL ${CUDA_VERSION_STRING})
+ if(NOT cuda_version_from_header STREQUAL ${CUDAToolkit_VERSION_STRING})
# Force CUDA to be processed for again next time
# TODO: I'm not sure if this counts as an implementation detail of
# FindCUDA
- set(${cuda_version_from_findcuda} ${CUDA_VERSION_STRING})
+ set(${cuda_version_from_findcuda} ${CUDAToolkit_VERSION_STRING})
unset(CUDA_TOOLKIT_ROOT_DIR_INTERNAL CACHE)
# Not strictly necessary, but for good luck.
- unset(CUDA_VERSION CACHE)
+ unset(CUDAToolkit_VERSION CACHE)
# Error out
message(FATAL_ERROR "FindCUDA says CUDA version is ${cuda_version_from_findcuda} (usually determined by nvcc), "
"but the CUDA headers say the version is ${cuda_version_from_header}. This often occurs "
"when you set both CUDA_HOME and CUDA_NVCC_EXECUTABLE to "
"non-standard locations, without also setting PATH to point to the correct nvcc. "
- "Perhaps, try re-running this command again with PATH=${CUDA_TOOLKIT_ROOT_DIR}/bin:$PATH. "
+ "Perhaps, try re-running this command again with PATH=${CUDAToolkit_ROOT}/bin:$PATH. "
"See above log messages for more diagnostics, and see https://github.com/pytorch/pytorch/issues/8092 for more details.")
endif()
endif()
@@ -128,8 +120,8 @@ endif()
# ---[ CUDA libraries wrapper
# find lbnvrtc.so
-set(CUDA_NVRTC_LIB "${CUDA_nvrtc_LIBRARY}" CACHE FILEPATH "")
-if(CUDA_NVRTC_LIB AND NOT CUDA_NVRTC_SHORTHASH)
+get_target_property(CUDA_NVRTC_LIB cuda::nvrtc INTERFACE_LINK_LIBRARIES)
+if(NOT CUDA_NVRTC_SHORTHASH)
find_package(Python COMPONENTS Interpreter)
execute_process(
COMMAND Python::Interpreter -c |
looks like you've removed |
That should be set by |
Solution to issue cannot be found in the documentation.
Issue
Description:
The
libtorch
package for Windows in CMake is currently not working as expected. When usingfind_package(Torch)
in aCMakeLists.txt
file, it returns an error due to missing .dll files:Reproduction Steps:
Create a new CMake project with the following CMakeLists.txt:
Run cmake on the project and observe the error message.
Note: This issue only occurs on Windows, while macOS and Linux work fine.
The
.dll
files are located in<prefix>/Library/bin/*
, but the configuration is set to find them in<prefix>/Library/lib/*
. I'm unsure where to look into to resolve this issue. Any help would be appreciated!Installed packages
Environment info
The text was updated successfully, but these errors were encountered: