Skip to content

Commit

Permalink
[Model] Upgrade Aria to transformers 4.48 (vllm-project#12203)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Matthew Hendrey <[email protected]>
  • Loading branch information
DarkLight1337 authored and mhendrey committed Jan 23, 2025
1 parent 5cc6a09 commit 9f3d5a6
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 379 deletions.
3 changes: 0 additions & 3 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ def run_aria(question: str, modality: str):

# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
tokenizer_mode="slow",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
Expand Down
7 changes: 2 additions & 5 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pytest
from transformers import AutoModelForVision2Seq
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.utils import is_flash_attn_2_available

from vllm.platforms import current_platform
from vllm.utils import identity
Expand Down Expand Up @@ -140,9 +139,7 @@
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
Expand All @@ -158,8 +155,8 @@
max_tokens=64,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
TRANSFORMERS_VERSION < "4.48.0",
reason="HF model requires transformers>=4.48.0",
),
large_gpu_mark(min_gb=64),
],
Expand Down
12 changes: 5 additions & 7 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.multimodal.utils import cached_get_tokenizer

from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS


def _test_processing_correctness(
Expand All @@ -20,12 +21,9 @@ def _test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else:
hf_overrides = {}
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
Expand All @@ -41,7 +39,7 @@ def _test_processing_correctness(
seed=0,
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
hf_overrides=model_info.hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)

Expand Down
67 changes: 62 additions & 5 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import AbstractSet, Mapping, Optional
from typing import AbstractSet, Any, Literal, Mapping, Optional

import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION


@dataclass(frozen=True)
Expand Down Expand Up @@ -38,6 +42,50 @@ class _HfExamplesInfo:
trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""

hf_overrides: dict[str, Any] = field(default_factory=dict)
"""The ``hf_overrides`` required to load the model."""

def check_transformers_version(
self,
*,
on_fail: Literal["error", "skip"],
) -> None:
"""
If the installed transformers version does not meet the requirements,
perform the given action.
"""
if self.min_transformers_version is None:
return

current_version = TRANSFORMERS_VERSION
required_version = self.min_transformers_version
if Version(current_version) < Version(required_version):
msg = (
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")

if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)

def check_available_online(
self,
*,
on_fail: Literal["error", "skip"],
) -> None:
"""
If the model is not available online, perform the given action.
"""
if not self.is_available_online:
msg = "Model is not available online"

if on_fail == "error":
raise RuntimeError(msg)
else:
pytest.skip(msg)


# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
Expand All @@ -48,8 +96,6 @@ class _HfExamplesInfo:
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
Expand Down Expand Up @@ -176,14 +222,17 @@ class _HfExamplesInfo:

_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
min_transformers_version="4.48"),
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
extras={"text_only": "THUDM/chatglm3-6b"},
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
Expand All @@ -194,7 +243,8 @@ class _HfExamplesInfo:
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
Expand Down Expand Up @@ -247,5 +297,12 @@ def get_supported_archs(self) -> AbstractSet[str]:
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
return self.hf_models[model_arch]

def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
for info in self.hf_models.values():
if info.default == model_id:
return info

raise ValueError(f"No example model defined for {model_id}")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
14 changes: 2 additions & 12 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from unittest.mock import patch

import pytest
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm import LLM

Expand All @@ -13,16 +11,8 @@
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if not model_info.is_available_online:
pytest.skip("Model is not available online")
if model_info.min_transformers_version is not None:
current_version = TRANSFORMERS_VERSION
required_version = model_info.min_transformers_version
if Version(current_version) < Version(required_version):
pytest.skip(
f"You have `transformers=={current_version}` installed, but "
f"`transformers>={required_version}` is required to run this "
"model")
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")

# Ensure all model classes can be imported successfully
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)

Expand Down
Loading

0 comments on commit 9f3d5a6

Please sign in to comment.