From 9430eca7b895bb5d8713ff25688e035fe98d692d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 06:02:50 +0000 Subject: [PATCH 1/8] Fix doc link Signed-off-by: DarkLight1337 --- docs/source/features/compatibility_matrix.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index b0018ebccf5ba..ee5db70c7d5c8 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ? - * [✗](gh-issue:7968>) + * [✗](gh-issue:7968) * ? * ✅ * From 7ce57c33b414b8f202a71db0c255d75472bbb0d2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 06:03:26 +0000 Subject: [PATCH 2/8] Fix processor tests Signed-off-by: DarkLight1337 --- tests/models/multimodal/processing/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 8658e60bc5b2e..a56a9e2beef22 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -89,7 +89,7 @@ def _test_processing_correctness( mm_data = { k: [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(limit))] + for _ in range(rng.randint(limit + 1))] for k, limit in limit_mm_per_prompt.items() } From 9668969b64509f8aebd6ea651a397d4680627d2a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 06:03:48 +0000 Subject: [PATCH 3/8] Remove old workaround for Qwen2-VL Signed-off-by: DarkLight1337 --- tests/multimodal/utils.py | 3 --- vllm/model_executor/models/qwen2_vl.py | 10 +++------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 9a336b7e60ffc..40fcfeeeac7d0 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -17,10 +17,7 @@ def random_video( min_wh: int, max_wh: int, ): - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 num_frames = rng.randint(min_frames, max_frames) - num_frames = (num_frames // 2) * 2 - w, h = rng.randint(min_wh, max_wh, size=(2, )) return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 34ae7b8c94697..f2071eaff481f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -885,14 +885,10 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - num_frames = min(max(max_total_frames // max(max_videos, 1), 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 - if num_frames > 1 and num_frames % 2 == 1: - num_frames += 1 - - return num_frames + return max(max_frames_per_video, 1) def get_max_video_tokens(self, seq_len: int) -> int: target_width, target_height = self.get_image_size_with_most_features() From d17d8c6a5388017dd1fd20207374b159daf8962b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 06:04:22 +0000 Subject: [PATCH 4/8] Clean up Qwen-VL code Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen.py | 87 +++++++++++++++--------------- 1 file changed, 42 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 8970661243148..533973c671697 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -63,18 +63,6 @@ logger = init_logger(__name__) -# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; -# for the time being, these tags are not considered as special at encoding -# time. This may change as VLLMs multimodal API changes in the future. -IMG_START = "" -IMG_END = "" -IMG_PAD = "" -# Image context is fixed at 256 for all images -MAX_QWEN_IMG_TOKENS = 256 -# Image normalization params -CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) -CLIP_STD = (0.26862954, 0.26130258, 0.27577711) - class QwenImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -622,25 +610,6 @@ def forward( return hidden_states -def build_normalization_transform(image_size: int) -> transforms.Compose: - """ - Build a normalization transform which can be applied to one or - more input images from which we want to extract visual features. - - Args: - image_size: size of the image to be processed for visual embeddings. - - Returns: - Callable transform for normalizing and resizing one RGB image. - """ - return transforms.Compose([ - transforms.Resize((image_size, image_size), - interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), - ]) - - @lru_cache(maxsize=1) def _get_tokenizer_without_image_pad( tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -716,16 +685,34 @@ def __init__( self.config = config self.tokenizer = tokenizer - if hasattr(self.config, "visual"): - self.image_transform = build_normalization_transform( - config.visual["image_size"]) + if vision_config := getattr(self.config, "visual", None): + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) else: self.image_transform = None - special_tokens: dict[str, - int] = tokenizer.special_tokens # type: ignore - self.img_start_id = special_tokens[IMG_START] - self.img_end_id = special_tokens[IMG_END] + @property + def image_start_tag(self) -> str: + return self.tokenizer.image_start_tag # type: ignore + + @property + def image_end_tag(self) -> str: + return self.tokenizer.image_end_tag # type: ignore + + @property + def image_pad_tag(self) -> str: + return self.tokenizer.image_pad_tag # type: ignore def __call__( self, @@ -787,7 +774,14 @@ def get_mm_max_tokens_per_item( return {"image": self.get_num_image_tokens()} def get_num_image_tokens(self) -> int: - return MAX_QWEN_IMG_TOKENS + hf_config = self.get_hf_config() + if not (vision_config := getattr(hf_config, "visual", None)): + return 0 + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): @@ -798,10 +792,12 @@ def get_dummy_processor_inputs( mm_counts: Mapping[str, int], ) -> ProcessorInputs: hf_config = self.info.get_hf_config() - if not hasattr(hf_config, "visual"): + if not (vision_config := getattr(hf_config, "visual", None)): return ProcessorInputs(prompt_text="", mm_data={}) - vision_config = hf_config.visual + processor = self.info.get_hf_processor() + img_start = processor.image_start_tag + img_end = processor.image_end_tag target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) @@ -814,7 +810,7 @@ def get_dummy_processor_inputs( } return ProcessorInputs( - prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n" + prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1)), mm_data=mm_data, ) @@ -873,9 +869,10 @@ def _get_prompt_replacements( special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore - img_start_id = special_tokens[IMG_START] - img_end_id = special_tokens[IMG_END] - img_pad_id = special_tokens[IMG_PAD] + processor = self.info.get_hf_processor() + img_start_id = special_tokens[processor.image_start_tag] + img_end_id = special_tokens[processor.image_end_tag] + img_pad_id = special_tokens[processor.image_pad_tag] num_image_tokens = self.info.get_num_image_tokens() image_tokens = [img_pad_id] * num_image_tokens From 2940cf620d6c7c515600099de39f39874848e407 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 06:41:24 +0000 Subject: [PATCH 5/8] Fix GLM4V Signed-off-by: DarkLight1337 --- vllm/model_executor/models/chatglm.py | 158 ++++++++++---------------- 1 file changed, 61 insertions(+), 97 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 9ee9e9ca80092..696a30f068622 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -4,8 +4,8 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple, - TypedDict, Union) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch from torch import nn @@ -19,7 +19,6 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -37,12 +36,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BatchFeature, - BoundPromptReplacement, MultiModalFieldConfig, - PlaceholderFeaturesInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -53,39 +50,6 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) -logger = init_logger(__name__) - -IMAGE_TOKEN_ID = 151329 - - -def build_normalization_transform(image_size: int) -> transforms.Compose: - """ - Build a normalization transform which can be applied to one or - more input images from which we want to extract visual features. - - Args: - image_size: size of the image to be processed for visual embeddings. - - Returns: - Callable transform for normalizing and resizing one RGB image. - """ - - return transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - (0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711), - ), - ]) - - -def calculate_image_placeholder(vision_config): - return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 - class GLMImagePixelInputs(TypedDict): pixel_values: torch.Tensor @@ -109,9 +73,20 @@ def __init__( self.config = config self.tokenizer = tokenizer - if hasattr(self.config, "vision_config"): - self.image_transform = build_normalization_transform( - config.vision_config["image_size"]) + if vision_config := getattr(config, "vision_config", None): + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) else: self.image_transform = None @@ -150,9 +125,19 @@ def __call__( class GLM4VProcessingInfo(BaseProcessingInfo): - def __init__(self, ctx): - super().__init__(ctx) - self._pre_calculate() + def get_tokenizer(self): + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + return tokenizer + + def get_hf_config(self): + return self.ctx.get_hf_config(ChatGLMConfig) + + def get_hf_processor(self) -> GLM4VProcessor: + return GLM4VProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} @@ -162,27 +147,21 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - - return {"image": self.image_token_num + 2} - - def _pre_calculate(self): - hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - self.image_token_num = calculate_image_placeholder(vision_config) - self.image_size = vision_config["image_size"] + return {"image": self.get_num_image_feature_tokens()} def get_num_image_tokens(self) -> int: - return self.image_token_num + 2 - - def get_image_size(self) -> ImageSize: + hf_config = self.get_hf_config() + if not (vision_config := getattr(hf_config, "vision_config", None)): + return 0 - return ImageSize(height=self.image_size, width=self.image_size) + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length - def get_hf_processor(self) -> GLM4VProcessor: - return GLM4VProcessor( - self.get_hf_config(), - self.get_tokenizer(), - ) + def get_num_image_feature_tokens(self) -> int: + # EVA2CLIPModel has embeddings for boi and eoi tokens as well + return self.get_num_image_tokens() + 2 class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): @@ -192,8 +171,12 @@ def get_dummy_processor_inputs( seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + if not (vision_config := getattr(hf_config, "vision_config", None)): + return ProcessorInputs(prompt_text="", mm_data={}) + + target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) - target_width, target_height = self.info.get_image_size() mm_data = { "image": @@ -201,9 +184,11 @@ def get_dummy_processor_inputs( height=target_height, num_images=num_images) } - text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + return ProcessorInputs( - prompt_text=text, + prompt_text=base_text * num_images, mm_data=mm_data, ) @@ -223,47 +208,26 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: + config = self.info.get_hf_config() + + boi_token_id = config.boi_token_id + image_token_id = config.pad_token_id + eoi_token_id = config.eoi_token_id def get_replacement(item_idx: int): - image_tokens = self.info.image_token_num - return [IMAGE_TOKEN_ID] * image_tokens + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [boi_token_id] + image_tokens + [eoi_token_id] return [ PromptReplacement( modality="image", - target=[IMAGE_TOKEN_ID], + target=[boi_token_id, image_token_id, eoi_token_id], replacement=get_replacement, ), ] - def _apply_prompt_replacements( - self, - token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - token_ids, text, placeholders = super()._apply_prompt_replacements( - token_ids=token_ids, - mm_prompt_repls=mm_prompt_repls, - mm_item_counts=mm_item_counts, - ) - hf_config = self.info.get_hf_config() - boi_token_id = hf_config.boi_token_id - eoi_token_id = hf_config.eoi_token_id - placeholders = { - modality: [ - PlaceholderFeaturesInfo( - modality=p.modality, - item_idx=p.item_idx, - start_idx=p.start_idx - 1, - tokens=[boi_token_id] + p.tokens + [eoi_token_id], - ) for p in ps - ] - for modality, ps in placeholders.items() - } - - return token_ids, text, placeholders - class GLMAttention(nn.Module): @@ -618,7 +582,7 @@ def get_input_embeddings( multimodal_embeddings=multimodal_embeddings, placeholder_token_id=[ self.config.boi_token_id, - IMAGE_TOKEN_ID, + self.config.pad_token_id, self.config.eoi_token_id, ], ) From a5a5ebb693e27ea7694e6de08862765e9de5e43a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 07:09:33 +0000 Subject: [PATCH 6/8] Early exit for text-only models in `_get_prompt_replacements` Signed-off-by: DarkLight1337 --- vllm/model_executor/models/chatglm.py | 10 ++++++---- vllm/model_executor/models/qwen.py | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 696a30f068622..153c85cfb2141 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -208,11 +208,13 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - config = self.info.get_hf_config() + hf_config = self.info.get_hf_config() + if not hasattr(hf_config, "vision_config"): + return [] - boi_token_id = config.boi_token_id - image_token_id = config.pad_token_id - eoi_token_id = config.eoi_token_id + boi_token_id = hf_config.boi_token_id + image_token_id = hf_config.pad_token_id + eoi_token_id = hf_config.eoi_token_id def get_replacement(item_idx: int): num_image_tokens = self.info.get_num_image_tokens() diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 533973c671697..4b8aeaddbdd37 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -865,6 +865,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + if not hasattr(hf_config, "visual"): + return [] + tokenizer = self.info.get_tokenizer() special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore From c764d651efc73da1cfb55a717390cf6dd1a0f49d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 07:17:02 +0000 Subject: [PATCH 7/8] Avoid regressions Signed-off-by: DarkLight1337 --- tests/models/decoder_only/language/test_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 1ad56241535b8..c07dbb2665992 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -26,6 +26,9 @@ "google/gemma-1.1-2b-it", # gemma marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + pytest.param( + "THUDM/chatglm3-6b", # ChatGLM (text-only) + ), pytest.param( "meta-llama/Llama-3.2-1B-Instruct", # llama marks=[pytest.mark.core_model, pytest.mark.cpu_model], @@ -43,6 +46,9 @@ "microsoft/phi-2", # phi marks=[pytest.mark.core_model], ), + pytest.param( + "Qwen/Qwen-7B", # qwen (text-only) + ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[pytest.mark.core_model], From e993bd3fcc02adff9794cb2588c29351361422b7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 10 Feb 2025 07:34:56 +0000 Subject: [PATCH 8/8] Patch Signed-off-by: DarkLight1337 --- tests/models/decoder_only/language/test_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index c07dbb2665992..c6d5244318a32 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -74,6 +74,10 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: + if model.startswith("THUDM/chatglm3"): + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.transformer.output_layer + hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs)