From f57ee5650dd402c6147980824c6936c96cfa59fe Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Dec 2024 21:12:05 +0800 Subject: [PATCH] [Model] Modify MolmoForCausalLM MLP (#11510) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 42 ++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 63a25137f8aa9..8938f62d0c494 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -464,24 +464,27 @@ def forward( class MolmoMLP(nn.Module): """Molmo's LLM mlp.""" - def __init__( - self, - config: PretrainedConfig, - input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + proj_name: str = "gate_up_proj") -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // 2 - # Feed-forward input projection. - self.gate_up_proj = MergedColumnParallelLinear( - input_dim or self.hidden_size, - [self.intermediate_size] * 2, - bias=False, - quant_config=quant_config, - ) - + # Molmo's LLM proj weights are already merged into the disk, while + # image_projector proj is separate. If the same proj_name were used, it + # would create ambiguity and make it difficult to support BNB and LoRA. + self.proj_name = proj_name + setattr( + self, proj_name, + MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + )) # Activation function. self.act_fn = SiluAndMul() @@ -497,7 +500,7 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(x) + gate_up, _ = getattr(self, self.proj_name)(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x @@ -520,7 +523,9 @@ def __init__( prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = MolmoMLP(config, quant_config=quant_config) + self.mlp = MolmoMLP(config, + quant_config=quant_config, + proj_name="gate_up_proj") # LayerNorm assert config.layer_norm_type == "rms" @@ -616,6 +621,7 @@ def __init__( config, input_dim=vision_config.image_emb_dim, quant_config=quant_config, + proj_name="merged_linear", ) image_dim = vision_config.image_emb_dim * len(self.vit_layers) @@ -714,8 +720,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + ("merged_linear", "gate_proj", 0), + ("merged_linear", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set()