From eea3fc547941ce10a2c5705648fbe21d89d783ae Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:09:41 +0000 Subject: [PATCH 01/22] Add SupportsPP interface and stateless protocol check --- vllm/config.py | 61 ++++-------- vllm/model_executor/models/__init__.py | 119 ++++++++++++++++++++--- vllm/model_executor/models/interfaces.py | 88 +++++++++++++++-- 3 files changed, 201 insertions(+), 67 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3139c5a08bfb8..365c4e906bb64 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,27 +33,6 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 -_PP_SUPPORTED_MODELS = [ - "AquilaForCausalLM", - "AquilaModel", - "DeepseekV2ForCausalLM", - "GPT2LMHeadModel", - "InternLM2ForCausalLM", - "InternLMForCausalLM", - "InternVLChatModel", - "JAISLMHeadModel", - "LlamaForCausalLM", - "LLaMAForCausalLM", - "MistralForCausalLM", - "MixtralForCausalLM", - "NemotronForCausalLM", - "Phi3ForCausalLM", - "Qwen2ForCausalLM", - "Qwen2MoeForCausalLM", - "QWenLMHeadModel", - "Qwen2VLForConditionalGeneration", -] - class ModelConfig: """Configuration for the model. @@ -228,16 +207,14 @@ def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: architectures = getattr(self.hf_config, "architectures", []) - if any( - ModelRegistry.is_multimodal_model(arch) - for arch in architectures): + if ModelRegistry.is_multimodal_model(architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) - else: - if limit_mm_per_prompt: - raise ValueError( - "limit_mm_per_prompt is only supported for multimodal " - "models.") - return None + + if limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + + return None def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -249,8 +226,7 @@ def _verify_tokenizer_mode(self) -> None: def _verify_embedding_mode(self) -> None: architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = any( - ModelRegistry.is_embedding_model(arch) for arch in architectures) + self.embedding_mode = ModelRegistry.is_embedding_model(architectures) def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -417,17 +393,18 @@ def verify_with_parallel_config( f"({tensor_parallel_size}).") pipeline_parallel_size = parallel_config.pipeline_parallel_size - architectures = getattr(self.hf_config, "architectures", []) - if not all(arch in _PP_SUPPORTED_MODELS - for arch in architectures) and pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported for the following " - f" architectures: {_PP_SUPPORTED_MODELS}.") + if pipeline_parallel_size > 1: + architectures = getattr(self.hf_config, "architectures", []) + if not ModelRegistry.is_pp_supported_model(architectures): + raise NotImplementedError( + "Pipeline parallelism is only supported for the following " + f"architectures: {ModelRegistry.get_pp_supported_archs()}." + ) - if pipeline_parallel_size > 1 and self.use_async_output_proc: - logger.warning("Async output processor is not supported with " - "pipeline parallelism currently. Disabling it.") - self.use_async_output_proc = False + if self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 682a2e71a1dbf..d5bb5ed6774f4 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,12 +1,18 @@ -import functools import importlib -from typing import Dict, List, Optional, Tuple, Type +import importlib.util +import string +import subprocess +import sys +from functools import lru_cache, partial +from typing import Dict, List, Optional, Tuple, Type, Union import torch.nn as nn from vllm.logger import init_logger from vllm.utils import is_hip +from .interfaces import SupportsMultiModal, SupportsPP + logger = init_logger(__name__) _GENERATION_MODELS = { @@ -117,6 +123,27 @@ **_CONDITIONAL_GENERATION_MODELS, } +_PP_SUPPORTED_MODELS = [ + "AquilaForCausalLM", + "AquilaModel", + "DeepseekV2ForCausalLM", + "GPT2LMHeadModel", + "InternLM2ForCausalLM", + "InternLMForCausalLM", + "InternVLChatModel", + "JAISLMHeadModel", + "LlamaForCausalLM", + "LLaMAForCausalLM", + "MistralForCausalLM", + "MixtralForCausalLM", + "NemotronForCausalLM", + "Phi3ForCausalLM", + "Qwen2ForCausalLM", + "Qwen2MoeForCausalLM", + "QWenLMHeadModel", + "Qwen2VLForConditionalGeneration", +] + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -150,12 +177,48 @@ class ModelRegistry: @staticmethod - @functools.lru_cache(maxsize=128) - def _get_model(model_arch: str): - module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module( - f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) + def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: + module_relname, cls_name = _MODELS[model_arch] + return f"vllm.model_executor.models.{module_relname}", cls_name + + @staticmethod + @lru_cache(maxsize=128) + def _get_model(model_arch: str) -> Optional[Type[nn.Module]]: + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + module = importlib.import_module(module_name) + return getattr(module, cls_name, None) + + @staticmethod + @lru_cache(maxsize=128) + def _is_subclass_stateless(model_arch: str, class_: type) -> bool: + """ + Test whether a model is a subclass of the given type. + + This is run in a subprocess to avoid initializing CUDA for the main + program. + """ + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + + valid_name_characters = string.ascii_letters + string.digits + "._" + if any(s not in valid_name_characters for s in module_name): + raise ValueError(f"Unsafe module name detected for {model_arch}") + if any(s not in valid_name_characters for s in cls_name): + raise ValueError(f"Unsafe class name detected for {model_arch}") + if any(s not in valid_name_characters for s in class_.__module__): + raise ValueError(f"Unsafe module name detected for {class_}") + if any(s not in valid_name_characters for s in class_.__name__): + raise ValueError(f"Unsafe class name detected for {class_}") + + stmts = ";".join([ + f"from {module_name} import {cls_name}", + f"from {class_.__module__} import {class_.__name__}", + f"assert isinstance({cls_name}, {class_.__name__})", + ]) + + result = subprocess.run([sys.executable, "-c", stmts], + capture_output=True) + + return result.returncode == 0 @staticmethod def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: @@ -202,17 +265,41 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): _OOT_MODELS[model_arch] = model_cls @staticmethod - def is_embedding_model(model_arch: str) -> bool: - return model_arch in _EMBEDDING_MODELS + def is_embedding_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _EMBEDDING_MODELS for arch in architectures) + + @staticmethod + def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_mm = partial(ModelRegistry._is_subclass_stateless, + class_=SupportsMultiModal) + + return any(is_mm(arch) for arch in architectures) @staticmethod - def is_multimodal_model(model_arch: str) -> bool: + def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_pp = partial(ModelRegistry._is_subclass_stateless, + class_=SupportsPP) - # TODO: find a way to avoid initializing CUDA prematurely to - # use `supports_multimodal` to determine if a model is multimodal - # model_cls = ModelRegistry._try_load_model_cls(model_arch) - # from vllm.model_executor.models.interfaces import supports_multimodal - return model_arch in _MULTIMODAL_MODELS + return any(is_pp(arch) for arch in architectures) + + @staticmethod + def get_pp_supported_archs() -> List[str]: + return list(_PP_SUPPORTED_MODELS) __all__ = [ diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 069948f812253..36e705fa2e787 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,11 +1,18 @@ -from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, - Union, overload, runtime_checkable) +import ast +import inspect +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, + Protocol, Type, Union, overload, runtime_checkable) +import torch from typing_extensions import TypeIs -from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.sequence import IntermediateTensors + logger = init_logger(__name__) @@ -22,7 +29,7 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ - def __init__(self, *, multimodal_config: MultiModalConfig) -> None: + def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -32,7 +39,7 @@ def __init__(self, *, multimodal_config: MultiModalConfig) -> None: class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] - def __call__(self, *, multimodal_config: MultiModalConfig) -> None: + def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -75,7 +82,7 @@ class SupportsLoRA(Protocol): embedding_padding_modules: ClassVar[List[str]] # lora_config is None when LoRA is not enabled - def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -90,7 +97,7 @@ class _SupportsLoRAType(Protocol): embedding_modules: Dict[str, str] embedding_padding_modules: List[str] - def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -145,6 +152,69 @@ def _supports_lora( return isinstance(model, SupportsLoRA) +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + """ + Accept :class:`IntermediateTensors` when PP rank > 0. + + Return :class:`IntermediateTensors` only for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + +@overload +def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[ast.ClassDef, Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + forward_params = inspect.signature(model_forward).parameters + if "intermediate_tensors" not in forward_params: + return False + + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + @runtime_checkable class HasInnerState(Protocol): """The interface required for all models that has inner state.""" @@ -158,7 +228,7 @@ class HasInnerState(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... @@ -168,7 +238,7 @@ class _HasInnerStateType(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... From b4ce5f7a3ad560694ef80da5f5063d09900ab2e6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:09:43 +0000 Subject: [PATCH 02/22] Subclass SupportsPP in relevant models --- vllm/model_executor/models/deepseek_v2.py | 6 +++--- vllm/model_executor/models/gpt2.py | 6 +++--- vllm/model_executor/models/internlm2.py | 6 +++--- vllm/model_executor/models/internvl.py | 4 ++-- vllm/model_executor/models/jais.py | 6 +++--- vllm/model_executor/models/llama.py | 7 +++---- vllm/model_executor/models/mixtral.py | 7 +++---- vllm/model_executor/models/nemotron.py | 7 +++---- vllm/model_executor/models/qwen.py | 7 +++---- vllm/model_executor/models/qwen2.py | 7 +++---- vllm/model_executor/models/qwen2_moe.py | 6 +++--- vllm/model_executor/models/qwen2_vl.py | 5 +++-- 12 files changed, 35 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8cbd9435ec7ca..8bdeac29e2458 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -472,7 +472,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fb5a297661ddc..52c92353784b0 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -32,8 +32,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,6 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -234,7 +234,7 @@ def forward( return hidden_states -class GPT2LMHeadModel(nn.Module): +class GPT2LMHeadModel(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 11a8431a5e7f7..28acbd73625c1 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -18,8 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,6 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -297,7 +297,7 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module): +class InternLM2ForCausalLM(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e84990a2ab109..a5535aa5d5c81 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -32,7 +32,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) @@ -362,7 +362,7 @@ def dummy_data_for_internvl(ctx: InputContext, @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) -class InternVLChatModel(nn.Module, SupportsMultiModal): +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: PretrainedConfig, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index b0fbb7e9829e0..d99bd4901656d 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -33,8 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -43,6 +42,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig +from .interfaces import SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -279,7 +279,7 @@ def forward( return hidden_states -class JAISLMHeadModel(nn.Module): +class JAISLMHeadModel(nn.Module, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec9..c5d5c3595dd1b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -51,7 +50,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_hip -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -344,7 +343,7 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module, SupportsLoRA): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 10cbfcf6432b3..40b9b0d7c7450 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -36,8 +36,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -306,7 +305,7 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module, SupportsLoRA): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e9ff12de2094e..134aab59c6e0e 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -34,8 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,7 +45,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers # The architecture is pretty similar to Llama, with these changes: @@ -372,7 +371,7 @@ def forward( return hidden_states -class NemotronForCausalLM(nn.Module, SupportsLoRA): +class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 761c1370b9776..704c65a80dee7 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -31,15 +31,13 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -47,6 +45,7 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of +from .interfaces import SupportsMultiModal, SupportsPP from .utils import flatten_bn, is_pp_missing_parameter, make_layers logger = init_logger(__name__) @@ -860,7 +859,7 @@ def dummy_data_for_qwen( @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal): +class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 5e6737ad7fa47..ac06ddb8cc60e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -48,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -298,7 +297,7 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module, SupportsLoRA): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index d80064601d993..55d5039e6edbc 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -42,8 +42,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -53,6 +52,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once +from .interfaces import SupportsPP from .utils import is_pp_missing_parameter, make_layers @@ -368,7 +368,7 @@ def forward( return hidden_states -class Qwen2MoeForCausalLM(nn.Module): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c82e8ed6ed1e0..9c822541c627b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -55,7 +55,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalInputs) @@ -68,6 +67,7 @@ from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext, "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: Qwen2VLConfig, From 30e454a860599af50e2137407efed2a3adcc614d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:18:13 +0000 Subject: [PATCH 03/22] Remove hardcoded list --- vllm/config.py | 5 ++--- vllm/model_executor/models/__init__.py | 25 ------------------------- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 365c4e906bb64..186c9e6955717 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -397,9 +397,8 @@ def verify_with_parallel_config( architectures = getattr(self.hf_config, "architectures", []) if not ModelRegistry.is_pp_supported_model(architectures): raise NotImplementedError( - "Pipeline parallelism is only supported for the following " - f"architectures: {ModelRegistry.get_pp_supported_archs()}." - ) + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") if self.use_async_output_proc: logger.warning("Async output processor is not supported with " diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d5bb5ed6774f4..bea499adc91e3 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -123,27 +123,6 @@ **_CONDITIONAL_GENERATION_MODELS, } -_PP_SUPPORTED_MODELS = [ - "AquilaForCausalLM", - "AquilaModel", - "DeepseekV2ForCausalLM", - "GPT2LMHeadModel", - "InternLM2ForCausalLM", - "InternLMForCausalLM", - "InternVLChatModel", - "JAISLMHeadModel", - "LlamaForCausalLM", - "LLaMAForCausalLM", - "MistralForCausalLM", - "MixtralForCausalLM", - "NemotronForCausalLM", - "Phi3ForCausalLM", - "Qwen2ForCausalLM", - "Qwen2MoeForCausalLM", - "QWenLMHeadModel", - "Qwen2VLForConditionalGeneration", -] - # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -297,10 +276,6 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: return any(is_pp(arch) for arch in architectures) - @staticmethod - def get_pp_supported_archs() -> List[str]: - return list(_PP_SUPPORTED_MODELS) - __all__ = [ "ModelRegistry", From e9ea5b7524ddb8d85f6439506cb914b033921551 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:19:07 +0000 Subject: [PATCH 04/22] Remove unused import --- vllm/model_executor/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index bea499adc91e3..b895c7a9a9c9f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,4 @@ import importlib -import importlib.util import string import subprocess import sys From 8b401761dd5176de7058404809676f1c00ed5055 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:22:20 +0000 Subject: [PATCH 05/22] Check using function --- vllm/model_executor/models/__init__.py | 28 ++++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b895c7a9a9c9f..b5ac1892edd22 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -3,14 +3,14 @@ import subprocess import sys from functools import lru_cache, partial -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import torch.nn as nn from vllm.logger import init_logger from vllm.utils import is_hip -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import supports_multimodal, supports_pp logger = init_logger(__name__) @@ -168,7 +168,10 @@ def _get_model(model_arch: str) -> Optional[Type[nn.Module]]: @staticmethod @lru_cache(maxsize=128) - def _is_subclass_stateless(model_arch: str, class_: type) -> bool: + def _check_stateless( + model_arch: str, + func: Callable[[object], bool], + ) -> bool: """ Test whether a model is a subclass of the given type. @@ -182,15 +185,15 @@ def _is_subclass_stateless(model_arch: str, class_: type) -> bool: raise ValueError(f"Unsafe module name detected for {model_arch}") if any(s not in valid_name_characters for s in cls_name): raise ValueError(f"Unsafe class name detected for {model_arch}") - if any(s not in valid_name_characters for s in class_.__module__): - raise ValueError(f"Unsafe module name detected for {class_}") - if any(s not in valid_name_characters for s in class_.__name__): - raise ValueError(f"Unsafe class name detected for {class_}") + if any(s not in valid_name_characters for s in func.__module__): + raise ValueError(f"Unsafe module name detected for {func}") + if any(s not in valid_name_characters for s in func.__name__): + raise ValueError(f"Unsafe class name detected for {func}") stmts = ";".join([ f"from {module_name} import {cls_name}", - f"from {class_.__module__} import {class_.__name__}", - f"assert isinstance({cls_name}, {class_.__name__})", + f"from {func.__module__} import {func.__name__}", + f"assert {func.__name__}({cls_name})", ]) result = subprocess.run([sys.executable, "-c", stmts], @@ -258,8 +261,8 @@ def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - is_mm = partial(ModelRegistry._is_subclass_stateless, - class_=SupportsMultiModal) + is_mm = partial(ModelRegistry._check_stateless, + func=supports_multimodal) return any(is_mm(arch) for arch in architectures) @@ -270,8 +273,7 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - is_pp = partial(ModelRegistry._is_subclass_stateless, - class_=SupportsPP) + is_pp = partial(ModelRegistry._check_stateless, func=supports_pp) return any(is_pp(arch) for arch in architectures) From ec4c6b3ca637514a5d12dc85e3f2d3bae4dfff30 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:23:09 +0000 Subject: [PATCH 06/22] Update docstring --- vllm/model_executor/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b5ac1892edd22..887b325adab46 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -173,7 +173,7 @@ def _check_stateless( func: Callable[[object], bool], ) -> bool: """ - Test whether a model is a subclass of the given type. + Run a boolean function against a model and return the result. This is run in a subprocess to avoid initializing CUDA for the main program. From cdc4dbe86d47343598534ac7c486a3b0c5675f36 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:24:21 +0000 Subject: [PATCH 07/22] Simplify --- vllm/model_executor/models/__init__.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 887b325adab46..df1058f85fc38 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -169,8 +169,8 @@ def _get_model(model_arch: str) -> Optional[Type[nn.Module]]: @staticmethod @lru_cache(maxsize=128) def _check_stateless( - model_arch: str, func: Callable[[object], bool], + model_arch: str, ) -> bool: """ Run a boolean function against a model and return the result. @@ -261,9 +261,7 @@ def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - is_mm = partial(ModelRegistry._check_stateless, - func=supports_multimodal) - + is_mm = partial(ModelRegistry._check_stateless, supports_multimodal) return any(is_mm(arch) for arch in architectures) @staticmethod @@ -273,8 +271,7 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - is_pp = partial(ModelRegistry._check_stateless, func=supports_pp) - + is_pp = partial(ModelRegistry._check_stateless, supports_pp) return any(is_pp(arch) for arch in architectures) From dcc2a4912fa1fbf52a754b04ec354be1de5e50ce Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 14:29:02 +0000 Subject: [PATCH 08/22] Add tests --- tests/models/test_registry.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b058e2755c245..b63c1638a6802 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -7,3 +7,19 @@ def test_registry_imports(model_cls): # Ensure all model classes can be imported successfully ModelRegistry.resolve_model_cls([model_cls]) + + +@pytest.mark.parametrize("model_cls,is_mm", [ + ("LlamaForCausalLM", False), + ("MllamaForConditionalGeneration", True), +]) +def test_registry_is_multimodal(model_cls, is_mm): + assert ModelRegistry.is_multimodal_model(model_cls) is is_mm + + +@pytest.mark.parametrize("model_cls,is_pp", [ + ("MLPSpeculatorPreTrainedModel", False), + ("DeepseekV2ForCausalLM", True), +]) +def test_registry_is_pp(model_cls, is_pp): + assert ModelRegistry.is_pp_supported_model(model_cls) is is_pp From 7280766dfd829cd12b9882010088ba4a21ceb588 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 15:01:26 +0000 Subject: [PATCH 09/22] Test CUDA initialization --- tests/models/test_registry.py | 55 ++++++++++++++++++++------ vllm/model_executor/models/__init__.py | 7 +++- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b63c1638a6802..e267b47f613b0 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -1,25 +1,54 @@ +import warnings + import pytest +import torch.cuda from vllm.model_executor.models import _MODELS, ModelRegistry +from ..utils import fork_new_process_for_each_test + -@pytest.mark.parametrize("model_cls", _MODELS) -def test_registry_imports(model_cls): +@pytest.mark.parametrize("model_arch", _MODELS) +def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully - ModelRegistry.resolve_model_cls([model_cls]) + ModelRegistry.resolve_model_cls(model_arch) -@pytest.mark.parametrize("model_cls,is_mm", [ - ("LlamaForCausalLM", False), - ("MllamaForConditionalGeneration", True), +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ + ("LlamaForCausalLM", False, False), + ("MllamaForConditionalGeneration", True, False), + ("LlavaForConditionalGeneration", True, True), ]) -def test_registry_is_multimodal(model_cls, is_mm): - assert ModelRegistry.is_multimodal_model(model_cls) is is_mm +def test_registry_is_multimodal(model_arch, is_mm, init_cuda): + assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + + if init_cuda: + assert not torch.cuda.is_initialized() + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different model.", + stacklevel=2) -@pytest.mark.parametrize("model_cls,is_pp", [ - ("MLPSpeculatorPreTrainedModel", False), - ("DeepseekV2ForCausalLM", True), + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ + ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), ]) -def test_registry_is_pp(model_cls, is_pp): - assert ModelRegistry.is_pp_supported_model(model_cls) is is_pp +def test_registry_is_pp(model_arch, is_pp, init_cuda): + assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp + + if init_cuda: + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different model.", + stacklevel=2) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index df1058f85fc38..ba9f670101e45 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -221,7 +221,12 @@ def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: @staticmethod def resolve_model_cls( - architectures: List[str]) -> Tuple[Type[nn.Module], str]: + architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + for arch in architectures: model_cls = ModelRegistry._try_load_model_cls(arch) if model_cls is not None: From 37cc51ba2797196a7d82a9bfe0ab1e54974bcd53 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 15:10:52 +0000 Subject: [PATCH 10/22] Add platform guard --- tests/models/test_registry.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index e267b47f613b0..cd5fe77d01db9 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -4,6 +4,7 @@ import torch.cuda from vllm.model_executor.models import _MODELS, ModelRegistry +from vllm.platforms import current_platform from ..utils import fork_new_process_for_each_test @@ -23,7 +24,7 @@ def test_registry_imports(model_arch): def test_registry_is_multimodal(model_arch, is_mm, init_cuda): assert ModelRegistry.is_multimodal_model(model_arch) is is_mm - if init_cuda: + if init_cuda and current_platform.is_cuda_alike(): assert not torch.cuda.is_initialized() ModelRegistry.resolve_model_cls(model_arch) @@ -43,7 +44,7 @@ def test_registry_is_multimodal(model_arch, is_mm, init_cuda): def test_registry_is_pp(model_arch, is_pp, init_cuda): assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp - if init_cuda: + if init_cuda and current_platform.is_cuda_alike(): assert not torch.cuda.is_initialized() ModelRegistry.resolve_model_cls(model_arch) From 38142463e0bb58657e828cae499d1d812188642f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 15:20:28 +0000 Subject: [PATCH 11/22] Trigger CI --- tests/models/test_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index cd5fe77d01db9..ee5c9e8ccb196 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -31,7 +31,7 @@ def test_registry_is_multimodal(model_arch, is_mm, init_cuda): if not torch.cuda.is_initialized(): warnings.warn( "This model no longer initializes CUDA on import. " - "Please test using a different model.", + "Please test using a different one.", stacklevel=2) @@ -51,5 +51,5 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda): if not torch.cuda.is_initialized(): warnings.warn( "This model no longer initializes CUDA on import. " - "Please test using a different model.", + "Please test using a different one.", stacklevel=2) From cf91f7b46f4171425406e3c04a1683c7817f45ec Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 16:17:10 +0000 Subject: [PATCH 12/22] Fix OOT registration --- vllm/model_executor/models/__init__.py | 109 +++++++++++++++---------- 1 file changed, 67 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ba9f670101e45..9fe71f10e433d 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -161,52 +161,19 @@ def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: @staticmethod @lru_cache(maxsize=128) - def _get_model(model_arch: str) -> Optional[Type[nn.Module]]: + def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) module = importlib.import_module(module_name) return getattr(module, cls_name, None) @staticmethod - @lru_cache(maxsize=128) - def _check_stateless( - func: Callable[[object], bool], - model_arch: str, - ) -> bool: - """ - Run a boolean function against a model and return the result. - - This is run in a subprocess to avoid initializing CUDA for the main - program. - """ - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) - - valid_name_characters = string.ascii_letters + string.digits + "._" - if any(s not in valid_name_characters for s in module_name): - raise ValueError(f"Unsafe module name detected for {model_arch}") - if any(s not in valid_name_characters for s in cls_name): - raise ValueError(f"Unsafe class name detected for {model_arch}") - if any(s not in valid_name_characters for s in func.__module__): - raise ValueError(f"Unsafe module name detected for {func}") - if any(s not in valid_name_characters for s in func.__name__): - raise ValueError(f"Unsafe class name detected for {func}") - - stmts = ";".join([ - f"from {module_name} import {cls_name}", - f"from {func.__module__} import {func.__name__}", - f"assert {func.__name__}({cls_name})", - ]) - - result = subprocess.run([sys.executable, "-c", stmts], - capture_output=True) - - return result.returncode == 0 - - @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _OOT_MODELS: return _OOT_MODELS[model_arch] - if model_arch not in _MODELS: - return None + if is_hip(): if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError( @@ -217,7 +184,15 @@ def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: "Model architecture %s is partially supported by ROCm: %s", model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - return ModelRegistry._get_model(model_arch) + return None + + @staticmethod + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return model + + return ModelRegistry._try_get_model_stateful(model_arch) @staticmethod def resolve_model_cls( @@ -250,6 +225,50 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + @lru_cache(maxsize=128) + def _check_stateless( + func: Callable[[Type[nn.Module]], bool], + model_arch: str, + *, + default: Optional[bool] = None, + ) -> bool: + """ + Run a boolean function against a model and return the result. + + This is run in a subprocess to avoid initializing CUDA for the main + program. + """ + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return func(model) + + if model_arch not in _MODELS and default is not None: + return default + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + + valid_name_characters = string.ascii_letters + string.digits + "._" + if any(s not in valid_name_characters for s in module_name): + raise ValueError(f"Unsafe module name detected for {model_arch}") + if any(s not in valid_name_characters for s in cls_name): + raise ValueError(f"Unsafe class name detected for {model_arch}") + if any(s not in valid_name_characters for s in func.__module__): + raise ValueError(f"Unsafe module name detected for {func}") + if any(s not in valid_name_characters for s in func.__name__): + raise ValueError(f"Unsafe class name detected for {func}") + + stmts = ";".join([ + f"from {module_name} import {cls_name}", + f"from {func.__module__} import {func.__name__}", + f"assert {func.__name__}({cls_name})", + ]) + + result = subprocess.run([sys.executable, "-c", stmts], + capture_output=True) + + return result.returncode == 0 + @staticmethod def is_embedding_model(architectures: Union[str, List[str]]) -> bool: if isinstance(architectures, str): @@ -266,7 +285,10 @@ def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - is_mm = partial(ModelRegistry._check_stateless, supports_multimodal) + is_mm = partial(ModelRegistry._check_stateless, + supports_multimodal, + default=False) + return any(is_mm(arch) for arch in architectures) @staticmethod @@ -276,7 +298,10 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - is_pp = partial(ModelRegistry._check_stateless, supports_pp) + is_pp = partial(ModelRegistry._check_stateless, + supports_pp, + default=False) + return any(is_pp(arch) for arch in architectures) From 38b090ad10705fba10f6203f461be650a1f52ffe Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 16:18:44 +0000 Subject: [PATCH 13/22] Update docstring --- vllm/model_executor/models/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 9fe71f10e433d..4e8624b65a5ed 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -236,8 +236,8 @@ def _check_stateless( """ Run a boolean function against a model and return the result. - This is run in a subprocess to avoid initializing CUDA for the main - program. + If the model is not imported, the function is run inside a subprocess to + avoid initializing CUDA for the main program. """ model = ModelRegistry._try_get_model_stateless(model_arch) if model is not None: From d394985857a01c542a7482d33abf9d3bcaf2135e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 16:24:20 +0000 Subject: [PATCH 14/22] Remove unnecessary global --- vllm/model_executor/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4e8624b65a5ed..11ae31351b77b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -222,7 +222,7 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, model_cls.__name__) - global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls @staticmethod From 6a4287afcd2e21206d0fb82e6e6351a663f9d04c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 3 Oct 2024 02:07:03 +0000 Subject: [PATCH 15/22] Update interfaces --- vllm/model_executor/models/__init__.py | 18 +++++- vllm/model_executor/models/interfaces.py | 78 +++++++++++++++++++++--- vllm/model_executor/models/utils.py | 22 ++++--- 3 files changed, 98 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 11ae31351b77b..2f3eadc6aa25d 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -2,6 +2,7 @@ import string import subprocess import sys +import uuid from functools import lru_cache, partial from typing import Callable, Dict, List, Optional, Tuple, Type, Union @@ -236,8 +237,10 @@ def _check_stateless( """ Run a boolean function against a model and return the result. - If the model is not imported, the function is run inside a subprocess to - avoid initializing CUDA for the main program. + If the model is not found, returns the provided default value. + + If the model is not already imported, the function is run inside a + subprocess to avoid initializing CUDA for the main program. """ model = ModelRegistry._try_get_model_stateless(model_arch) if model is not None: @@ -257,16 +260,25 @@ def _check_stateless( raise ValueError(f"Unsafe module name detected for {func}") if any(s not in valid_name_characters for s in func.__name__): raise ValueError(f"Unsafe class name detected for {func}") + + err_id = uuid.uuid4() stmts = ";".join([ f"from {module_name} import {cls_name}", f"from {func.__module__} import {func.__name__}", - f"assert {func.__name__}({cls_name})", + f"assert {func.__name__}({cls_name}), '{err_id}'", ]) result = subprocess.run([sys.executable, "-c", stmts], capture_output=True) + err_lines = [line.decode() for line in result.stderr.splitlines()] + if err_lines and err_lines[-1] != f"AssertionError: {err_id}": + err_str = "\n".join(err_lines) + raise RuntimeError( + "An unexpected error occurred while importing the model in " + f"another process. Error log:\n{err_str}") + return result.returncode == 0 @staticmethod diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 36e705fa2e787..298174fa05965 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,4 +1,3 @@ -import ast import inspect from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) @@ -165,6 +164,15 @@ class SupportsPP(Protocol): MRO of your model class. """ + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + def forward( self, input_ids: torch.Tensor, @@ -187,6 +195,24 @@ def forward( class _SupportsPPType(Protocol): supports_pp: Literal[True] + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + ... + @overload def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: @@ -199,22 +225,58 @@ def supports_pp(model: object) -> TypeIs[SupportsPP]: def supports_pp( - model: Union[ast.ClassDef, Type[object], object], + model: Union[Type[object], object], ) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: - model_forward = getattr(model, "forward", None) - if not callable(model_forward): - return False + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect - forward_params = inspect.signature(model_forward).parameters - if "intermediate_tensors" not in forward_params: - return False +def _supports_pp_attributes( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: if isinstance(model, type): return isinstance(model, _SupportsPPType) return isinstance(model, SupportsPP) +def _supports_pp_inspect( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + forward_params = inspect.signature(model_forward).parameters + return "intermediate_tensors" in forward_params + + @runtime_checkable class HasInnerState(Protocol): """The interface required for all models that has inner state.""" diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f6218bad4ef1e..9d7432013165d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -24,7 +24,7 @@ class WeightsGroup(UserDict): when attempting to access a weight component that does not exist. """ - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: try: return super().__getitem__(key) except KeyError as exc: @@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], def group_weights_with_prefix( - weights: Iterable[Tuple[str, torch.Tensor]] -) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: + weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup: """ Helper function to group weights with prefix """ @@ -319,8 +318,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( - batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: return IntermediateTensors({ key: torch.zeros((batch_size, hidden_size), dtype=dtype, @@ -342,8 +343,11 @@ def __init__(self, llm: nn.Module, name: str) -> None: self.model_name = name setattr(self, name, llm) - def forward(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name)(*args, **kwargs) + def __getattr__(self, key: str): + llm = super().__getattr__(self.model_name) + return getattr(llm, key) - def embed_tokens(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name).embed_tokens(*args, **kwargs) + # We need to explicitly override this + def __call__(self, *args: Any, **kwargs: Any) -> Any: + llm = super().__getattr__(self.model_name) + return llm(*args, **kwargs) From 1e010c79cdec4e7bc49725fec38529fe4d5560ca Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 3 Oct 2024 04:28:59 +0000 Subject: [PATCH 16/22] format --- vllm/model_executor/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 2f3eadc6aa25d..91f34b246b651 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -260,7 +260,7 @@ def _check_stateless( raise ValueError(f"Unsafe module name detected for {func}") if any(s not in valid_name_characters for s in func.__name__): raise ValueError(f"Unsafe class name detected for {func}") - + err_id = uuid.uuid4() stmts = ";".join([ From 1e0babaee7146b3e9b9ca5698d7b14f9d0dc7746 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 3 Oct 2024 13:10:34 +0800 Subject: [PATCH 17/22] Fix error check --- vllm/model_executor/models/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 91f34b246b651..621fdaf26c1ba 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -272,12 +272,13 @@ def _check_stateless( result = subprocess.run([sys.executable, "-c", stmts], capture_output=True) - err_lines = [line.decode() for line in result.stderr.splitlines()] - if err_lines and err_lines[-1] != f"AssertionError: {err_id}": - err_str = "\n".join(err_lines) - raise RuntimeError( - "An unexpected error occurred while importing the model in " - f"another process. Error log:\n{err_str}") + if result.returncode != 0: + err_lines = [line.decode() for line in result.stderr.splitlines()] + if err_lines and err_lines[-1] != f"AssertionError: {err_id}": + err_str = "\n".join(err_lines) + raise RuntimeError( + "An unexpected error occurred while importing the model in " + f"another process. Error log:\n{err_str}") return result.returncode == 0 From 9ef69deeec65cf8271045bf67b3a4d3a8c51fcd3 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 3 Oct 2024 13:11:25 +0800 Subject: [PATCH 18/22] Make `prefix` required --- vllm/model_executor/models/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 9d7432013165d..ebc136a1c4ca1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -182,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, class LayerFn(Protocol): - def __call__( - self, - prefix="", - ) -> torch.nn.Module: + def __call__(self, prefix: str) -> torch.nn.Module: ... From a36f7ed4a5d0149b76167e12e85ab76596c24ec7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 3 Oct 2024 16:52:07 +0800 Subject: [PATCH 19/22] Fix environment variables not being copied over --- vllm/model_executor/models/__init__.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 621fdaf26c1ba..b3ec259b77b0e 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,4 +1,5 @@ import importlib +import os import string import subprocess import sys @@ -18,6 +19,7 @@ _GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), @@ -36,9 +38,12 @@ "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), + "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), @@ -58,6 +63,7 @@ "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), @@ -68,14 +74,11 @@ "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), - "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + # NOTE: The below models are for speculative decoding only "MedusaModel": ("medusa", "Medusa"), "EAGLEModel": ("eagle", "EAGLE"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM"), - "GraniteForCausalLM": ("granite", "GraniteForCausalLM") } _EMBEDDING_MODELS = { @@ -270,7 +273,8 @@ def _check_stateless( ]) result = subprocess.run([sys.executable, "-c", stmts], - capture_output=True) + capture_output=True, + env=os.environ.copy()) if result.returncode != 0: err_lines = [line.decode() for line in result.stderr.splitlines()] From ed669a52a33e50043256b9dba81ef820336c4ffd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 3 Oct 2024 12:44:43 +0000 Subject: [PATCH 20/22] Fix the real problem, which is that modelscope is not installed --- requirements-test.txt | 5 +++-- vllm/model_executor/models/__init__.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 9c6fadb88865a..e350722671449 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,8 +10,9 @@ pytest-shard awscli einops # required for MPT, qwen-vl and Mamba httpx -librosa # required for audio test -opencv-python # required for video test +modelscope # required for modelscope tests +librosa # required for audio tests +opencv-python # required for video tests peft requests ray[adag]==2.35 diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b3ec259b77b0e..2f9cb2b760a82 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,4 @@ import importlib -import os import string import subprocess import sys @@ -273,8 +272,7 @@ def _check_stateless( ]) result = subprocess.run([sys.executable, "-c", stmts], - capture_output=True, - env=os.environ.copy()) + capture_output=True) if result.returncode != 0: err_lines = [line.decode() for line in result.stderr.splitlines()] From b8958a97d3fdb8033e2c60ade7feb6e318a6334d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 3 Oct 2024 13:33:19 +0000 Subject: [PATCH 21/22] Move modelscope installation into regression test --- .buildkite/test-pipeline.yaml | 4 +++- requirements-test.txt | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f678436dd05e1..427dc14513d45 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -146,7 +146,9 @@ steps: source_file_dependencies: - vllm/ - tests/test_regression - command: pytest -v -s test_regression.py + commands: + - pip install modelscope + - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min diff --git a/requirements-test.txt b/requirements-test.txt index e350722671449..37c3bd8ba8794 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,7 +10,6 @@ pytest-shard awscli einops # required for MPT, qwen-vl and Mamba httpx -modelscope # required for modelscope tests librosa # required for audio tests opencv-python # required for video tests peft From e9f0601d46b5340e22bdca387946fe2be13e6f87 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 3 Oct 2024 14:14:00 +0000 Subject: [PATCH 22/22] Fix `LLMWrapper` --- vllm/model_executor/models/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ebc136a1c4ca1..761f0406b1333 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -342,6 +342,9 @@ def __init__(self, llm: nn.Module, name: str) -> None: def __getattr__(self, key: str): llm = super().__getattr__(self.model_name) + if key == self.model_name: + return llm + return getattr(llm, key) # We need to explicitly override this