Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable qwen2-vl multimodal input on v0.6.1 #43

Merged
merged 8 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],

# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
#return "\n".join(missing_placeholders + [text_prompt])
return "".join(missing_placeholders + [text_prompt])


# No need to validate using Pydantic again
Expand Down Expand Up @@ -398,7 +399,6 @@ def _parse_chat_message_content_parts(
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)

return [ConversationMessage(role=role, content=text_prompt)]


Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
# self.resampler.to(device="cuda", dtype=param_dtype)
#self.resampler.to(device="cuda", dtype=param_dtype)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
Expand Down
133 changes: 63 additions & 70 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,37 +200,37 @@ def __init__(
quant_config=quant_config)

# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
if device_available:
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
self._use_flash_attn = True
else:
logger.warning(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now."
)
# selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
# if selected_backend is None:
# backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
# if backend_by_env_var is not None:
# selected_backend = backend_name_to_enum(backend_by_env_var)
# if selected_backend is None:
# # For Volta and Turing GPUs, use xformers instead.
# device_available = current_platform.get_device_capability()[0] >= 8
# if device_available:
# from transformers.utils import is_flash_attn_2_available

# if is_flash_attn_2_available():
# self._use_flash_attn = True
# else:
# logger.warning(
# "Current Qwen2-VL implementation has a bug with "
# "`vllm-flash-attn` inside vision module, so we use "
# "xformers backend instead. You can run `pip install "
# "flash-attn to use flash-attention backend.")
# self._use_flash_attn = False
# else:
# self._use_flash_attn = False
# else:
# if selected_backend == _Backend.FLASH_ATTN:
# self._use_flash_attn = True
# elif selected_backend == _Backend.XFORMERS:
# self._use_flash_attn = False
# else:
# raise RuntimeError(
# f"Qwen2-VL does not support {selected_backend} backend now."
# )

def forward(
self,
Expand Down Expand Up @@ -258,42 +258,37 @@ def forward(
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

if self._use_flash_attn:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func

q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)

context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None)

context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()

output, _ = self.proj(context_layer)
query = q.movedim(1, 2)
key = k.movedim(1, 2)
value = v.movedim(1, 2)

seq_lens = []
for i in range(1, len(cu_seqlens)):
seq_lens.append(cu_seqlens[i]-cu_seqlens[i-1])
att_masks = [None] * len(seq_lens)

num_tokens = q.shape[0] * q.shape[1]
output = torch.empty(
(num_tokens, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
dtype=query.dtype, device=query.device)
start = 0
for seq_len, mask in zip(seq_lens,
att_masks):
end = start + seq_len
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, :, start:end, :],
key[:, :, start:end, :],
value[:, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
scale= self.hidden_size_per_attention_head**-0.5).squeeze(0).movedim(
0, 1)
output[start:end, :, :] = sub_out
start = end
output = output.view(-1, batch_size, self.hidden_size_per_attention_head * self.num_attention_heads_per_partition)

output, _ = self.proj(output)
return output


Expand Down Expand Up @@ -518,9 +513,7 @@ def forward(
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)

# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)

Expand Down Expand Up @@ -926,7 +919,7 @@ def _parse_and_validate_video_input(

def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
pixel_values = image_input["pixel_values"].to(torch.float16)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
Expand Down
49 changes: 48 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -62,6 +63,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)

Expand Down Expand Up @@ -90,6 +92,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
Expand Down Expand Up @@ -122,6 +125,8 @@ def __init__(self,
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
# Multi-modal data support
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
Expand Down Expand Up @@ -179,6 +184,40 @@ def _prepare_prompt(
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if self.runner.model_is_mrope and mm_data:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")

hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
temp_mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
seq_data.mrope_position_delta = mrope_position_delta
mrope_input_positions = [[] for _ in range(3)]
for idx in range(3):
# msections = temp_mrope_input_positions
# for _seq_mrope_input_positions in msections:
mrope_input_positions[idx].extend(
temp_mrope_input_positions[idx])
input_positions = mrope_input_positions


if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
Expand Down Expand Up @@ -241,7 +280,6 @@ def _prepare_prompt(
)

multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)

Expand Down Expand Up @@ -411,6 +449,15 @@ def load_model(self) -> None:
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()

@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"

@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
Expand Down