Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: PyNcclCommunicator error when inferencing #8420

Closed
1 task done
pspdada opened this issue Sep 12, 2024 · 5 comments · Fixed by #8446
Closed
1 task done

[Bug]: PyNcclCommunicator error when inferencing #8420

pspdada opened this issue Sep 12, 2024 · 5 comments · Fixed by #8446
Labels
bug Something isn't working

Comments

@pspdada
Copy link

pspdada commented Sep 12, 2024

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.27

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-150-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090
GPU 2: NVIDIA GeForce RTX 4090
GPU 3: NVIDIA GeForce RTX 4090

Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn.so.8.8.0
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.8.0
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.8.0
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.8.0
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.8.0
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.8.0
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              72
On-line CPU(s) list: 0-71
Thread(s) per core:  2
Core(s) per socket:  18
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz
Stepping:            7
CPU MHz:             3300.000
CPU max MHz:         3900.0000
CPU min MHz:         1000.0000
BogoMIPS:            5200.00
Virtualization:      VT-x
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            25344K
NUMA node0 CPU(s):   0-17,36-53
NUMA node1 CPU(s):   18-35,54-71
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-ml-py==12.560.30
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.44.2
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl    defaults
[conda] cuda-cudart               12.1.105                      0    nvidia
[conda] cuda-cupti                12.1.105                      0    nvidia
[conda] cuda-libraries            12.1.0                        0    nvidia
[conda] cuda-nvrtc                12.1.105                      0    nvidia
[conda] cuda-nvtx                 12.1.105                      0    nvidia
[conda] cuda-opencl               12.6.68                       0    nvidia
[conda] cuda-runtime              12.1.0                        0    nvidia
[conda] cuda-version              12.6                          3    nvidia
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libcublas                 12.1.0.26                     0    nvidia
[conda] libcufft                  11.0.2.4                      0    nvidia
[conda] libcufile                 1.11.1.6                      0    nvidia
[conda] libcurand                 10.3.7.68                     0    nvidia
[conda] libcusolver               11.4.4.55                     0    nvidia
[conda] libcusparse               12.0.2.55                     0    nvidia
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] libnpp                    12.0.2.50                     0    nvidia
[conda] libnvjitlink              12.1.105                      0    nvidia
[conda] libnvjpeg                 12.1.1.14                     0    nvidia
[conda] mkl                       2023.1.0         h213fc3f_46344    defaults
[conda] mkl-service               2.4.0           py310h5eee18b_1    defaults
[conda] mkl_fft                   1.3.10          py310h5eee18b_0    defaults
[conda] mkl_random                1.2.7           py310h1128e8f_0    defaults
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-ml-py              12.560.30                pypi_0    pypi
[conda] pytorch                   2.4.0           py3.10_cuda12.1_cudnn9.1.0_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torchaudio                2.4.0               py310_cu121    pytorch
[conda] torchtriton               3.0.0                     py310    pytorch
[conda] torchvision               0.19.0              py310_cu121    pytorch
[conda] transformers              4.44.2                   pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.0@32e7db25365415841ebc7c4215851743fbb1bad1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    SYS     SYS     0-17,36-53      0               N/A
GPU1    NODE     X      SYS     SYS     0-17,36-53      0               N/A
GPU2    SYS     SYS      X      NODE    18-35,54-71     1               N/A
GPU3    SYS     SYS     NODE     X      18-35,54-71     1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Model Input Dumps

No response

🐛 Describe the bug

When I started creating an instance from the LLM class, an error occurred.
It says AttributeError: 'PyNcclCommunicator' object has no attribute 'device', I don't know what to do then.
The code to test it:

from vllm import LLM
def init_model() -> LLM:
    llm = LLM(
        model="Qwen/Qwen2-7B-Instruct",
        tokenizer_mode="auto",
        trust_remote_code=True,
        download_dir="./.cache",
        tensor_parallel_size=2,  # How many GPUs to use
        gpu_memory_utilization=0.85,
        pipeline_parallel_size=1,
        dtype="bfloat16",
        # max_model_len=20480,  # Model context length
        enable_prefix_caching=True,
        enable_chunked_prefill=False,
        num_scheduler_steps=8,
    )
    return llm
