Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Sep 22, 2024
1 parent c3adcd6 commit 9eaaea6
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
Expand Down Expand Up @@ -36,8 +35,8 @@
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)

logger = init_logger(__name__)

Expand Down Expand Up @@ -859,22 +858,19 @@ def sample(
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators
(vit_weights, mlp_weights, newline_weights,
llm_weights) = itertools.tee(weights, 4)
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

0 comments on commit 9eaaea6

Please sign in to comment.