Skip to content

Commit

Permalink
[core][bugfix] configure env var during import vllm (vllm-project#12209)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Matthew Hendrey <[email protected]>
  • Loading branch information
youkaichao authored and mhendrey committed Jan 23, 2025
1 parent 957ca23 commit 399d224
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 45 deletions.
7 changes: 1 addition & 6 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams, configure_as_vllm_process
from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker

Expand Down Expand Up @@ -98,12 +98,7 @@ def __init__(self, *args, **kwargs):
"""
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
It is important for all the processes outside of vLLM to call
`configure_as_vllm_process` to set some common environment variables
the same as vLLM workers.
"""
configure_as_vllm_process()

train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
Expand Down
49 changes: 13 additions & 36 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
import os

import torch

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand All @@ -17,43 +20,18 @@

from .version import __version__, __version_tuple__

# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.

def configure_as_vllm_process():
"""
set some common config/environment variables that should be set
for all processes created by vllm and all processes
that interact with vllm workers.
"""
import os

import torch

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1

from vllm.platforms import current_platform

if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1

__all__ = [
"__version__",
Expand All @@ -80,5 +58,4 @@ def configure_as_vllm_process():
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"configure_as_vllm_process",
]
23 changes: 23 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import os
from typing import Callable, Dict

import torch

import vllm.envs as envs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,6 +54,26 @@ def load_general_plugins():
if plugins_loaded:
return
plugins_loaded = True

# some platform-specific configurations
from vllm.platforms import current_platform

if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'

plugins = load_plugins_by_group(group='vllm.general_plugins')
# general plugins, we only need to execute the loaded functions
for func in plugins.values():
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,6 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
kwargs = all_kwargs[self.rpc_rank]
enable_trace_function_call_for_thread(self.vllm_config)

from vllm import configure_as_vllm_process
configure_as_vllm_process()

from vllm.plugins import load_general_plugins
load_general_plugins()

Expand Down

0 comments on commit 399d224

Please sign in to comment.