if __name__ == "__main__":
    llm = init_model()
    print(llm.generate("Hello, world!"))
INFO 09-13 00:13:27 config.py:890] Defaulting to use mp for distributed inference
WARNING 09-13 00:13:27 arg_utils.py:880] Enabled BlockSpaceManagerV2 because it is required for multi-step (--num-scheduler-steps > 1)
INFO 09-13 00:13:27 llm_engine.py:213] Initializing an LLM engine (v0.6.0) with config: model='Qwen/Qwen2-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir='./.cache', load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2-7B-Instruct, use_v2_block_manager=True, num_scheduler_steps=8, enable_prefix_caching=True, use_async_output_proc=True)
WARNING 09-13 00:13:28 multiproc_gpu_executor.py:56] Reducing Torch parallelism from 36 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 09-13 00:13:29 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
WARNING 09-13 00:13:29 registry.py:190] `mm_limits` has already been set for model=Qwen/Qwen2-7B-Instruct, and will be overwritten by the new values.
(VllmWorkerProcess pid=18643) WARNING 09-13 00:13:29 registry.py:190] `mm_limits` has already been set for model=Qwen/Qwen2-7B-Instruct, and will be overwritten by the new values.
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:29 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 09-13 00:13:29 utils.py:977] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:29 utils.py:977] Found nccl from library libnccl.so.2
ERROR 09-13 00:13:29 pynccl_wrapper.py:196] Failed to load NCCL library from libnccl.so.2 .It is expected if you are not running on NVIDIA/AMD GPUs.Otherwise, the nccl library might not exist, be corrupted or it does not support the current platform Linux-5.4.0-150-generic-x86_64-with-glibc2.27.If you already have the library, please set the environment variable VLLM_NCCL_SO_PATH to point to the correct nccl library path.
(VllmWorkerProcess pid=18643) ERROR 09-13 00:13:29 pynccl_wrapper.py:196] Failed to load NCCL library from libnccl.so.2 .It is expected if you are not running on NVIDIA/AMD GPUs.Otherwise, the nccl library might not exist, be corrupted or it does not support the current platform Linux-5.4.0-150-generic-x86_64-with-glibc2.27.If you already have the library, please set the environment variable VLLM_NCCL_SO_PATH to point to the correct nccl library path.
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:29 custom_all_reduce_utils.py:242] reading GPU P2P access cache from ~/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 09-13 00:13:29 custom_all_reduce_utils.py:242] reading GPU P2P access cache from ~/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
WARNING 09-13 00:13:29 custom_all_reduce.py:131] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=18643) WARNING 09-13 00:13:29 custom_all_reduce.py:131] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 09-13 00:13:29 shm_broadcast.py:235] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7fac3ec98730>, local_subscribe_port=46869, remote_subscribe_port=None)
INFO 09-13 00:13:29 model_runner.py:915] Starting to load model Qwen/Qwen2-7B-Instruct...
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:29 model_runner.py:915] Starting to load model Qwen/Qwen2-7B-Instruct...
INFO 09-13 00:13:30 weight_utils.py:236] Using model weights format ['*.safetensors']
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:31 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:04,  1.54s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:03,  1.71s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:05<00:01,  1.75s/it]
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:38 model_runner.py:926] Loading model weights took 7.1216 GB
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.80s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.76s/it]

