Skip to content

Commit

Permalink
[cuda] manually import the correct pynvml module (vllm-project#12679)
Browse files Browse the repository at this point in the history
fixes problems like vllm-project#12635 and
vllm-project#12636 and
vllm-project#12565

---------

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Feb 3, 2025
1 parent b998645 commit ad4a9dc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
3 changes: 2 additions & 1 deletion vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def cuda_platform_plugin() -> Optional[str]:
is_cuda = False

try:
import pynvml
from vllm.utils import import_pynvml
pynvml = import_pynvml()
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
Expand Down
10 changes: 2 additions & 8 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
Union)

import pynvml
import torch
from typing_extensions import ParamSpec

# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import import_pynvml

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

Expand All @@ -29,13 +29,7 @@
_P = ParamSpec("_P")
_R = TypeVar("_R")

if pynvml.__file__.endswith("__init__.py"):
logger.warning(
"You are using a deprecated `pynvml` package. Please install"
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
" When both of them are installed, `pynvml` will take precedence"
" and cause errors. See https://pypi.org/project/pynvml "
"for more information.")
pynvml = import_pynvml()

# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
Expand Down
52 changes: 52 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,3 +2208,55 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)


def import_pynvml():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
TL;DR: if users have pynvml<12.0 installed, it will cause problems.
Otherwise, `import pynvml` will import the correct module.
We take the safest approach here, to manually import the correct
`pynvml.py` module from the `nvidia-ml-py` package.
"""
if TYPE_CHECKING:
import pynvml
return pynvml
if "pynvml" in sys.modules:
import pynvml
if pynvml.__file__.endswith("__init__.py"):
# this is pynvml < 12.0
raise RuntimeError(
"You are using a deprecated `pynvml` package. "
"Please uninstall `pynvml` or upgrade to at least"
" version 12.0. See https://pypi.org/project/pynvml "
"for more information.")
return sys.modules["pynvml"]
import importlib.util
import os
import site
for site_dir in site.getsitepackages():
pynvml_path = os.path.join(site_dir, "pynvml.py")
if os.path.exists(pynvml_path):
spec = importlib.util.spec_from_file_location(
"pynvml", pynvml_path)
pynvml = importlib.util.module_from_spec(spec)
sys.modules["pynvml"] = pynvml
spec.loader.exec_module(pynvml)
return pynvml

0 comments on commit ad4a9dc

Please sign in to comment.