diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 29fa5d812deb2..ec8acb224fdf3 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -252,6 +252,11 @@ Multimodal Language Models - Image\ :sup:`E` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - + * - :code:`Qwen2VLForConditionalGeneration` + - Qwen2-VL (see note) + - Image\ :sup:`+` / Video\ :sup:`+` + - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. + - * - :code:`UltravoxModel` - Ultravox - Audio\ :sup:`E+` @@ -265,15 +270,14 @@ Multimodal Language Models For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 - For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. +.. note:: + For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. This can be installed by running the following command: - .. code-block:: bash pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 - ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 2ec691608df6d..464eaf334e3de 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -179,6 +179,23 @@ def run_qwen_vl(question): return llm, prompt, stop_token_ids +# Qwen2-VL +def run_qwen2_vl(question): + model_name = "Qwen/Qwen2-VL-7B-Instruct" + + llm = LLM( + model=model_name, + max_num_seqs=5, + ) + + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -191,6 +208,7 @@ def run_qwen_vl(question): "blip-2": run_blip2, "internvl_chat": run_internvl, "qwen_vl": run_qwen_vl, + "qwen2_vl": run_qwen2_vl, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index dd84627b9dc58..ed7e886d57806 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -6,7 +6,7 @@ from argparse import Namespace from typing import List -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer from vllm import LLM, SamplingParams from vllm.multimodal.utils import fetch_image @@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]): for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompt, stop_token_ids, None def load_internvl(question, image_urls: List[str]): @@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]): # https://huggingface.co/OpenGVLab/InternVL2-2B#service stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids + + return llm, prompt, stop_token_ids, None + + +def load_qwen2_vl(question, image_urls: List[str]): + try: + from qwen_vl_utils import process_vision_info + except ModuleNotFoundError: + print('WARNING: `qwen-vl-utils` not installed, input images will not ' + 'be automatically resized. You can enable this functionality by ' + '`pip install qwen-vl-utils`.') + process_vision_info = None + + model_name = "Qwen/Qwen2-VL-7B-Instruct" + + llm = LLM( + model=model_name, + max_num_seqs=5, + max_model_len=32768 if process_vision_info is None else 4096, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + stop_token_ids = None + + if process_vision_info is None: + image_data = [fetch_image(url) for url in image_urls] + else: + image_data, _ = process_vision_info(messages) + + return llm, prompt, stop_token_ids, image_data model_example_map = { "phi3_v": load_phi3v, "internvl_chat": load_internvl, + "qwen2_vl": load_qwen2_vl, } def run_generate(model, question: str, image_urls: List[str]): - llm, prompt, stop_token_ids = model_example_map[model](question, - image_urls) + llm, prompt, stop_token_ids, image_data = model_example_map[model]( + question, image_urls) + if image_data is None: + image_data = [fetch_image(url) for url in image_urls] sampling_params = SamplingParams(temperature=0.0, max_tokens=128, @@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]): { "prompt": prompt, "multi_modal_data": { - "image": [fetch_image(url) for url in image_urls] + "image": image_data }, }, sampling_params=sampling_params) @@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]): - llm, _, stop_token_ids = model_example_map[model](question, image_urls) + llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, diff --git a/requirements-common.txt b/requirements-common.txt index 49a290317f818..4e008112c6cb0 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -28,3 +28,4 @@ importlib_metadata mistral_common >= 1.3.4 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 +einops # Required for Qwen2-VL. diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b058e2755c245..3930a5f465f70 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -1,9 +1,14 @@ import pytest +import transformers from vllm.model_executor.models import _MODELS, ModelRegistry @pytest.mark.parametrize("model_cls", _MODELS) def test_registry_imports(model_cls): + if (model_cls == "Qwen2VLForConditionalGeneration" + and transformers.__version__ < "4.45"): + pytest.skip("Waiting for next transformers release") + # Ensure all model classes can be imported successfully ModelRegistry.resolve_model_cls([model_cls]) diff --git a/vllm/config.py b/vllm/config.py index 4d9310af79ed1..b3e91701c60a4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -773,7 +773,7 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - + """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO @@ -1733,8 +1733,11 @@ def _get_and_verify_max_len( "with rope_scaling. Please raise an issue so we can " "investigate.") - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] + if rope_type == "mrope": + scaling_factor = 1 + else: + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] if rope_type == "yarn": derived_max_model_len = rope_scaling[ "original_max_position_embeddings"] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a42ad81b3eef4..a0b8e81f666c2 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" -ModalityStr = Literal["image", "audio"] +ModalityStr = Literal["image", "audio", "video"] _T = TypeVar("_T") @@ -158,12 +158,18 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" + if model_type == "qwen2_vl": + return "<|vision_start|><|image_pad|><|vision_end|>" raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": if model_type == "ultravox": return "<|reserved_special_token_0|>" raise TypeError(f"Unknown model type: {model_type}") + elif modality == "video": + if model_type == "qwen2_vl": + return "<|vision_start|><|video_pad|><|vision_end|>" + raise TypeError(f"Unknown model type: {model_type}") else: raise TypeError(f"Unknown modality: {modality}") diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d323f6cc432a2..7fa6c5e7fcde4 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -712,6 +712,179 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: return new_freqs +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @staticmethod + def get_input_positions( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + ) -> Tuple[List[List[int]], int]: + """Get mrope input positions and delta value.""" + + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> List[List[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -752,7 +925,7 @@ def get_rope( # The correct one should be "longrope" but keep "su" here # for backward compatible if scaling_type not in {"su", "longrope"}: - scaling_factor = rope_scaling["factor"] + scaling_factor = rope_scaling.get("factor", 1.0) if scaling_type == "llama3": low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] @@ -816,6 +989,16 @@ def get_rope( head_size, rotary_dim, max_position, original_max_position, base, is_neox_style, dtype, short_factor, long_factor, **extra_kwargs) + elif scaling_type == "mrope": + return MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index da907e8a75063..59e8f8866f66b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -53,6 +53,8 @@ "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), + "Qwen2VLForConditionalGeneration": + ("qwen2_vl", "Qwen2VLForConditionalGeneration"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), @@ -90,6 +92,8 @@ "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "UltravoxModel": ("ultravox", "UltravoxModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "Qwen2VLForConditionalGeneration": ("qwen2_vl", + "Qwen2VLForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py new file mode 100644 index 0000000000000..3f8c590a39b00 --- /dev/null +++ b/vllm/model_executor/models/qwen2_vl.py @@ -0,0 +1,1088 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +from array import array +from functools import lru_cache, partial +from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, + Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from PIL import Image +from transformers import Qwen2VLConfig +from transformers.image_utils import (get_image_size, + infer_channel_dimension_format, + to_numpy_array) +from transformers.models.qwen2_vl.configuration_qwen2_vl import ( + Qwen2VLVisionConfig) +from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + make_batched_images, make_batched_videos, smart_resize) + +import vllm.envs as envs +from vllm.attention import AttentionMetadata +from vllm.attention.selector import (_Backend, backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalInputs) +from vllm.multimodal.base import MultiModalData +from vllm.multimodal.image import cached_get_image_processor +from vllm.platforms import current_platform +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) +from vllm.transformers_utils.processor import get_processor + +logger = init_logger(__name__) + +# === Vision Inputs === # + + +class Qwen2VLImageInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2VLVideoInputs(TypedDict): + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +# === Vision Encoder === # + + +class Qwen2VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int = None, + act_layer: Type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config) + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + return output + + +class Qwen2VisionAttention(nn.Module): + + def __init__( + self, + embed_dim: Optional[int] = None, + num_heads: Optional[int] = None, + projection_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size) + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config) + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config) + + # Detect attention implementation. + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + # For Volta and Turing GPUs, use xformers instead. + device_available = current_platform.get_device_capability()[0] >= 8 + if device_available: + from transformers.utils import is_flash_attn_2_available + + if is_flash_attn_2_available(): + self._use_flash_attn = True + else: + logger.warning( + "Current Qwen2-VL implementation has a bug with " + "`vllm-flash-attn` inside vision module, so we use " + "xformers backend instead. You can run `pip install " + "flash-attn to use flash-attention backend.") + self._use_flash_attn = False + else: + self._use_flash_attn = False + else: + if selected_backend == _Backend.FLASH_ATTN: + self._use_flash_attn = True + elif selected_backend == _Backend.XFORMERS: + self._use_flash_attn = False + else: + raise RuntimeError( + f"Qwen2-VL does not support {selected_backend} backend now." + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + x = x.view(*new_x_shape) + + # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + batch_size = q.shape[1] + + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self._use_flash_attn: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + else: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Qwen2VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Type[nn.Module] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Qwen2VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config) + self.mlp = Qwen2VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config) + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_chans, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.embed_dim) + return x + + +class Qwen2VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Type[nn.Module] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) + / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + patch_size: int = vision_config.patch_size + temporal_patch_size: int = vision_config.temporal_patch_size + spatial_merge_size: int = vision_config.spatial_merge_size + in_chans: int = vision_config.in_chans + hidden_size: int = vision_config.hidden_size + embed_dim: int = vision_config.embed_dim + depth: int = vision_config.depth + num_heads: int = vision_config.num_heads + mlp_ratio: float = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + + self.patch_embed = Qwen2VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) for _ in range(depth) + ]) + self.merger = Qwen2VisionPatchMerger( + d_model=hidden_size, + context_dim=embed_dim, + norm_layer=norm_layer, + quant_config=quant_config, + ) + + @property + def dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + @property + def device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + # adapter + x = self.merger(x) + return x + + +# === Vision input helpers === # + +cached_get_processor = lru_cache(get_processor) + + +def mm_input_mapper_for_qwen2_vl( + ctx: InputContext, + data: MultiModalData[object], + data_type_key: str, +) -> MultiModalInputs: + """Input mapper for Qwen2-VL.""" + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + images = None + videos = None + if data_type_key == "image": + images = data + else: + assert data_type_key == "video" + videos = data + + try: + batch_data = image_processor \ + .preprocess(images=images, videos=videos, return_tensors="pt") \ + .data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + + return MultiModalInputs(batch_data) + + +image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, + data_type_key="image") +video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, + data_type_key="video") + + +def _get_vision_info( + image_processor, + height: int, + width: int, + min_pixels: int, + max_pixels: int, + do_resize: bool = True, + data_type_key: str = "image", + mm_count: int = 1, +): + """Get information (resized height / width and number of vision tokens) + of input image / video frame.""" + + if do_resize: + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=image_processor.patch_size * image_processor.merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + else: + resized_height, resized_width = height, width + + if data_type_key == "image": + grid_t = mm_count + else: + assert data_type_key == "video" + grid_t = max(mm_count // image_processor.temporal_patch_size, 1) + + grid_h = resized_height // image_processor.patch_size + grid_w = resized_width // image_processor.patch_size + vision_tokens = grid_t * grid_h * grid_w + llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // + image_processor.merge_size) + + return resized_height, resized_width, llm_num_vision_tokens + + +def _get_max_image_info( + image_processor, + data_type_key: str = "image", + mm_count: int = 1, +): + return _get_vision_info( + image_processor, + height=9999999, + width=9999999, + + # Limit min / max pixels. + min_pixels=max(image_processor.min_pixels, 28 * 28), + max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28), + data_type_key=data_type_key, + mm_count=mm_count, + ) + + +def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: + image_processor = cached_get_image_processor(ctx.model_config.model) + max_resized_height, max_resized_width, max_llm_image_tokens = \ + _get_max_image_info(image_processor, data_type_key=data_type_key, + mm_count=1) + return max_llm_image_tokens + + +get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens, + data_type_key="image") +get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, + data_type_key="video") + + +def dummy_data_for_qwen2_vl( + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] +) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: + image_processor = cached_get_image_processor(ctx.model_config.model) + + num_images = mm_counts["image"] + max_resized_height, max_resized_width, max_llm_image_tokens = \ + _get_max_image_info(image_processor, data_type_key="image", + mm_count=num_images) + if seq_len - max_llm_image_tokens - 2 < 0: + raise RuntimeError( + f"Qwen2-VL cannot process {num_images} images in a prompt, " + "please increase max_model_len or reduce image limit by " + "--limit-mm-per-prompt.") + + # Check video counts. + num_videos = mm_counts["video"] + max_resized_height, max_resized_width, max_llm_video_tokens = \ + _get_max_image_info(image_processor, data_type_key="video", + mm_count=num_videos) + if seq_len - max_llm_video_tokens - 2 < 0: + raise RuntimeError( + f"Qwen2-VL cannot process {num_images} videos in a prompt, " + "please increase max_model_len or reduce video limit by " + "--limit-mm-per-prompt.") + + hf_config = ctx.get_hf_config(Qwen2VLConfig) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [hf_config.vision_start_token_id]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [hf_config.image_token_id]) * max_llm_image_tokens + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [hf_config.vision_end_token_id]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - max_llm_image_tokens - 2) + dummy_seqdata = SequenceData(token_ids) + dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), + color=0) + + return dummy_seqdata, { + "image": dummy_image if num_images == 1 else [dummy_image] * num_images + } + + +def _get_llm_num_vision_tokens( + mm_inputs: list, + data_type_key: str, + image_processor, +): + """Get number of vision tokens of multimodal inputs. + + This method is derived from `transformers.models.qwen2_vl. + image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`. + """ + image = to_numpy_array(mm_inputs[0]) + input_data_format = infer_channel_dimension_format(image) + height, width = get_image_size(image, channel_dim=input_data_format) + _, _, llm_num_vision_tokens = _get_vision_info( + image_processor, + height=height, + width=width, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + do_resize=image_processor.do_resize, + data_type_key=data_type_key, + mm_count=len(mm_inputs), + ) + return llm_num_vision_tokens + + +def input_processor_for_qwen2_vl(ctx: InputContext, + llm_inputs: LLMInputs) -> LLMInputs: + multi_modal_data = llm_inputs.get("multi_modal_data", None) + if multi_modal_data is None: + return llm_inputs + + image_inputs = multi_modal_data.get("image", None) + video_inputs = multi_modal_data.get("video", None) + + processor = cached_get_processor(ctx.model_config.model) + image_processor = processor.image_processor + hf_config = ctx.get_hf_config(Qwen2VLConfig) + + # To avoid redundant processing of vision objects (resize, rescale, etc.), + # we extract code of calculating number of vision tokens from + # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. + # + # The following code is equivalent to: + # prompt = llm_inputs["prompt"] + # inputs = processor(text=[prompt], + # images=image_inputs, + # videos=video_inputs, + # padding=True, + # return_tensors="pt") + # prompt_token_ids = inputs["input_ids"][0].tolist() + + prompt_token_ids = llm_inputs.get("prompt_token_ids", None) + if prompt_token_ids is None: + prompt = llm_inputs["prompt"] + prompt_token_ids = processor.tokenizer( + prompt, + padding=True, + return_tensors=None, + )["input_ids"] + + # Expand image pad tokens. + if image_inputs is not None: + image_indices = [ + idx for idx, token in enumerate(prompt_token_ids) + if token == hf_config.image_token_id + ] + image_inputs = make_batched_images(image_inputs) + assert len(image_indices) == len(image_inputs) + + prompt_token_ids_with_image = [] + for image_cnt, image in enumerate(image_inputs): + num_image_tokens = _get_llm_num_vision_tokens( + [image], + data_type_key="image", + image_processor=image_processor, + ) + if image_cnt == 0: + non_image_tokens = prompt_token_ids[:image_indices[image_cnt]] + else: + non_image_tokens = prompt_token_ids[image_indices[image_cnt - + 1] + + 1:image_indices[image_cnt]] + prompt_token_ids_with_image.extend(non_image_tokens) + prompt_token_ids_with_image.extend( + hf_config.image_token_id for _ in range(num_image_tokens)) + prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] + + 1:]) + prompt_token_ids = prompt_token_ids_with_image + + # Expand video pad tokens. + if video_inputs is not None: + video_indices = [ + idx for idx, token in enumerate(prompt_token_ids) + if token == hf_config.video_token_id + ] + video_inputs = make_batched_videos(video_inputs) + assert len(video_indices) == len(video_inputs) + + prompt_token_ids_with_video = [] + for video_cnt, video in enumerate(video_inputs): + num_video_tokens = _get_llm_num_vision_tokens( + video, + data_type_key="video", + image_processor=image_processor, + ) + if video_cnt == 0: + non_video_tokens = prompt_token_ids[:video_indices[video_cnt]] + else: + non_video_tokens = prompt_token_ids[video_indices[video_cnt - + 1] + + 1:video_indices[video_cnt]] + prompt_token_ids_with_video.extend(non_video_tokens) + prompt_token_ids_with_video.extend( + hf_config.video_token_id for _ in range(num_video_tokens)) + prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] + + 1:]) + prompt_token_ids = prompt_token_ids_with_video + + return LLMInputs( + prompt_token_ids=prompt_token_ids, + prompt=llm_inputs["prompt"], + multi_modal_data=multi_modal_data, + ) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper( + image_input_mapper_for_qwen2_vl) +@MULTIMODAL_REGISTRY.register_input_mapper("video", + video_input_mapper_for_qwen2_vl) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "video", get_max_qwen2_vl_video_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, + config: Qwen2VLConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + assert not cache_config.enable_prefix_caching, \ + "Qwen2-VL currently does not support prefix caching" + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + + # NOTE: Qwen2-VL vision encoder does not support any + # quantization method now. + quant_config=None, + ) + + self.model = Qwen2Model(config, cache_config, quant_config) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def _validate_and_reshape_mm_tensor(self, + mm_input: Union[torch.Tensor, + List[torch.Tensor]], + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim}") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2VLImageInputs(pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2VLVideoInputs( + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input(self, + image_input: Qwen2VLImageInputs) -> torch.Tensor: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, + grid_thw=image_input["image_grid_thw"]) + return image_embeds + + def _process_video_input(self, + video_input: Qwen2VLVideoInputs) -> torch.Tensor: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, + grid_thw=video_input["video_grid_thw"]) + return video_embeds + + def _merge_multimodal_embeddings( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: torch.Tensor, + placeholder_token_id: int, + ) -> torch.Tensor: + mask = (input_ids == placeholder_token_id) + inputs_embeds[mask, :] = multimodal_embeddings + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> SamplerOutput: + """Run forward pass for Qwen2-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if getattr(self.config, "rope_scaling", {}).get("type", + None) == "mrope": + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.model.embed_tokens(input_ids) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + + input_ids = None + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "visual" in name and "qkv.weight" in name: + visual_num_heads = self.config.vision_config.num_heads + visual_embed_dim = self.config.vision_config.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size, + visual_embed_dim) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) + elif "visual" in name and "qkv.bias" in name: + visual_num_heads = self.config.vision_config.num_heads + visual_embed_dim = self.config.vision_config.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + try: + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 17ef9938d0572..032964fe0ac4e 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -79,14 +79,12 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: if len(inputs_list) == 0: return {} - keys = inputs_list[0].keys() - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) for inputs in inputs_list: - if inputs.keys() != keys: - msg = f"Inputs do not share the same keys ({keys})" - raise ValueError(msg) + # For models that supports multiple modalities (e.g. Qwen2-VL), + # different modalities will return different data keys, + # so batch() should skip the same key check. for k, v in inputs.items(): item_lists[k].append(v) diff --git a/vllm/sequence.py b/vllm/sequence.py index a5ebf152ce776..135586831e680 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -165,6 +165,9 @@ class SequenceData(msgspec.Struct, # is called. _new_appended_tokens: List[int] = msgspec.field(default_factory=list) + # It is used to compute mrope_position_ids. + _mrope_position_delta: Optional[int] = None + def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" @@ -219,6 +222,14 @@ def output_token_ids_array(self) -> array: assert isinstance(self._output_token_ids, array) return self._output_token_ids + @property + def mrope_position_delta(self) -> Optional[int]: + return self._mrope_position_delta + + @mrope_position_delta.setter + def mrope_position_delta(self, new_mrope_position_delta): + self._mrope_position_delta = new_mrope_position_delta + def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py new file mode 100644 index 0000000000000..2001746c5f7f9 --- /dev/null +++ b/vllm/transformers_utils/processor.py @@ -0,0 +1,37 @@ +from typing import cast + + +def get_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + **kwargs, +): + """Gets a processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor + from transformers.processing_utils import ProcessorMixin + + try: + processor = AutoProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the processor. If the processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(ProcessorMixin, processor) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 74f7d4e0860d3..cf8bc3e6a18b8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -30,6 +30,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -181,6 +182,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore + self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore @@ -206,6 +208,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, + mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). seq_lens: Optional[List[int]] = None, @@ -266,6 +269,8 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() + self.mrope_input_positions = None + if seq_lens: self.seq_lens = seq_lens else: @@ -327,6 +332,7 @@ def __init__( else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] + self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] self.query_lens = query_lens or [] @@ -357,6 +363,7 @@ def __post_init__(self): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] + self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs @@ -493,6 +500,17 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.query_lens[ seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + if seq_data.mrope_position_delta is not None: + if inter_data.mrope_input_positions is None: + inter_data.mrope_input_positions = [None] * inter_data.n_seqs + + inter_data.mrope_input_positions[ + seq_idx] = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + def _compute_for_prefix_cache_hit( self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): @@ -636,6 +654,40 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mm_kwargs = self.multi_modal_input_mapper(mm_data) inter_data.multi_modal_inputs = mm_kwargs + # special processing for mrope position deltas. + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + + inter_data.mrope_input_positions = [None] * inter_data.n_seqs + for seq_idx in range(inter_data.n_seqs): + seq_data = seq_group_metadata.seq_data[ + inter_data.seq_ids[seq_idx]] + token_ids = seq_data.get_token_ids() + + mrope_input_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=inter_data.context_lens[seq_idx], + ) + + seq_data.mrope_position_delta = mrope_position_delta + inter_data.mrope_input_positions[ + seq_idx] = mrope_input_positions + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): """Add a sequence group to the builder.""" seq_ids = seq_group_metadata.seq_data.keys() @@ -684,10 +736,27 @@ def build(self) -> ModelInputForGPU: # prefix caching and there is no decode request. return self.model_input_cls() - input_positions = [] - for inter_data in self.inter_data_list: - for cur_input_positions in inter_data.input_positions: - input_positions.extend(cur_input_positions) + mrope_input_positions: Optional[List[List[int]]] = None + if any(inter_data.mrope_input_positions is not None + for inter_data in self.inter_data_list): + mrope_input_positions = [[] for _ in range(3)] + for idx in range(3): + for inter_data in self.inter_data_list: + msections = inter_data.mrope_input_positions + if msections is None: + for _seq_input_positions in inter_data.input_positions: + mrope_input_positions[idx].extend( + _seq_input_positions) + else: + for _seq_mrope_input_positions in msections: + mrope_input_positions[idx].extend( + _seq_mrope_input_positions[idx]) + input_positions = None + else: + input_positions = [] + for inter_data in self.inter_data_list: + for cur_input_positions in inter_data.input_positions: + input_positions.extend(cur_input_positions) seq_lens = [] max_decode_seq_len = 0 @@ -724,15 +793,24 @@ def build(self) -> ModelInputForGPU: # Tokens and positions. if cuda_graph_pad_size: input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) - input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) assert self.runner.device is not None input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory) - input_positions_tensor = async_tensor_h2d(input_positions, torch.long, - self.runner.device, - self.runner.pin_memory) - + if mrope_input_positions is not None: + for idx in range(3): + mrope_input_positions[idx].extend( + itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(mrope_input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + else: + input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) # Sequence and query lengths. if cuda_graph_pad_size: seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) @@ -1199,6 +1277,15 @@ def list_prompt_adapters(self) -> Set[int]: raise RuntimeError("PromptAdapter is not enabled.") return self.prompt_adapter_manager.list_adapters() + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. @@ -1229,7 +1316,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - + if self.model_is_mrope: + input_positions = torch.tile(input_positions, (3, 1)) # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. previous_hidden_states = None @@ -1293,7 +1381,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "input_ids": input_tokens[:batch_size], "positions": - input_positions[:batch_size], + input_positions[..., :batch_size], "hidden_or_intermediate_states": hidden_or_intermediate_states[ virtual_engine] # type: ignore