From 864c90e56046d7b7982e7a9fe0adcdde6a1af45d Mon Sep 17 00:00:00 2001 From: Xiang Xu Date: Tue, 8 Oct 2024 13:53:18 -0700 Subject: [PATCH] add unit tests --- .../vision_language/test_mllama.py | 110 +++++++++++------- vllm/model_executor/models/mllama.py | 3 - 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 78a5c8158e16e..7af15fcd90d14 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -7,26 +7,35 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) +from ....conftest import HfRunner, PromptImageInput, VllmRunner, _ImageAssets from ....utils import large_gpu_test from ...utils import check_logprobs_close -_LIMIT_IMAGE_PER_PROMPT = 1 - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|image|><|begin_of_text|>The meaning of the image is", - "cherry_blossom": - "<|image|><|begin_of_text|>The city is", -}) - -text_only_prompts = [ - "The color of the sky is blue but sometimes it can also be", +_LIMIT_IMAGE_PER_PROMPT = 2 + +# ("prompt", [image indices]) +# image-0: stop sign, image-1: cherry blossom +# For each entry, we will generate a batch of +# samples with different image sizes. +PROMPTS = [ + # Single leading image. + ("<|image|><|begin_of_text|>The meaning of the image is", [0]), + ("<|image|><|begin_of_text|>The city is", [1]), + # Single interleaved image. + ("<|begin_of_text|>The meaning of the image<|image|> is", [0]), + # Multi leading images. + ("<|image|><|image|><|begin_of_text|>Between the first and second image, " + "which is stop sign and which is cherry blossom?", [0, 1]), + # Multi interleaved images. + ("<|begin_of_text|>Between the first image<|image|> and second " + "image<|image|>, which is stop sign and which is cherry blossom?", [0, + 1]), + # Text only. + ("The color of the sky is blue but sometimes it can also be", []), ] models = [ - "meta-llama/Llama-3.2-11B-Vision-Instruct", + "/home/xiangxu_google_com/data/Llama-3.2-11B-Vision-Instruct", ] @@ -59,32 +68,49 @@ def _get_inputs( *, size_factors: Optional[List[float]] = None, sizes: Optional[List[Tuple[int, int]]] = None, -) -> List[Tuple[List[str], PromptImageInput]]: - images = [asset.pil_image for asset in image_assets] - - if size_factors is not None: - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - elif sizes is not None: - inputs_per_image = [( - [ - prompt if size is not None else text_only_prompts[0] - for size in sizes - ], - [ - image.resize(size) if size is not None else None - for size in sizes - ], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - if len(sizes) == 0: - inputs_per_image.append( - (text_only_prompts, [None] * len(text_only_prompts))) - else: - raise ValueError("You must provide either `size_factors` or `sizes`") - - return inputs_per_image +) -> List[Tuple[List[str], List[Optional[PromptImageInput]]]]: + assets = [asset.pil_image for asset in image_assets] + assert len(assets) >= 2 + + # Inputs is a list of batches, a batch is a tuple of + # (prompts, images), prompts is a list of strings, + # images is a nested list of PIL images. + # len(prompts) == len(images) + # A batch will trigger a generate run. + inputs = [] + for entry in PROMPTS: + prompt, image_indices = entry + images = [assets[i] for i in image_indices] + batch_prompts = [] + batch_images = [] + if size_factors is not None: + for factor in size_factors: + if factor is None: + batch_prompts.append(PROMPTS[-1][0]) + batch_images.append(None) + else: + batch_prompts.append(prompt) + resized_images = [ + rescale_image_size(image, factor) for image in images + ] + batch_images.append( + resized_images if resized_images else None) + elif sizes is not None: + for size in sizes: + if size is None: + batch_prompts.append(PROMPTS[-1][0]) + batch_images.append(None) + else: + batch_prompts.append(prompt) + resized_images = [image.resize(size) for image in images] + batch_images.append( + resized_images if resized_images else None) + else: + raise ValueError( + "You must provide either `size_factors` or `sizes`") + assert len(batch_prompts) == len(batch_images) + inputs.append((batch_prompts, batch_images)) + return inputs @overload @@ -151,7 +177,7 @@ def run_test( def _run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], + inputs: List[Tuple[List[str], List[Optional[PromptImageInput]]]], model: str, *, dtype: str, @@ -226,8 +252,6 @@ def process(hf_inputs: BatchEncoding): @pytest.mark.parametrize( "sizes", [ - # Text only - [], # Single-size [(512, 512)], # Single-size, batched diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index c68bbfd79cdcc..4e9cb73c1f6e9 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -136,7 +136,6 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): assert hf_config.vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 - print(f"vllm num_tiles: {num_tiles}") num_tokens = num_tiles * token_per_chunk llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID @@ -1151,7 +1150,6 @@ def get_cross_attention_states( ) -> Tuple[torch.Tensor, torch.Tensor]: # NOTE: llama's reference implementation runs vision model on CPU pixel_values = image_inputs['data'] - print(f"pixel_values={pixel_values.shape}") aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_mask = image_inputs['aspect_ratio_mask'] cross_attention_states = self.vision_model(pixel_values, @@ -1189,7 +1187,6 @@ def get_cross_attention_mask( get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID) for t in batch_token_ids ] - print(f"sparse_mask={sparse_mask}") # Skip generating cross-attention mask if all samples # are text-only or have only 1 leading image.