Skip to content

Commit

Permalink
Add manual cuda deps search logic (#90411) (#90426)
Browse files Browse the repository at this point in the history
If PyTorch is package into a wheel with [nvidia-cublas-cu11](https://pypi.org/project/nvidia-cublas-cu11/), which is designated as PureLib, but `torch` wheel is not, can cause a torch_globals loading problem.

Fix that by searching for `nvidia/cublas/lib/libcublas.so.11` an `nvidia/cudnn/lib/libcudnn.so.8` across all `sys.path` folders.

Test plan:
```
docker pull amazonlinux:2
docker run --rm -t amazonlinux:2 bash -c 'yum install -y python3 python3-devel python3-distutils patch;python3 -m pip install torch==1.13.0;curl -OL https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/90411.diff; pushd /usr/local/lib64/python3.7/site-packages; patch -p1 </90411.diff; popd; python3 -c "import torch;print(torch.__version__, torch.cuda.is_available())"'
```

Fixes #88869

Pull Request resolved: #90411
Approved by: https://github.com/atalman

Co-authored-by: Nikita Shulga <[email protected]>
  • Loading branch information
atalman and malfet authored Dec 8, 2022
1 parent a4d16e0 commit 56de8a3
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@
kernel32.SetErrorMode(prev_error_mode)


def _preload_cuda_deps():
""" Preloads cudnn/cublas deps if they could not be found otherwise """
# Should only be called on Linux if default path resolution have failed
assert platform.system() == 'Linux', 'Should only be called on Linux'
for path in sys.path:
nvidia_path = os.path.join(path, 'nvidia')
if not os.path.exists(nvidia_path):
continue
cublas_path = os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.11')
cudnn_path = os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.8')
if not os.path.exists(cublas_path) or not os.path.exists(cudnn_path):
continue
break

ctypes.CDLL(cublas_path)
ctypes.CDLL(cudnn_path)


# See Note [Global dependencies]
def _load_global_deps():
if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
Expand All @@ -150,7 +168,15 @@ def _load_global_deps():
here = os.path.abspath(__file__)
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)

ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
try:
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
except OSError as err:
# Can only happen of wheel with cublas as PYPI deps
# As PyTorch is not purelib, but nvidia-cublas-cu11 is
if 'libcublas.so.11' not in err.args[0]:
raise err
_preload_cuda_deps()
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)


if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
Expand Down

0 comments on commit 56de8a3

Please sign in to comment.