Skip to content

Commit

Permalink
Enable qwen2-vl multimodal input on v0.6.1 (#43)
Browse files Browse the repository at this point in the history
* enable mrope model

* update minicpm

* update utils

* update qwen2_vl

* update

* update

* enable parallel multimodal input
  • Loading branch information
hzjane authored Oct 11, 2024
1 parent 1075a90 commit 32c883f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 74 deletions.
4 changes: 2 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],

# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
#return "\n".join(missing_placeholders + [text_prompt])
return "".join(missing_placeholders + [text_prompt])


# No need to validate using Pydantic again
Expand Down Expand Up @@ -398,7 +399,6 @@ def _parse_chat_message_content_parts(
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)

return [ConversationMessage(role=role, content=text_prompt)]


Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
# self.resampler.to(device="cuda", dtype=param_dtype)
#self.resampler.to(device="cuda", dtype=param_dtype)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
Expand Down
133 changes: 63 additions & 70 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,37 +200,37 @@ def __init__(
quant_config=quant_config)

# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
if device_available:
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
self._use_flash_attn = True
else:
logger.warning(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now."
)
# selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
# if selected_backend is None:
# backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
# if backend_by_env_var is not None:
# selected_backend = backend_name_to_enum(backend_by_env_var)
# if selected_backend is None:
# # For Volta and Turing GPUs, use xformers instead.
# device_available = current_platform.get_device_capability()[0] >= 8
# if device_available:
# from transformers.utils import is_flash_attn_2_available

# if is_flash_attn_2_available():
# self._use_flash_attn = True
# else:
# logger.warning(
# "Current Qwen2-VL implementation has a bug with "
# "`vllm-flash-attn` inside vision module, so we use "
# "xformers backend instead. You can run `pip install "
# "flash-attn to use flash-attention backend.")
# self._use_flash_attn = False
# else:
# self._use_flash_attn = False
# else:
# if selected_backend == _Backend.FLASH_ATTN:
# self._use_flash_attn = True
# elif selected_backend == _Backend.XFORMERS:
# self._use_flash_attn = False
# else:
# raise RuntimeError(
# f"Qwen2-VL does not support {selected_backend} backend now."
# )

def forward(
self,
Expand Down Expand Up @@ -258,42 +258,37 @@ def forward(
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

if self._use_flash_attn:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func

q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)

context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None)

context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
query = q.movedim(1, 2)
key = k.movedim(1, 2)
value = v.movedim(1, 2)

seq_lens = []
for i in range(1, len(cu_seqlens)):
seq_lens.append(cu_seqlens[i]-cu_seqlens[i-1])
att_masks = [None] * len(seq_lens)

num_tokens = q.shape[0] * q.shape[1]
output = torch.empty(
(num_tokens, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
dtype=query.dtype, device=query.device)
start = 0
for seq_len, mask in zip(seq_lens,
att_masks):
end = start + seq_len
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, :, start:end, :],
key[:, :, start:end, :],
value[:, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
scale= self.hidden_size_per_attention_head**-0.5).squeeze(0).movedim(
0, 1)
output[start:end, :, :] = sub_out
start = end
output = output.view(-1, batch_size, self.hidden_size_per_attention_head * self.num_attention_heads_per_partition)

output, _ = self.proj(output)
return output


Expand Down Expand Up @@ -518,9 +513,7 @@ def forward(
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)

# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)

Expand Down Expand Up @@ -926,7 +919,7 @@ def _parse_and_validate_video_input(

def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
pixel_values = image_input["pixel_values"].to(torch.float16)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
Expand Down
49 changes: 48 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -62,6 +63,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)

Expand Down Expand Up @@ -90,6 +92,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
Expand Down Expand Up @@ -122,6 +125,8 @@ def __init__(self,
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
# Multi-modal data support
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
Expand Down Expand Up @@ -179,6 +184,40 @@ def _prepare_prompt(
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if self.runner.model_is_mrope and mm_data:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")

hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
temp_mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
seq_data.mrope_position_delta = mrope_position_delta
mrope_input_positions = [[] for _ in range(3)]
for idx in range(3):
# msections = temp_mrope_input_positions
# for _seq_mrope_input_positions in msections:
mrope_input_positions[idx].extend(
temp_mrope_input_positions[idx])
input_positions = mrope_input_positions


if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
Expand Down Expand Up @@ -241,7 +280,6 @@ def _prepare_prompt(
)

multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)

Expand Down Expand Up @@ -411,6 +449,15 @@ def load_model(self) -> None:
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()

@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"

@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
Expand Down

0 comments on commit 32c883f

Please sign in to comment.