From 3d60be0102b95007a38f488b115f15b16e0c2913 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 4 Jan 2025 08:19:00 +0000 Subject: [PATCH 1/6] qwen2 audio Signed-off-by: Roger Wang --- docs/source/models/supported_models.md | 2 +- vllm/model_executor/models/qwen2_audio.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 7682ed104b8c5..b526b32b3c9fd 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -710,7 +710,7 @@ See [this page](#generative-models) for more information on how to use generativ - `Qwen/Qwen2-Audio-7B-Instruct` - - ✅︎ - - + - ✅︎ * - `Qwen2VLForConditionalGeneration` - Qwen2-VL - T + IE+ + VE+ diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index d050fd060353a..1352990ca2dcd 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -310,13 +310,16 @@ def _process_audio_input(self, selected_audio_feature = audio_outputs.last_hidden_state audio_features = self.multi_modal_projector(selected_audio_feature) num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_output_lengths = audio_output_lengths.unsqueeze(1) audio_features_mask = torch.arange(max_audio_tokens).expand( - num_audios, max_audio_tokens - ).to(audio_output_lengths.device) < audio_output_lengths.unsqueeze(1) + num_audios, max_audio_tokens).to( + audio_output_lengths.device) < audio_output_lengths masked_audio_features = audio_features[audio_features_mask].view( -1, embed_dim) - return masked_audio_features + # Split to tuple of embeddings for individual audio input. + return torch.split(masked_audio_features, + audio_output_lengths.flatten().tolist()) def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: audio_input = self._parse_and_validate_audio_input(**kwargs) From 0a1d1e2ca27d5cd2a40fec6a66a56d92a449db7f Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 4 Jan 2025 08:38:44 +0000 Subject: [PATCH 2/6] ultravox Signed-off-by: Roger Wang --- docs/source/models/supported_models.md | 2 +- vllm/model_executor/models/ultravox.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index b526b32b3c9fd..d3ff2e05f3b7d 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -724,7 +724,7 @@ See [this page](#generative-models) for more information on how to use generativ - `fixie-ai/ultravox-v0_3` - - ✅︎ - - + - ✅︎ ``` E Pre-computed embeddings can be inputted for this modality. diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 0b83684c9bac5..ad4a9a4db90d0 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" - import math +import os from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -36,8 +36,10 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings, merge_multimodal_embeddings_from_map) +_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -449,11 +451,15 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - # TODO(ywang96): use merge_multimodal_embeddings after - # v0 is deprecated - merge_multimodal_embeddings_from_map( - inputs_embeds, multimodal_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["audio"]) + # TODO(ywang96): remove this block after v0 is deprecated. + if os.environ.get("VLLM_USE_V1") == "0": + merge_multimodal_embeddings_from_map( + inputs_embeds, multimodal_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["audio"]) + else: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) return inputs_embeds def forward(self, From ed30132f5a8704f5f01eb5567f60af67294f8c68 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 5 Jan 2025 01:22:45 -0800 Subject: [PATCH 3/6] use envs instead Signed-off-by: Roger Wang --- vllm/model_executor/models/ultravox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ad4a9a4db90d0..1e91268fa8f79 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" import math -import os from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -15,6 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder +from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn @@ -452,7 +452,7 @@ def get_input_embeddings( if multimodal_embeddings is not None: # TODO(ywang96): remove this block after v0 is deprecated. - if os.environ.get("VLLM_USE_V1") == "0": + if not envs.VLLM_USE_V1: merge_multimodal_embeddings_from_map( inputs_embeds, multimodal_embeddings, attn_metadata.multi_modal_placeholder_index_maps["audio"]) From b33fc30720675f4a89c109ff9d604fe59373a542 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 7 Jan 2025 05:37:21 +0000 Subject: [PATCH 4/6] fix ultravox Signed-off-by: Roger Wang --- vllm/model_executor/models/ultravox.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 779f90954fe19..ec751891b1577 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -38,6 +38,7 @@ merge_multimodal_embeddings, merge_multimodal_embeddings_from_map) +_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>" _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -200,8 +201,12 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) - placeholder = hf_processor.audio_token_replacement # type: ignore + + # NOTE: Ultravox processing definition uses '<|eot_id|>' as the + # placeholder that will cause confusion with the actual end of turn + # token, thus we override placeholder with a reserved special + # token. + placeholder = _AUDIO_PLACEHOLDER_OVERRIDE def get_replacement_ultravox(item_idx: int): audio_token_len = out_mm_kwargs["audio_token_len"][item_idx] @@ -348,6 +353,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multi_modal_config = multimodal_config assert self.multi_modal_config + self.audio_token_id = config.audio_token_index self.secondary_weights = [] self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: From 550913272a59ae451218b7008944042759697b8d Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 7 Jan 2025 06:14:27 +0000 Subject: [PATCH 5/6] update hf processor instead Signed-off-by: Roger Wang --- vllm/model_executor/models/ultravox.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ec751891b1577..e27928ac3c98a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -67,7 +67,14 @@ def _get_hf_processor( # Ignored in initialization sampling_rate: Optional[int] = None, ) -> ProcessorMixin: - return self.ctx.get_hf_processor() + hf_processor = self.ctx.get_hf_processor() + + # NOTE: Ultravox processing definition uses '<|eot_id|>' as the + # placeholder that will cause confusion with the actual end of turn + # token, thus we override placeholder with a reserved special + # token. + hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE + return hf_processor def _get_feature_extractor( self, @@ -201,12 +208,8 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - - # NOTE: Ultravox processing definition uses '<|eot_id|>' as the - # placeholder that will cause confusion with the actual end of turn - # token, thus we override placeholder with a reserved special - # token. - placeholder = _AUDIO_PLACEHOLDER_OVERRIDE + hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) + placeholder = hf_processor.audio_token_replacement # type: ignore def get_replacement_ultravox(item_idx: int): audio_token_len = out_mm_kwargs["audio_token_len"][item_idx] From 2f93c1bc0f7c85e0989d0cedf3606a77f98d20be Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 7 Jan 2025 06:15:26 +0000 Subject: [PATCH 6/6] remove unused attribute Signed-off-by: Roger Wang --- vllm/model_executor/models/ultravox.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e27928ac3c98a..ecafd157b1d61 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -356,7 +356,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multi_modal_config = multimodal_config assert self.multi_modal_config - self.audio_token_id = config.audio_token_index self.secondary_weights = [] self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: