From f0fe4fe86d45763cb5904ac256ac6241c5eb2fde Mon Sep 17 00:00:00 2001 From: Xiang Xu <117880274+xiangxu-google@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:24:26 -0700 Subject: [PATCH] [Model] Make llama3.2 support multiple and interleaved images (#9095) --- ...e_inference_vision_language_multi_image.py | 23 ++ .../vision_language/test_mllama.py | 85 ++++- vllm/model_executor/models/mllama.py | 318 +++++++++++++++--- 3 files changed, 384 insertions(+), 42 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index c4e4cdc0db95f..69f590fb7950d 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -234,12 +234,35 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: ) +def load_mllama(question, image_urls: List[str]) -> ModelRequestData: + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + # The configuration below has been confirmed to launch on a single L40 GPU. + llm = LLM( + model=model_name, + max_model_len=4096, + max_num_seqs=16, + enforce_eager=True, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + prompt = f"<|image|><|image|><|begin_of_text|>{question}" + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + model_example_map = { "phi3_v": load_phi3v, "internvl_chat": load_internvl, "NVLM_D": load_nvlm_d, "qwen2_vl": load_qwen2_vl, "qwen_vl_chat": load_qwenvl_chat, + "mllama": load_mllama, } diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 78a5c8158e16e..52f74ec885946 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -12,7 +12,7 @@ from ....utils import large_gpu_test from ...utils import check_logprobs_close -_LIMIT_IMAGE_PER_PROMPT = 1 +_LIMIT_IMAGE_PER_PROMPT = 3 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -244,8 +244,9 @@ def process(hf_inputs: BatchEncoding): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, - max_tokens, num_logprobs) -> None: +def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, + model, sizes, dtype, max_tokens, + num_logprobs) -> None: run_test( hf_runner, vllm_runner, @@ -257,3 +258,81 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 + "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 + "<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501 + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes. + [ + stop_sign.resize((512, 512)), + stop_sign, + ], + [ + stop_sign, + stop_sign.resize((512, 1536)), + cherry_blossom.resize((512, 1024)), + ], + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, + dtype, max_tokens, num_logprobs) -> None: + + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 + "<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501 + "which is a stop sign and which is a cherry blossom?", # noqa: E501 + ], + [ + [stop_sign], + [stop_sign, cherry_blossom], + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 45d6ad3c0efa5..66e9b2844620d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -18,6 +18,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -28,9 +29,12 @@ CausalLMOutputWithPast) from transformers.models.mllama.image_processing_mllama import ( get_optimal_tiled_canvas) +from transformers.models.mllama.processing_mllama import ( + get_cross_attention_token_mask) import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs @@ -72,6 +76,16 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs +def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: + num_images = 0 + for token_id in prompt_token_ids[::-1]: + if token_id == MLLAMA_IMAGE_TOKEN_ID: + num_images += 1 + elif num_images > 0: + break + return num_images + + def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # move encoder_prompt to prompt if llm_inputs.get("prompt") is None: @@ -91,12 +105,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): llm_inputs["encoder_multi_modal_data"] = {} return llm_inputs - # get num_tiles if isinstance(multi_modal_data['image'], Image.Image): multi_modal_data['image'] = [multi_modal_data['image']] + # Since only the last group of consecutive images + # are attended by the decoded tokens, we only need to + # get the number of tiles for those images. + num_decode_images = _get_num_image_in_last_group( + llm_inputs["prompt_token_ids"]) hf_config = ctx.model_config.hf_config num_tiles = 0 - for image in multi_modal_data["image"]: + for image in multi_modal_data["image"][::-1]: width, height = image.size tile_size = hf_config.vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( @@ -108,8 +126,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): num_tiles_height = canvas_height // tile_size num_tiles_width = canvas_width // tile_size num_tiles += num_tiles_height * num_tiles_width + num_decode_images -= 1 + if num_decode_images == 0: + break - # set encoder prompt based on num_tiles + # Set encoder prompt length based on the number of tiles. + # This tells the block manager to allocate correct number + # of slots for encoder tokens. assert hf_config.vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 @@ -675,6 +698,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, @@ -697,15 +721,71 @@ def forward( q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) - output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + if attention_mask is not None: + output = self.attention_with_mask(q, k, v, kv_cache, + attention_mask, + kv_range_for_decode, + attn_metadata) + else: + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) out, _ = self.o_proj(output) return out + def attention_with_mask( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + attention_mask: torch.Tensor, + kv_range_for_decode: List[Tuple[int, int]], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Skip writing kv-cache for the initial profiling run. + if len(kv_cache.shape) == 3: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + PagedAttention.write_to_paged_cache( + cached_k, cached_v, key_cache, value_cache, + attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + # We have to call torch.sdpa for prefill when using a + # custom cross-attention mask. Because the mask is not a + # standard causal mask, neither a block diagonal mask which + # can be optimized by xformers.BlockDiagonalMask. + # The mask is specially calculated for supporting multi + # images and interleaved images. + q_len = q.shape[0] + kv_len = k.shape[0] + q = q.transpose(0, 1).view(self.num_local_key_value_heads, + self.num_key_value_groups, q_len, + self.head_dim) + k = k.transpose(0, + 1)[:, + None, :, :].expand(self.num_local_key_value_heads, + self.num_key_value_groups, + kv_len, self.head_dim) + v = v.transpose(0, + 1)[:, + None, :, :].expand(self.num_local_key_value_heads, + self.num_key_value_groups, + kv_len, self.head_dim) + attention_mask = attention_mask.view(1, 1, q_len, kv_len) + output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + is_causal=False) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output + class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention @@ -741,6 +821,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, + kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: torch.Tensor, kv_cache: List[torch.Tensor], attn_metadata: AttentionMetadata, @@ -751,6 +832,7 @@ def forward( hidden_states = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, cross_attention_states=cross_attention_states, kv_cache=kv_cache, attn_metadata=attn_metadata, @@ -804,6 +886,7 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], @@ -820,6 +903,7 @@ def forward( hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask= full_text_row_masked_out_mask, kv_cache=kv_caches[idx], @@ -868,6 +952,7 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], @@ -879,6 +964,7 @@ def forward( positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -1026,36 +1112,102 @@ def _parse_and_validate_image_input(self, **kwargs: object): raise AssertionError("This line should be unreachable.") def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: AttentionMetadata): + attn_metadata: AttentionMetadata, + actual_encoder_seq_lens: List[int]): cross_attention_states_flat = torch.zeros( - sum(attn_metadata.encoder_seq_lens), + sum(actual_encoder_seq_lens), cross_attention_states.shape[-1], device=cross_attention_states.device, dtype=cross_attention_states.dtype) start_pos = 0 - for seq_len, vision_token_in_batch in zip( - attn_metadata.encoder_seq_lens, cross_attention_states): + for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens, + cross_attention_states): end_pos = start_pos + seq_len cross_attention_states_flat[ start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos = end_pos cross_attention_states = cross_attention_states_flat + return cross_attention_states + + def get_cross_attention_states( + self, + image_inputs: MllamaImagePixelInputs, + attn_metadata: AttentionMetadata, + actual_encoder_seq_lens: List[int], + ) -> Tuple[torch.Tensor]: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states = self.flat_encoder_result( + cross_attention_states, attn_metadata, actual_encoder_seq_lens) + + return cross_attention_states + + def get_cross_attention_mask( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + num_tiles: List[List[int]], + num_tokens_per_tile: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + token_ids = input_ids.tolist() + start = 0 + batch_token_ids = [] + for seq_len in attn_metadata.seq_lens: + batch_token_ids.append(token_ids[start:start + seq_len]) + start += seq_len + sparse_mask = [ + get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID) + for t in batch_token_ids + ] + # Skip generating cross-attention mask if all samples + # are text-only or have only 1 leading image. + if skip_attention_mask(sparse_mask): + return None, None + + dense_mask, tile_range_for_decode = \ + convert_sparse_cross_attention_mask_to_dense( + sparse_mask, num_tiles, attn_metadata.seq_lens) + cross_attention_mask = \ + convert_dense_cross_attention_mask_to_tensor( + dense_mask, num_tokens_per_tile, input_ids.device, dtype) + kv_range_for_decode = [[ + t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile + ] for t in tile_range_for_decode] + + return cross_attention_mask, kv_range_for_decode + + def get_full_text_row_masked_out_mask( + self, + attn_metadata: AttentionMetadata, + device: torch.device, + ) -> torch.Tensor: full_text_row_masked_out_mask = torch.ones( (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) start_pos = 0 - for seq_len, encoder_seq_len in zip( - attn_metadata.seq_lens_tensor.cpu(), - attn_metadata.encoder_seq_lens): + for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens): if encoder_seq_len == 0: full_text_row_masked_out_mask[start_pos:start_pos + seq_len] = False start_pos += seq_len full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - cross_attention_states.device) - - return cross_attention_states, full_text_row_masked_out_mask + device) + return full_text_row_masked_out_mask def forward( self, @@ -1069,39 +1221,54 @@ def forward( attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) + cross_attention_states = None + cross_attention_mask = None + kv_range_for_decode = None + + # For 1) text-only prefill and decode, 2) image-present decode. if image_inputs is None: - cross_attention_mask = None full_text_row_masked_out_mask = ( attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( input_ids.device) - cross_attention_states = None skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + + # For image-present prefill. else: - # NOTE: llama's reference implementation runs vision model on CPU - pixel_values = image_inputs['data'] - aspect_ratio_ids = image_inputs['aspect_ratio_ids'] - aspect_ratio_mask = image_inputs['aspect_ratio_mask'] - cross_attention_states = self.vision_model(pixel_values, - aspect_ratio_ids, - aspect_ratio_mask) - cross_attention_states = self.multi_modal_projector( - cross_attention_states) - - bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view( - bsz, -1, image_token_dim) - - cross_attention_states, full_text_row_masked_out_mask = \ - self.flat_encoder_result(cross_attention_states, attn_metadata) skip_cross_attention = False - # TODO: support multi-image by this mask - cross_attention_mask = None + + # Get the actual number of encoder tokens for each sample. + # Because attn_metadata.encoder_seq_lens only counts the last + # group of images for each sample, which is used to cheat the + # block manager to allocate blocks for those images only. + # See input_processor_for_mllama() for more details. + num_tiles_tensor = kwargs.pop("num_tiles") + num_tiles = [t[0].tolist() for t in num_tiles_tensor] + num_tokens_per_tile = (self.image_size // 14)**2 + 1 + actual_encoder_seq_lens = [ + sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles + ] + for actual_len, last_group_len in zip( + actual_encoder_seq_lens, attn_metadata.encoder_seq_lens): + assert actual_len >= last_group_len + + cross_attention_states = self.get_cross_attention_states( + image_inputs, attn_metadata, actual_encoder_seq_lens) + + full_text_row_masked_out_mask = \ + self.get_full_text_row_masked_out_mask( + attn_metadata, input_ids.device) + + cross_attention_mask, kv_range_for_decode = \ + self.get_cross_attention_mask( + input_ids, attn_metadata, num_tiles, + num_tokens_per_tile, cross_attention_states.dtype) outputs = self.language_model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -1140,3 +1307,76 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: + for mask in sparse_mask: + # Skip text-only samples. + if len(mask) == 0: + continue + # If the sample contains more than 1 images, + # we can't skip mask. + if len(mask) != 1: + return False + # If the sample contains only 1 image, + # but the image is not the leading one, + # we can't skip mask. + if mask[0][0] != 0 or mask[0][1] != -1: + return False + return True + + +def convert_sparse_cross_attention_mask_to_dense( + sparse_mask: List[List[List[int]]], + num_tiles: List[List[int]], + lengths: List[int], +) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + total_length = sum(lengths) + total_tiles = sum([sum(tiles) for tiles in num_tiles]) + dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) + # A list of ranges, range[i] = [start, end] means + # if the i-th sample has N tiles in total, the tiles[start, end] + # will be used for cross-attention decoding. + tile_range_for_decode = [] + + seq_start = 0 + tile_start = 0 + for masks, tiles, length in zip(sparse_mask, num_tiles, lengths): + ts, td = -1, 0 + for mask, tile in zip(masks, tiles): + if len(mask) != 2: + continue + start, end = mask + end = min(end, length) + if end == -1: + end = length + if end == length: + if ts == -1: + ts = tile_start + td += tile + dense_mask[seq_start + start:seq_start + end, + tile_start:tile_start + tile] = 1 + tile_start += tile + tile_range_for_decode.append((ts, ts + td)) + seq_start += length + + return dense_mask, tile_range_for_decode + + +def convert_dense_cross_attention_mask_to_tensor( + cross_attention_token_mask: np.ndarray, + num_tokens_per_tile: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device) + mask = mask.repeat_interleave(num_tokens_per_tile, dim=1) + + mask = 1.0 - mask + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min) + + ninf = torch.finfo(dtype).min + full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) + mask *= full_text_mask + # (num_prompt_tokens, num_encoder_tokens) + return mask