From 6cc6935d12d38be581eead655040d764ddde5823 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 13 Nov 2024 18:55:39 -0800 Subject: [PATCH] [misc] error early for old-style class (#10304) Signed-off-by: youkaichao --- docs/source/design/class_hierarchy.rst | 39 +++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 17 +++++++- vllm/model_executor/models/arctic.py | 2 +- vllm/model_executor/models/bloom.py | 6 +-- vllm/model_executor/models/chatglm.py | 6 +-- vllm/model_executor/models/dbrx.py | 6 +-- vllm/model_executor/models/eagle.py | 2 +- vllm/model_executor/models/falcon.py | 6 +-- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/gpt2.py | 6 +-- vllm/model_executor/models/gpt_bigcode.py | 6 +-- vllm/model_executor/models/gpt_j.py | 6 +-- vllm/model_executor/models/gpt_neox.py | 6 +-- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/jais.py | 6 +-- vllm/model_executor/models/llava.py | 2 +- vllm/model_executor/models/llava_next.py | 2 +- .../model_executor/models/llava_next_video.py | 2 +- vllm/model_executor/models/llava_onevision.py | 2 +- vllm/model_executor/models/medusa.py | 2 +- vllm/model_executor/models/utils.py | 10 ++--- 21 files changed, 75 insertions(+), 63 deletions(-) diff --git a/docs/source/design/class_hierarchy.rst b/docs/source/design/class_hierarchy.rst index b3404f6b936e7..15f0c8ccf77ee 100644 --- a/docs/source/design/class_hierarchy.rst +++ b/docs/source/design/class_hierarchy.rst @@ -26,6 +26,45 @@ There are several important design choices behind this class hierarchy: 2. **Uniformity**: The model runner needs a unified interface to create and initialize the model. vLLM supports more than 50 types of popular open-source models. Each model has its own initialization logic. If the constructor signature varies with models, the model runner does not know how to call the constructor accordingly, without complicated and error-prone inspection logic. By making the constructor of the model class uniform, the model runner can easily create and initialize the model without knowing the specific model type. This is also useful for composing models. Vision-language models often consist of a vision model and a language model. By making the constructor uniform, we can easily create a vision model and a language model and compose them into a vision-language model. +.. note:: + + To support this change, all vLLM models' signatures have been updated to: + + .. code-block:: python + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: + + .. code-block:: python + + class MyOldModel(nn.Module): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + ... + + from vllm.config import VllmConfig + class MyNewModel(MyOldModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + super().__init__(config, cache_config, quant_config, lora_config, prefix) + + if __version__ >= "0.6.4": + MyModel = MyNewModel + else: + MyModel = MyOldModel + + This way, the model can work with both old and new versions of vLLM. + 3. **Sharding and Quantization at Initialization**: Certain features require changing the model weights. For example, tensor parallelism needs to shard the model weights, and quantization needs to quantize the model weights. There are two possible ways to implement this feature. One way is to change the model weights after the model is initialized. The other way is to change the model weights during the model initialization. vLLM chooses the latter. The first approach is not scalable to large models. Suppose we want to run a 405B model (with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should only load 50GB weights. If we change the model weights after the model is initialized, we need to load the full 810GB weights to every GPU and then shard the weights, leading to a huge memory overhead. Instead, if we shard the weights during the model initialization, every layer will only create a shard of the weights it needs, leading to a much smaller memory overhead. The same idea applies to quantization. Note that we also add an additional argument ``prefix`` to the model's constructor so that the model can initialize itself differently based on the prefix. This is useful for non-uniform quantization, where different parts of the model are quantized differently. The ``prefix`` is usually an empty string for the top-level model and a string like ``"vision"`` or ``"language"`` for the sub-models. In general, it matches the name of the module's state dict in the checkpoint file. One disadvantage of this design is that it is hard to write unit tests for individual components in vLLM because every component needs to be initialized by a complete config object. We solve this problem by providing a default initialization function that creates a default config object with all fields set to ``None``. If the component we want to test only cares about a few fields in the config object, we can create a default config object and set the fields we care about. This way, we can test the component in isolation. Note that many tests in vLLM are end-to-end tests that test the whole system, so this is not a big problem. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 715e6c11f86ce..5bcae37961195 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -4,6 +4,7 @@ import dataclasses import fnmatch import glob +import inspect import json import math import os @@ -88,11 +89,23 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _initialize_model(vllm_config: VllmConfig) -> nn.Module: +def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) - return model_class(vllm_config=vllm_config) + signatures = inspect.signature(model_class.__init__) + # collect all kw-only parameters + kw_only_params = [ + param.name for param in signatures.parameters.values() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + assert "vllm_config" in kw_only_params and "prefix" in kw_only_params, \ + ("vLLM model class must accept `vllm_config` and `prefix` as kw-only " + "arguments. Possibly you have an old-style model class registered from " + "out of tree and it is used for new vLLM version. " + "Please check https://docs.vllm.ai/en/latest/design/class_hierarchy.html " + "for the design and update the model class accordingly.") + return model_class(vllm_config=vllm_config, prefix=prefix) class BaseModelLoader(ABC): diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 7d4b9654b54ab..9ee2a2cc09a24 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -415,7 +415,7 @@ def forward( class ArcticForCausalLM(nn.Module, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 2c14519fb9e0e..84adf574af5e2 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -281,11 +281,7 @@ def forward( class BloomForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 6ec2d5a2a3909..70e9b607b0642 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -593,11 +593,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index d5f9b903183d4..fff8710f6b475 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -350,11 +350,7 @@ def forward( class DbrxForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index c902829994c7c..85c51e8404584 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -36,7 +36,7 @@ class EAGLE(nn.Module): in the draft checkpoint (using key token_map). Also, the draft config needs to have truncated_vocab_size (=k) as an attribute.""" - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 562ee5517e7f1..dcfcb6694feb5 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -401,11 +401,7 @@ class FalconForCausalLM(nn.Module, SupportsPP): ".dense_4h_to_h.", ] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index b39dfe706e0df..50701793b7b83 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -225,7 +225,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index adf2a7a51f737..cc85693f99526 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -242,11 +242,7 @@ def forward( class GPT2LMHeadModel(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ae1495ebd7914..ab25c66c3a887 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -257,11 +257,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 610795b084b44..a83d03480dde1 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -229,11 +229,7 @@ def forward( class GPTJForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index f5603772e9862..794b141bfa4aa 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -242,11 +242,7 @@ def forward( class GPTNeoXForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 07165ea688f94..92579e3aae949 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -409,7 +409,7 @@ def dummy_data( @INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 4dc9271703a8d..65800c44e5a93 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -286,11 +286,7 @@ def forward( class JAISLMHeadModel(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 005ae5e03cfed..b13bcfa676811 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -259,7 +259,7 @@ def init_vision_tower_for_llava( @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 0b621a23ec980..dd2fa6cac969f 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -281,7 +281,7 @@ def input_processor_for_llava_next(ctx: InputContext, class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b030c2f5fdc47..5d5598d07bfde 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -253,7 +253,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c129f140d8d12..a5b2108177830 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -404,7 +404,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 4cb1b4a929b9f..de5b2d89c0962 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -44,7 +44,7 @@ class Medusa(nn.Module): in the draft checkpoint (using key token_map). Also, the draft config needs to have truncated_vocab_size (=k) as an attribute.""" - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config super().__init__() self.config = config diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1fc6c1be4b7bb..1d51885f9094a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -14,7 +14,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import ModelRegistry from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -240,12 +239,9 @@ def init_vllm_registered_model( Helper function to initialize an inner model registered to vLLM, based on the arguments passed to the outer vLLM model. """ - model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) - - return model_class( - vllm_config=vllm_config.with_hf_config(hf_config), - prefix=prefix, - ) + from vllm.model_executor.model_loader.loader import _initialize_model + vllm_config = vllm_config.with_hf_config(hf_config) + return _initialize_model(vllm_config, prefix) @overload