diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f678436dd05e1..427dc14513d45 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -146,7 +146,9 @@ steps: source_file_dependencies: - vllm/ - tests/test_regression - command: pytest -v -s test_regression.py + commands: + - pip install modelscope + - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min diff --git a/requirements-test.txt b/requirements-test.txt index 9c6fadb88865a..37c3bd8ba8794 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,8 +10,8 @@ pytest-shard awscli einops # required for MPT, qwen-vl and Mamba httpx -librosa # required for audio test -opencv-python # required for video test +librosa # required for audio tests +opencv-python # required for video tests peft requests ray[adag]==2.35 diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b058e2755c245..ee5c9e8ccb196 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -1,9 +1,55 @@ +import warnings + import pytest +import torch.cuda from vllm.model_executor.models import _MODELS, ModelRegistry +from vllm.platforms import current_platform + +from ..utils import fork_new_process_for_each_test -@pytest.mark.parametrize("model_cls", _MODELS) -def test_registry_imports(model_cls): +@pytest.mark.parametrize("model_arch", _MODELS) +def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully - ModelRegistry.resolve_model_cls([model_cls]) + ModelRegistry.resolve_model_cls(model_arch) + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ + ("LlamaForCausalLM", False, False), + ("MllamaForConditionalGeneration", True, False), + ("LlavaForConditionalGeneration", True, True), +]) +def test_registry_is_multimodal(model_arch, is_mm, init_cuda): + assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + + if init_cuda and current_platform.is_cuda_alike(): + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different one.", + stacklevel=2) + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ + ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), +]) +def test_registry_is_pp(model_arch, is_pp, init_cuda): + assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp + + if init_cuda and current_platform.is_cuda_alike(): + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different one.", + stacklevel=2) diff --git a/vllm/config.py b/vllm/config.py index 1310c07ade482..0dc805f46fe69 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,27 +33,6 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 -_PP_SUPPORTED_MODELS = [ - "AquilaForCausalLM", - "AquilaModel", - "DeepseekV2ForCausalLM", - "GPT2LMHeadModel", - "InternLM2ForCausalLM", - "InternLMForCausalLM", - "InternVLChatModel", - "JAISLMHeadModel", - "LlamaForCausalLM", - "LLaMAForCausalLM", - "MistralForCausalLM", - "MixtralForCausalLM", - "NemotronForCausalLM", - "Phi3ForCausalLM", - "Qwen2ForCausalLM", - "Qwen2MoeForCausalLM", - "QWenLMHeadModel", - "Qwen2VLForConditionalGeneration", -] - class ModelConfig: """Configuration for the model. @@ -228,16 +207,14 @@ def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: architectures = getattr(self.hf_config, "architectures", []) - if any( - ModelRegistry.is_multimodal_model(arch) - for arch in architectures): + if ModelRegistry.is_multimodal_model(architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) - else: - if limit_mm_per_prompt: - raise ValueError( - "limit_mm_per_prompt is only supported for multimodal " - "models.") - return None + + if limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + + return None def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -249,8 +226,7 @@ def _verify_tokenizer_mode(self) -> None: def _verify_embedding_mode(self) -> None: architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = any( - ModelRegistry.is_embedding_model(arch) for arch in architectures) + self.embedding_mode = ModelRegistry.is_embedding_model(architectures) def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -417,17 +393,17 @@ def verify_with_parallel_config( f"({tensor_parallel_size}).") pipeline_parallel_size = parallel_config.pipeline_parallel_size - architectures = getattr(self.hf_config, "architectures", []) - if not all(arch in _PP_SUPPORTED_MODELS - for arch in architectures) and pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported for the following " - f" architectures: {_PP_SUPPORTED_MODELS}.") + if pipeline_parallel_size > 1: + architectures = getattr(self.hf_config, "architectures", []) + if not ModelRegistry.is_pp_supported_model(architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") - if pipeline_parallel_size > 1 and self.use_async_output_proc: - logger.warning("Async output processor is not supported with " - "pipeline parallelism currently. Disabling it.") - self.use_async_output_proc = False + if self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3a57db0d04fab..2f9cb2b760a82 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,12 +1,18 @@ -import functools import importlib -from typing import Dict, List, Optional, Tuple, Type +import string +import subprocess +import sys +import uuid +from functools import lru_cache, partial +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import torch.nn as nn from vllm.logger import init_logger from vllm.utils import is_hip +from .interfaces import supports_multimodal, supports_pp + logger = init_logger(__name__) _GENERATION_MODELS = { @@ -152,19 +158,25 @@ class ModelRegistry: @staticmethod - @functools.lru_cache(maxsize=128) - def _get_model(model_arch: str): - module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module( - f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) + def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: + module_relname, cls_name = _MODELS[model_arch] + return f"vllm.model_executor.models.{module_relname}", cls_name @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch in _OOT_MODELS: - return _OOT_MODELS[model_arch] + @lru_cache(maxsize=128) + def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: return None + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + module = importlib.import_module(module_name) + return getattr(module, cls_name, None) + + @staticmethod + def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + if is_hip(): if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError( @@ -175,11 +187,24 @@ def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: "Model architecture %s is partially supported by ROCm: %s", model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - return ModelRegistry._get_model(model_arch) + return None + + @staticmethod + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return model + + return ModelRegistry._try_get_model_stateful(model_arch) @staticmethod def resolve_model_cls( - architectures: List[str]) -> Tuple[Type[nn.Module], str]: + architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + for arch in architectures: model_cls = ModelRegistry._try_load_model_cls(arch) if model_cls is not None: @@ -200,21 +225,99 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, model_cls.__name__) - global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls @staticmethod - def is_embedding_model(model_arch: str) -> bool: - return model_arch in _EMBEDDING_MODELS + @lru_cache(maxsize=128) + def _check_stateless( + func: Callable[[Type[nn.Module]], bool], + model_arch: str, + *, + default: Optional[bool] = None, + ) -> bool: + """ + Run a boolean function against a model and return the result. + + If the model is not found, returns the provided default value. + + If the model is not already imported, the function is run inside a + subprocess to avoid initializing CUDA for the main program. + """ + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return func(model) + + if model_arch not in _MODELS and default is not None: + return default + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + + valid_name_characters = string.ascii_letters + string.digits + "._" + if any(s not in valid_name_characters for s in module_name): + raise ValueError(f"Unsafe module name detected for {model_arch}") + if any(s not in valid_name_characters for s in cls_name): + raise ValueError(f"Unsafe class name detected for {model_arch}") + if any(s not in valid_name_characters for s in func.__module__): + raise ValueError(f"Unsafe module name detected for {func}") + if any(s not in valid_name_characters for s in func.__name__): + raise ValueError(f"Unsafe class name detected for {func}") + + err_id = uuid.uuid4() + + stmts = ";".join([ + f"from {module_name} import {cls_name}", + f"from {func.__module__} import {func.__name__}", + f"assert {func.__name__}({cls_name}), '{err_id}'", + ]) + + result = subprocess.run([sys.executable, "-c", stmts], + capture_output=True) + + if result.returncode != 0: + err_lines = [line.decode() for line in result.stderr.splitlines()] + if err_lines and err_lines[-1] != f"AssertionError: {err_id}": + err_str = "\n".join(err_lines) + raise RuntimeError( + "An unexpected error occurred while importing the model in " + f"another process. Error log:\n{err_str}") + + return result.returncode == 0 + + @staticmethod + def is_embedding_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _EMBEDDING_MODELS for arch in architectures) + + @staticmethod + def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_mm = partial(ModelRegistry._check_stateless, + supports_multimodal, + default=False) + + return any(is_mm(arch) for arch in architectures) @staticmethod - def is_multimodal_model(model_arch: str) -> bool: + def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_pp = partial(ModelRegistry._check_stateless, + supports_pp, + default=False) - # TODO: find a way to avoid initializing CUDA prematurely to - # use `supports_multimodal` to determine if a model is multimodal - # model_cls = ModelRegistry._try_load_model_cls(model_arch) - # from vllm.model_executor.models.interfaces import supports_multimodal - return model_arch in _MULTIMODAL_MODELS + return any(is_pp(arch) for arch in architectures) __all__ = [ diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8cbd9435ec7ca..8bdeac29e2458 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -472,7 +472,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fb5a297661ddc..52c92353784b0 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -32,8 +32,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,6 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -234,7 +234,7 @@ def forward( return hidden_states -class GPT2LMHeadModel(nn.Module): +class GPT2LMHeadModel(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 069948f812253..298174fa05965 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,11 +1,17 @@ -from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, - Union, overload, runtime_checkable) +import inspect +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, + Protocol, Type, Union, overload, runtime_checkable) +import torch from typing_extensions import TypeIs -from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.sequence import IntermediateTensors + logger = init_logger(__name__) @@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ - def __init__(self, *, multimodal_config: MultiModalConfig) -> None: + def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -32,7 +38,7 @@ def __init__(self, *, multimodal_config: MultiModalConfig) -> None: class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] - def __call__(self, *, multimodal_config: MultiModalConfig) -> None: + def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -75,7 +81,7 @@ class SupportsLoRA(Protocol): embedding_padding_modules: ClassVar[List[str]] # lora_config is None when LoRA is not enabled - def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol): embedding_modules: Dict[str, str] embedding_padding_modules: List[str] - def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -145,6 +151,132 @@ def _supports_lora( return isinstance(model, SupportsLoRA) +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + """ + Accept :class:`IntermediateTensors` when PP rank > 0. + + Return :class:`IntermediateTensors` only for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + ... + + +@overload +def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect + + +def _supports_pp_attributes( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + +def _supports_pp_inspect( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + forward_params = inspect.signature(model_forward).parameters + return "intermediate_tensors" in forward_params + + @runtime_checkable class HasInnerState(Protocol): """The interface required for all models that has inner state.""" @@ -158,7 +290,7 @@ class HasInnerState(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... @@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 11a8431a5e7f7..28acbd73625c1 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -18,8 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,6 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -297,7 +297,7 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module): +class InternLM2ForCausalLM(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e84990a2ab109..a5535aa5d5c81 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -32,7 +32,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) @@ -362,7 +362,7 @@ def dummy_data_for_internvl(ctx: InputContext, @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) -class InternVLChatModel(nn.Module, SupportsMultiModal): +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: PretrainedConfig, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index b0fbb7e9829e0..d99bd4901656d 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -33,8 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -43,6 +42,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig +from .interfaces import SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -279,7 +279,7 @@ def forward( return hidden_states -class JAISLMHeadModel(nn.Module): +class JAISLMHeadModel(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec9..c5d5c3595dd1b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -51,7 +50,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_hip -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -344,7 +343,7 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module, SupportsLoRA): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 10cbfcf6432b3..40b9b0d7c7450 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -36,8 +36,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -306,7 +305,7 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module, SupportsLoRA): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e9ff12de2094e..134aab59c6e0e 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -34,8 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,7 +45,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers # The architecture is pretty similar to Llama, with these changes: @@ -372,7 +371,7 @@ def forward( return hidden_states -class NemotronForCausalLM(nn.Module, SupportsLoRA): +class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 761c1370b9776..704c65a80dee7 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -31,15 +31,13 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -47,6 +45,7 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of +from .interfaces import SupportsMultiModal, SupportsPP from .utils import flatten_bn, is_pp_missing_parameter, make_layers logger = init_logger(__name__) @@ -860,7 +859,7 @@ def dummy_data_for_qwen( @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal): +class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 5e6737ad7fa47..ac06ddb8cc60e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -48,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -298,7 +297,7 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module, SupportsLoRA): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index d80064601d993..55d5039e6edbc 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -42,8 +42,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -53,6 +52,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once +from .interfaces import SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -368,7 +368,7 @@ def forward( return hidden_states -class Qwen2MoeForCausalLM(nn.Module): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c82e8ed6ed1e0..9c822541c627b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -55,7 +55,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalInputs) @@ -68,6 +67,7 @@ from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext, "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: Qwen2VLConfig, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f6218bad4ef1e..761f0406b1333 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -24,7 +24,7 @@ class WeightsGroup(UserDict): when attempting to access a weight component that does not exist. """ - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: try: return super().__getitem__(key) except KeyError as exc: @@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], def group_weights_with_prefix( - weights: Iterable[Tuple[str, torch.Tensor]] -) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: + weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup: """ Helper function to group weights with prefix """ @@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, class LayerFn(Protocol): - def __call__( - self, - prefix="", - ) -> torch.nn.Module: + def __call__(self, prefix: str) -> torch.nn.Module: ... @@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( - batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: return IntermediateTensors({ key: torch.zeros((batch_size, hidden_size), dtype=dtype, @@ -342,8 +340,14 @@ def __init__(self, llm: nn.Module, name: str) -> None: self.model_name = name setattr(self, name, llm) - def forward(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name)(*args, **kwargs) + def __getattr__(self, key: str): + llm = super().__getattr__(self.model_name) + if key == self.model_name: + return llm - def embed_tokens(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name).embed_tokens(*args, **kwargs) + return getattr(llm, key) + + # We need to explicitly override this + def __call__(self, *args: Any, **kwargs: Any) -> Any: + llm = super().__getattr__(self.model_name) + return llm(*args, **kwargs)