From 1a0737068a2b0ae5e5008dfd308905c7253808e6 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 22 Sep 2024 12:33:27 +0800 Subject: [PATCH] [Bugfix] Refactor composite weight loading logic (#8656) --- vllm/model_executor/models/internvl.py | 16 ++++----- vllm/model_executor/models/llava.py | 16 ++++----- vllm/model_executor/models/llava_next.py | 20 ++++------- .../model_executor/models/llava_next_video.py | 17 ++++----- vllm/model_executor/models/paligemma.py | 14 +++----- vllm/model_executor/models/ultravox.py | 12 +++---- vllm/model_executor/models/utils.py | 36 ++++++++++++++++++- 7 files changed, 70 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 507d7014714a2..005a24f10aa17 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,7 +4,6 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -import itertools import re from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -33,8 +32,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) IMG_START = '' IMG_END = '' @@ -518,21 +517,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_model") - self.vision_model.load_weights(vit_weights) + self.vision_model.load_weights(weights_group["vision_model"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "mlp1") mlp_params_dict = dict(self.mlp1.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["mlp1"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7a6c991fb133a..69eb177a7dea8 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -26,8 +25,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -393,21 +392,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index d550a249ee822..96034b254e49b 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -30,8 +29,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -637,25 +636,21 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( - weights, 4) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load newline - newline_weights = filter_weights(newline_weights, "image_newline") - for name, loaded_weight in newline_weights: + for name, loaded_weight in weights_group["image_newline"]: assert name == "" param = self.image_newline weight_loader = getattr(param, "weight_loader", @@ -663,5 +658,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7fe85e5e4ab3d..a8b5176dc43cf 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,4 +1,3 @@ -import itertools import math from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -30,7 +29,7 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -449,23 +448,19 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( - weights, 4) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 5fd39b5e35be6..68b6d0cf808e1 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -23,7 +22,7 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import filter_weights, merge_multimodal_embeddings +from .utils import group_weights_with_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -286,21 +285,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision tower - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 87f59f487f87b..b89c9dafd9cd8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" -import itertools import math from array import array from functools import lru_cache @@ -29,7 +28,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import (filter_weights, flatten_bn, +from vllm.model_executor.models.utils import (flatten_bn, + group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -467,11 +467,10 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - projector_weights, llm_weights = itertools.tee(weights, 2) + weights_group = group_weights_with_prefix(weights) # load projector weights - projector_weights = filter_weights(projector_weights, - "multi_modal_projector") + projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict( self.multi_modal_projector.named_parameters()) for name, loaded_weight in projector_weights: @@ -481,5 +480,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8b80dda96db49..38d6a4653ebd6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,3 +1,5 @@ +import itertools +from collections import UserDict from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) @@ -16,7 +18,23 @@ from vllm.utils import is_pin_memory_available -def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): +class WeightsGroup(UserDict): + """ + Wraps grouped weights dictionary for a more informative error message + when attempting to access a weight component that does not exist. + """ + + def __getitem__(self, key: str) -> int: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"There is no weights named with the prefix: {key}. " + f"Available prefix: {set(self.keys())}") + raise KeyError(msg) from exc + + +def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: """ Helper function to load weights for inner vLLM models. @@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): yield name, loaded_weight +def group_weights_with_prefix( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: + """ + Helper function to group weights with prefix + """ + init_weights, repeated_weights = itertools.tee(weights, 2) + weights_prefix = {name.split(".")[0] for name, _ in init_weights} + repeated_weights = itertools.tee(repeated_weights, len(weights_prefix)) + + return WeightsGroup({ + prefix: filter_weights(component, prefix) + for component, prefix in zip(repeated_weights, weights_prefix) + }) + + def init_vllm_registered_model( hf_config: PretrainedConfig, cache_config: Optional[CacheConfig],