Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM: Update vLLM-cpu to v0.6.6-post1 #12728

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docker/llm/serving/cpu/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ RUN wget -qO /sbin/tini https://github.com/krallin/tini/releases/download/${TINI
apt-get update && \
apt-get install -y --no-install-recommends wrk patch g++ && \
pip install --pre --upgrade ipex-llm[serving] && \
apt-get install -y gcc-12 g++-12 libnuma-dev && \
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && \
# Fix Trivy CVE Issues
pip install Jinja2==3.1.3 transformers==4.36.2 gradio==4.19.2 cryptography==42.0.4 && \
# Fix Qwen model adapter in fastchat
Expand All @@ -24,10 +26,11 @@ RUN wget -qO /sbin/tini https://github.com/krallin/tini/releases/download/${TINI
# Install vllm
git clone https://github.com/vllm-project/vllm.git && \
cd ./vllm && \
git checkout v0.4.2 && \
pip install wheel packaging ninja setuptools>=49.4.0 numpy && \
git checkout v0.6.6.post1 && \
pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy && \
pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu && \
VLLM_TARGET_DEVICE=cpu python3 setup.py install
VLLM_TARGET_DEVICE=cpu python3 setup.py install && \
pip install ray


COPY ./vllm_offline_inference.py /llm/
Expand Down
1 change: 0 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features,
mp_group,
None,
None,
optimize_lm_head,
None
)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def forward(self, x: torch.Tensor):
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
return result
return result.to(x.dtype)


class FP16Linear(nn.Linear):
Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/vllm/cpu/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
__all__ = [
"IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine",
"IPEXLLMClass",
"run_mp_engine",
]
234 changes: 159 additions & 75 deletions python/llm/src/ipex_llm/vllm/cpu/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Optional, Union
from vllm.logger import init_logger
from typing import Dict, Optional, Any, Union, Type
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.config import VllmConfig
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
from vllm.usage.usage_lib import UsageContext
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.multiprocessing.engine import MQLLMEngine
import signal
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.config import CompilationConfig
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
from vllm import envs
from vllm.v1.engine.async_llm import AsyncLLM
import os

from ipex_llm.utils.common import invalidInputError
logger = init_logger(__name__)


class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
Expand All @@ -35,79 +45,100 @@ def __init__(self, *args, **kwargs):
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: Optional[str] = None,
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Enable ipex-llm optimizations
engine_config = engine_args.create_engine_config()
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
invalidInputError(not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend."))
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
invalidInputError(engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)


class IPEXLLMClass(LLM):
class IPEXLLMAsyncV1Engine(AsyncLLM):

def __init__(self, *args, **kwargs):
print("IPEX-LLM V1 engine get started...")
super().__init__(*args, **kwargs)

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
) -> "AsyncLLM":
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)


class IPEXLLMClass(LLM):
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_in_low_bit: Optional[str] = None,
disable_async_output_proc: bool = True,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
load_in_low_bit: str = "sym_int4",
**kwargs,
) -> None:
'''
LLM constructor.

Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''

if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = None

engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand All @@ -116,16 +147,60 @@ def __init__(
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs,
)
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(engine_args,
load_in_low_bit=load_in_low_bit)
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
# TODO(gc): we will need to override this function
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS,
load_in_low_bit=load_in_low_bit)

self.request_counter = Counter()

@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
# from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return IPEXLLMLLMV1Engine # type: ignore
return IPEXLLMLLMEngine


# TODO(gc): implement this later...
class IPEXLLMLLMV1Engine(V1LLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
enable_multiprocessing: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.

# TODO(gc): delete this later
print("IPEXLLM V1 Engine")
# This does not work as it is in the seperate process...
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context,
stat_loggers, enable_multiprocessing)


class IPEXLLMLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs):
Expand All @@ -136,35 +211,44 @@ def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: Optional[str] = None,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
# TODO(gc): Delete
print("Use vLLM v0 engine")
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers)

# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
invalidInputError(engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor

# Create the LLM engine.
engine = cls(**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
)
return engine

class IPEXLLMMQLLMEngine(MQLLMEngine):
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, ipc_path)


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, load_in_low_bit: str, engine_alive):

def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa

try:
signal.signal(signal.SIGTERM, signal_handler)

engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path,
load_in_low_bit=load_in_low_bit)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e # noqa

if os.getenv("VLLM_USE_V1"):
IPEXLLMAsyncLLMEngine = IPEXLLMAsyncV1Engine
Loading
Loading