From 2fc6834acad080bddceda12ce10acb5897693b49 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:35:47 +0800 Subject: [PATCH] Enable qwen-vl multimodal on 062 (#44) * Enable qwen2-vl multimodal input on v0.6.1 (#43) * enable mrope model * update minicpm * update utils * update qwen2_vl * update * update * enable parallel multimodal input * update * remove error --- vllm/entrypoints/chat_utils.py | 3 +- vllm/model_executor/layers/activation.py | 10 +- vllm/model_executor/models/qwen2_vl.py | 148 ++++++++++------------- vllm/worker/xpu_model_runner.py | 49 +++++++- 4 files changed, 118 insertions(+), 92 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4a575ae8f8537..0806499b4d1a5 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -352,7 +352,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], # NOTE: For now we always add missing placeholders at the front of # the prompt. This may change to be customizable in the future. - return "\n".join(missing_placeholders + [text_prompt]) + #return "\n".join(missing_placeholders + [text_prompt]) + return "".join(missing_placeholders + [text_prompt]) # No need to validate using Pydantic again diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 43056786d35c9..1e07ecee33477 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -151,12 +151,12 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_quick(out, x) return out - def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - from vllm._ipex_ops import ipex_ops as ops + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + # from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_quick(out, x) - return out + # out = torch.empty_like(x) + # ops.gelu_quick(out, x) + # return out # TODO implement forward_xpu for QuickGELU # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 889ebc6c2e1ff..28d071c4213a6 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -202,37 +202,37 @@ def __init__( quant_config=quant_config) # Detect attention implementation. - selected_backend: Optional[_Backend] = get_global_forced_attn_backend() - if selected_backend is None: - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: - # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.has_device_capability(80) - if device_available: - from transformers.utils import is_flash_attn_2_available - - if is_flash_attn_2_available(): - self._use_flash_attn = True - else: - logger.warning( - "Current Qwen2-VL implementation has a bug with " - "`vllm-flash-attn` inside vision module, so we use " - "xformers backend instead. You can run `pip install " - "flash-attn to use flash-attention backend.") - self._use_flash_attn = False - else: - self._use_flash_attn = False - else: - if selected_backend == _Backend.FLASH_ATTN: - self._use_flash_attn = True - elif selected_backend == _Backend.XFORMERS: - self._use_flash_attn = False - else: - raise RuntimeError( - f"Qwen2-VL does not support {selected_backend} backend now." - ) + # selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + # if selected_backend is None: + # backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + # if backend_by_env_var is not None: + # selected_backend = backend_name_to_enum(backend_by_env_var) + # if selected_backend is None: + # # For Volta and Turing GPUs, use xformers instead. + # device_available = current_platform.get_device_capability()[0] >= 8 + # if device_available: + # from transformers.utils import is_flash_attn_2_available + + # if is_flash_attn_2_available(): + # self._use_flash_attn = True + # else: + # logger.warning( + # "Current Qwen2-VL implementation has a bug with " + # "`vllm-flash-attn` inside vision module, so we use " + # "xformers backend instead. You can run `pip install " + # "flash-attn to use flash-attention backend.") + # self._use_flash_attn = False + # else: + # self._use_flash_attn = False + # else: + # if selected_backend == _Backend.FLASH_ATTN: + # self._use_flash_attn = True + # elif selected_backend == _Backend.XFORMERS: + # self._use_flash_attn = False + # else: + # raise RuntimeError( + # f"Qwen2-VL does not support {selected_backend} backend now." + # ) def forward( self, @@ -260,57 +260,37 @@ def forward( if rotary_pos_emb is not None: q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - - if self._use_flash_attn: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func - - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0, - causal=False) - - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) - elif is_cpu(): - seq_length = q.size(1) - q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) - context_layer = rearrange(output, "b h s d -> b s h d ") - else: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() - - output, _ = self.proj(context_layer) + query = q.movedim(1, 2) + key = k.movedim(1, 2) + value = v.movedim(1, 2) + + seq_lens = [] + for i in range(1, len(cu_seqlens)): + seq_lens.append(cu_seqlens[i]-cu_seqlens[i-1]) + att_masks = [None] * len(seq_lens) + + num_tokens = q.shape[0] * q.shape[1] + output = torch.empty( + (num_tokens, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + dtype=query.dtype, device=query.device) + start = 0 + for seq_len, mask in zip(seq_lens, + att_masks): + end = start + seq_len + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, :, start:end, :], + key[:, :, start:end, :], + value[:, :, start:end, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=False, + scale= self.hidden_size_per_attention_head**-0.5).squeeze(0).movedim( + 0, 1) + output[start:end, :, :] = sub_out + start = end + output = output.view(-1, batch_size, self.hidden_size_per_attention_head * self.num_attention_heads_per_partition) + + output, _ = self.proj(output) return output @@ -535,9 +515,7 @@ def forward( grid_thw: torch.Tensor, ) -> torch.Tensor: # patchify - x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) - # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) @@ -948,7 +926,7 @@ def _parse_and_validate_video_input( def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) + pixel_values = image_input["pixel_values"].to(torch.float16) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) return image_embeds diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 70599fc05adb3..88ad760e56c70 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -25,6 +25,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -64,6 +65,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -92,6 +94,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -124,6 +127,8 @@ def __init__(self, self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device + # Multi-modal data support + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -181,6 +186,40 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + if self.runner.model_is_mrope and mm_data: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + temp_mrope_input_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + context_len=0, + ) + seq_data.mrope_position_delta = mrope_position_delta + mrope_input_positions = [[] for _ in range(3)] + for idx in range(3): + # msections = temp_mrope_input_positions + # for _seq_mrope_input_positions in msections: + mrope_input_positions[idx].extend( + temp_mrope_input_positions[idx]) + input_positions = mrope_input_positions + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -243,7 +282,6 @@ def _prepare_prompt( ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -417,6 +455,15 @@ def load_model(self) -> None: def vocab_size(self) -> int: return self.model_config.get_vocab_size() + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage.