Skip to content

Commit

Permalink
Support Pixtral-Large HF by using llava multimodal_projector_bias con…
Browse files Browse the repository at this point in the history
…fig (#12710)

Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Feb 4, 2025
1 parent 73b35cc commit 5d98d56
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,20 @@ def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()

self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=True,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=True,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")

Expand Down Expand Up @@ -503,6 +504,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias)

self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,16 @@ def forward(self, image_features: torch.Tensor):
class LlavaNextMultiModalProjector(nn.Module):

def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
projector_hidden_act: str, multimodal_projector_bias: bool):
super().__init__()

self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
bias=multimodal_projector_bias)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
bias=multimodal_projector_bias)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
Expand Down Expand Up @@ -298,7 +298,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,11 @@ def __init__(self, config: LlavaOnevisionConfig):

self.linear_1 = nn.Linear(config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True)
bias=config.multimodal_projector_bias)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size,
bias=True)
bias=config.multimodal_projector_bias)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
Expand Down

0 comments on commit 5d98d56

Please sign in to comment.