Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] New interface and automatic detection for PP support #9000

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ steps:
source_file_dependencies:
- vllm/
- tests/test_regression
command: pytest -v -s test_regression.py
commands:
- pip install modelscope
- pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional

- label: Engine Test # 10min
Expand Down
4 changes: 2 additions & 2 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pytest-shard
awscli
einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio test
opencv-python # required for video test
librosa # required for audio tests
opencv-python # required for video tests
peft
requests
ray[adag]==2.35
Expand Down
52 changes: 49 additions & 3 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,55 @@
import warnings

import pytest
import torch.cuda

from vllm.model_executor.models import _MODELS, ModelRegistry
from vllm.platforms import current_platform

from ..utils import fork_new_process_for_each_test


@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
@pytest.mark.parametrize("model_arch", _MODELS)
def test_registry_imports(model_arch):
# Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls([model_cls])
ModelRegistry.resolve_model_cls(model_arch)


@fork_new_process_for_each_test
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
("LlamaForCausalLM", False, False),
("MllamaForConditionalGeneration", True, False),
("LlavaForConditionalGeneration", True, True),
])
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm

if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized()

ModelRegistry.resolve_model_cls(model_arch)
if not torch.cuda.is_initialized():
warnings.warn(
"This model no longer initializes CUDA on import. "
"Please test using a different one.",
stacklevel=2)


@fork_new_process_for_each_test
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
("MLPSpeculatorPreTrainedModel", False, False),
("DeepseekV2ForCausalLM", True, False),
("Qwen2VLForConditionalGeneration", True, True),
])
def test_registry_is_pp(model_arch, is_pp, init_cuda):
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp

if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized()

ModelRegistry.resolve_model_cls(model_arch)
if not torch.cuda.is_initialized():
warnings.warn(
"This model no longer initializes CUDA on import. "
"Please test using a different one.",
stacklevel=2)
60 changes: 18 additions & 42 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,6 @@
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096

_PP_SUPPORTED_MODELS = [
"AquilaForCausalLM",
"AquilaModel",
"DeepseekV2ForCausalLM",
"GPT2LMHeadModel",
"InternLM2ForCausalLM",
"InternLMForCausalLM",
"InternVLChatModel",
"JAISLMHeadModel",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"MixtralForCausalLM",
"NemotronForCausalLM",
"Phi3ForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
"Qwen2VLForConditionalGeneration",
]


class ModelConfig:
"""Configuration for the model.
Expand Down Expand Up @@ -228,16 +207,14 @@ def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
architectures = getattr(self.hf_config, "architectures", [])
if any(
ModelRegistry.is_multimodal_model(arch)
for arch in architectures):
if ModelRegistry.is_multimodal_model(architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
else:
if limit_mm_per_prompt:
raise ValueError(
"limit_mm_per_prompt is only supported for multimodal "
"models.")
return None

if limit_mm_per_prompt:
raise ValueError("`limit_mm_per_prompt` is only supported for "
"multimodal models.")

return None

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
Expand All @@ -249,8 +226,7 @@ def _verify_tokenizer_mode(self) -> None:

def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)
self.embedding_mode = ModelRegistry.is_embedding_model(architectures)

def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
Expand Down Expand Up @@ -417,17 +393,17 @@ def verify_with_parallel_config(
f"({tensor_parallel_size}).")

pipeline_parallel_size = parallel_config.pipeline_parallel_size
architectures = getattr(self.hf_config, "architectures", [])
if not all(arch in _PP_SUPPORTED_MODELS
for arch in architectures) and pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
if pipeline_parallel_size > 1:
architectures = getattr(self.hf_config, "architectures", [])
if not ModelRegistry.is_pp_supported_model(architectures):
raise NotImplementedError(
"Pipeline parallelism is not supported for this model. "
"Supported models implement the `SupportsPP` interface.")

if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
if self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
Expand Down
147 changes: 125 additions & 22 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import functools
import importlib
from typing import Dict, List, Optional, Tuple, Type
import string
import subprocess
import sys
import uuid
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

from .interfaces import supports_multimodal, supports_pp

logger = init_logger(__name__)

_GENERATION_MODELS = {
Expand Down Expand Up @@ -152,19 +158,25 @@
class ModelRegistry:

@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name

@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None

module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
module = importlib.import_module(module_name)
return getattr(module, cls_name, None)

@staticmethod
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]

if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
Expand All @@ -175,11 +187,24 @@ def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

return ModelRegistry._get_model(model_arch)
return None

@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return model

return ModelRegistry._try_get_model_stateful(model_arch)

@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
Expand All @@ -200,21 +225,99 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]):
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS

_OOT_MODELS[model_arch] = model_cls

@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
@lru_cache(maxsize=128)
def _check_stateless(
func: Callable[[Type[nn.Module]], bool],
model_arch: str,
*,
default: Optional[bool] = None,
) -> bool:
"""
Run a boolean function against a model and return the result.

If the model is not found, returns the provided default value.

If the model is not already imported, the function is run inside a
subprocess to avoid initializing CUDA for the main program.
"""
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return func(model)

if model_arch not in _MODELS and default is not None:
return default

module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)

valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in module_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
if any(s not in valid_name_characters for s in func.__module__):
raise ValueError(f"Unsafe module name detected for {func}")
if any(s not in valid_name_characters for s in func.__name__):
raise ValueError(f"Unsafe class name detected for {func}")

err_id = uuid.uuid4()

stmts = ";".join([
f"from {module_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])

result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True)

if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")

return result.returncode == 0

@staticmethod
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

return any(arch in _EMBEDDING_MODELS for arch in architectures)

@staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

is_mm = partial(ModelRegistry._check_stateless,
supports_multimodal,
default=False)

return any(is_mm(arch) for arch in architectures)

@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

is_pp = partial(ModelRegistry._check_stateless,
supports_pp,
default=False)

# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
return any(is_pp(arch) for arch in architectures)


__all__ = [
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand All @@ -50,6 +49,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers


Expand Down Expand Up @@ -472,7 +472,7 @@ def forward(
return hidden_states


class DeepseekV2ForCausalLM(nn.Module):
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):

def __init__(
self,
Expand Down
Loading
Loading