Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Qwen2Audio] handle input ids expansion during processing #35534

Merged
merged 12 commits into from
Jan 7, 2025
25 changes: 25 additions & 0 deletions docs/source/en/model_doc/qwen2_audio.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ The abstract from the paper is the following:

`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)

### Inference

```python
from io import BytesIO
from urllib.request import urlopen
import librosa
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True, device_map="auto")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True)

prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)

generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]

response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

# We can also omit the audio_bos and audio_eos tokens
prompt = "<|AUDIO|><|audio_eos|>Generate the caption in English:"
```

In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.

### Voice Chat Inference
Expand Down
31 changes: 28 additions & 3 deletions src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,9 +1197,34 @@ def forward(
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)

inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
audio_tokens = input_ids == self.config.audio_token_index
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0

if legacy_processing:
eustlb marked this conversation as resolved.
Show resolved Hide resolved
logger.warning_once(
"Expanding inputs for audio tokens in Qwen2Audio should be done in processing."
)
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
else:
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(max_audio_tokens, device=audio_output_lengths.device)[None, :]
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
audio_features = audio_features[audio_features_mask]

n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
n_audio_features = audio_features.shape[0]

if n_audio_tokens != n_audio_features:
raise ValueError(
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
)
special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device)
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
72 changes: 70 additions & 2 deletions src/transformers/models/qwen2_audio/processing_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,32 @@ class Qwen2AudioProcessor(ProcessorMixin):
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`):
The token to use for audio tokens.
audio_bos_token (`str`, *optional*, defaults to `"<|audio_bos|>"`):
The token to use for audio bos tokens.
audio_eos_token (`str`, *optional*, defaults to `"<|audio_eos|>"`):
The token to use for audio eos tokens.
"""

attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "AutoTokenizer"

def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
def __init__(
self,
feature_extractor=None,
tokenizer=None,
chat_template=None,
audio_token="<|AUDIO|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
):
if chat_template is None:
chat_template = self.default_chat_template
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
self.audio_bos_token = tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down Expand Up @@ -88,7 +105,18 @@ def __call__(

if text is None:
raise ValueError("You need to specify either a `text` input to process.")
inputs = self.tokenizer(text, padding=padding, **kwargs)
elif isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if type(audios) == np.ndarray else len(audios)
if num_audio_tokens != num_audios:
raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
)

if audios is not None:
audio_inputs = self.feature_extractor(
Expand All @@ -97,6 +125,46 @@ def __call__(
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename attention_mask to prevent conflicts later on

expanded_text = []
eustlb marked this conversation as resolved.
Show resolved Hide resolved
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()

for sample in text:
replace_str = []
while self.audio_token in sample:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1

expanded_audio_token = self.audio_token * num_audio_tokens

audio_token_start_idx = sample.find(self.audio_token)
audio_token_end_idx = audio_token_start_idx + len(self.audio_token)

has_bos = (
sample[audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx]
== self.audio_bos_token
)
has_eos = (
sample[audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token)]
== self.audio_eos_token
)

# Check if this audio token is surrounded by bos/eos tokens
if not has_bos and not has_eos:
expanded_audio_token = self.audio_bos_token + expanded_audio_token + self.audio_eos_token

replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)

while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text

inputs = self.tokenizer(text, padding=padding, **kwargs)

if audios is not None:
inputs.update(audio_inputs)

return BatchFeature(data={**inputs})
Expand Down
65 changes: 60 additions & 5 deletions tests/models/qwen2_audio/test_modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
parent,
ignore_index=-100,
audio_token_index=0,
seq_length=7,
seq_length=25,
feat_seq_length=60,
text_config={
"model_type": "qwen2",
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
self.is_training = is_training

self.batch_size = 3
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
self.encoder_seq_length = seq_length

def get_config(self):
return Qwen2AudioConfig(
Expand All @@ -118,11 +118,13 @@ def prepare_config_and_inputs(self):
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_features_values, feature_attention_mask = config_and_inputs
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 audios let's make sure we pass in 3 audios tokens
input_ids[:, 1] = config.audio_token_index
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
inputs_dict = {
"input_features": input_features_values,
"feature_attention_mask": feature_attention_mask,
Expand Down Expand Up @@ -262,7 +264,9 @@ def test_small_model_integration_test_single(self):
25,
220,
151647,
151646,
]
+ [151646] * 101
+ [
151648,
198,
3838,
Expand All @@ -280,13 +284,64 @@ def test_small_model_integration_test_single(self):
)
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
)

eustlb marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False),
EXPECTED_DECODED_TEXT,
)

# test the error when incorrect number of audio tokens
inputs["input_ids"] = torch.tensor(
[
[
151644,
8948,
198,
2610,
525,
264,
10950,
17847,
13,
151645,
198,
151644,
872,
198,
14755,
220,
16,
25,
220,
151647,
]
+ [151646] * 200
+ [
151648,
198,
3838,
594,
429,
5112,
30,
151645,
198,
151644,
77091,
198,
]
]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use # fmt: skip

with self.assertRaisesRegex(
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
):
model.generate(**inputs, max_new_tokens=32)

@slow
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used
Expand Down