INFO 09-13 00:13:38 model_runner.py:926] Loading model weights took 7.1216 GB
INFO 09-13 00:13:44 distributed_gpu_executor.py:57] # GPU blocks: 22639, # CPU blocks: 9362
INFO 09-13 00:13:50 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 09-13 00:13:50 model_runner.py:1221] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:50 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=18643) INFO 09-13 00:13:50 model_runner.py:1221] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
[rank0]: Traceback (most recent call last):
[rank0]:   File "~/psp/Reasoning-Carefully/test.py", line 25, in <module>
[rank0]:     llm = init_model()
[rank0]:   File "~/psp/Reasoning-Carefully/test.py", line 6, in init_model
[rank0]:     llm = LLM(
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 177, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 538, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 319, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 461, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/executor/distributed_gpu_executor.py", line 63, in initialize_cache
[rank0]:     self._run_workers("initialize_cache",
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/executor/multiproc_gpu_executor.py", line 199, in _run_workers
[rank0]:     driver_worker_output = driver_worker_method(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/worker/worker.py", line 265, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/worker/worker.py", line 281, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/worker/multi_step_model_runner.py", line 543, in capture_model
[rank0]:     return self._base_model_runner.capture_model(kv_caches)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1327, in capture_model
[rank0]:     graph_runner.capture(**capture_inputs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1569, in capture
[rank0]:     self.model(
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 361, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 269, in forward
[rank0]:     hidden_states = self.embed_tokens(input_ids)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 406, in forward
[rank0]:     output = tensor_model_parallel_all_reduce(output_parallel)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/distributed/communication_op.py", line 11, in tensor_model_parallel_all_reduce
[rank0]:     return get_tp_group().all_reduce(input_)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/distributed/parallel_state.py", line 288, in all_reduce
[rank0]:     pynccl_comm.all_reduce(input_)
[rank0]:   File "~/anaconda3/envs/psp/lib/python3.10/site-packages/vllm/distributed/device_communicators/pynccl.py", line 113, in all_reduce
[rank0]:     assert tensor.device == self.device, (
[rank0]: AttributeError: 'PyNcclCommunicator' object has no attribute 'device'
INFO 09-13 00:13:54 multiproc_worker_utils.py:123] Killing local vLLM worker processes
Fatal Python error: _enter_buffered_busy: could not acquire lock for <_io.BufferedWriter name='<stdout>'> at interpreter shutdown, possibly due to daemon threads
Python runtime state: finalizing (tstate=0x0000000000fadf60)

Current thread 0x00007facf08ba100 (most recent call first):
  <no Python frame>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, _brotli, yaml._yaml, msgspec._core, psutil._psutil_linux, psutil._psutil_posix, sentencepiece._sentencepiece, PIL._imaging, PIL._imagingft, gmpy2.gmpy2, regex._regex, msgpack._cmsgpack, google._upb._message, setproctitle, uvloop.loop, ray._raylet, multidict._multidict, yarl._helpers_c, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, zmq.backend.cython._zmq (total: 44)
~/anaconda3/envs/psp/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
Aborted (core dumped)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@pspdada pspdada added the bug Something isn't working label Sep 12, 2024
@DarkLight1337
Copy link
Member

May be related to #5484, please take a look at that thread.

@youkaichao
Copy link
Member

ERROR 09-13 00:13:29 pynccl_wrapper.py:196] Failed to load NCCL library from libnccl.so.2 .It is expected if you are not running on NVIDIA/AMD GPUs.Otherwise, the nccl library might not exist, be corrupted or it does not support the current platform Linux-5.4.0-150-generic-x86_64-with-glibc2.27.If you already have the library, please set the environment variable VLLM_NCCL_SO_PATH to point to the correct nccl library path.

I think this is the root cause.

you can follow

def find_nccl_library() -> str:
to find out why nccl is not found in your case.

@youkaichao
Copy link
Member

might be related: pytorch/pytorch#132617

if you install pytorch via conda, it might cause problems. the recommended way is to use pip install torch

@pspdada
Copy link
Author

pspdada commented Sep 13, 2024

might be related: pytorch/pytorch#132617

if you install pytorch via conda, it might cause problems. the recommended way is to use pip install torch

Thank you so much! This method did indeed solve my problem directly. At first, I was installing PyTorch using conda, but switching to installing it with pip resolved the issue.

@iamthebot
Copy link

BTW, torch installed from conda-forge does not statically link NCCL. So it's fine to use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants