Skip to content

Commit

Permalink
[Misc] Add BNB support to GLM4-V model (vllm-project#12184)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored and Ubuntu committed Jan 19, 2025
1 parent b7b0865 commit 240a015
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 53 deletions.
15 changes: 11 additions & 4 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,15 +1105,22 @@ def _load_weights(self, model_config: ModelConfig,
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos > 0) and (
quant_param_name[shard_pos - 1] == ".")
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(
shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) \
and (new_quant_param_name in param_dict)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = quant_param_name.replace(
shard_name, weight_name)
quant_param_name = new_quant_param_name
break

# Models like Clip/Siglip may skip some layers in initialization,
Expand Down
95 changes: 47 additions & 48 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm.transformers_utils.configs import ChatGLMConfig

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -605,9 +605,50 @@ def forward(
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "rotary_pos_emb.inv_freq" in name:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".word_embeddings": ""}, )

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -660,52 +701,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
"transformer.vision.linear_proj.gate_proj.weight": None,
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
}
}

params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
if name in merged_weight_dict:
assert merged_weight_dict[name] is None
merged_weight_dict[name] = loaded_weight
is_weight_to_be_merge = True
if is_weight_to_be_merge:
continue
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
param = params_dict[combined_name]
combined_weight = torch.cat(list(merged_weight_dict.values()),
dim=0)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


class ChatGLM(ChatGLMBaseModel):
Expand All @@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):


class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):

packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
Expand Down Expand Up @@ -777,7 +776,7 @@ def __new__(
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "visual"):
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
# Initialize LLM
else:
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(self.proj.weight.device)
images = images.to(device=self.proj.weight.device,
dtype=self.proj.weight.dtype)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
Expand Down

0 comments on commit 240a015

Please sign in to comment.