From 9eaaea6972f1960ab1d412a1541062f27b1e210f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 22 Sep 2024 13:38:07 +0000 Subject: [PATCH] Apply #8656 --- vllm/model_executor/models/llava_onevision.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index ae4aea122cafd..9099d4f88222d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,4 +1,3 @@ -import itertools import math from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -36,8 +35,8 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, dummy_video_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -859,22 +858,19 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators - (vit_weights, mlp_weights, newline_weights, - llm_weights) = itertools.tee(weights, 4) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"])