Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Bugfix]: correct Aria model output #12309

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def run_aria(question: str, modality: str):
llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=2,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n")
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
Expand Down
51 changes: 49 additions & 2 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.sequence import IntermediateTensors

# yapf: disable
from .idefics2_vision_model import Idefics2VisionConfig
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
Expand All @@ -50,6 +51,50 @@ class AriaImagePixelInputs(TypedDict):
"""


class AriaVisionTransformer(Idefics3VisionTransformer):

def __init__(
self,
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.post_layernorm = nn.Identity()
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:

# NOTE: post_layernorm is not used in Aria
if "post_layernorm" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class AriaProjectorMLP(nn.Module):

def __init__(
Expand Down Expand Up @@ -228,8 +273,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_output = torch.nn.functional.linear(hidden_states,
self.router_weight)

hidden_states_copy = hidden_states.clone()
# NOTE: hidden_states will be modified inplace by `FusedMoE`
sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states)
shared_expert_output = self.shared_experts(hidden_states_copy)

return sparse_expert_output + shared_expert_output

Expand Down Expand Up @@ -445,7 +492,7 @@ def __init__(
quant_config = vllm_config.quant_config

self.config = config
self.vision_tower = Idefics3VisionTransformer(
self.vision_tower = AriaVisionTransformer(
config.vision_config,
quant_config,
prefix=f"{prefix}.vision_tower",
Expand Down
Loading