From 1cd050c389f8a61c1a0f847a83e929e6355276f9 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 16:25:31 +0100 Subject: [PATCH 01/11] add audio_token attribute to proc --- .../models/qwen2_audio/processing_qwen2_audio.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py index eabf5b7069f2..e4c856bd584f 100644 --- a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py @@ -40,15 +40,18 @@ class Qwen2AudioProcessor(ProcessorMixin): chat_template (`Optional[str]`, *optional*): The Jinja template to use for formatting the conversation. If not provided, the default chat template is used. + audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`): + The token to use for audio tokens. """ attributes = ["feature_extractor", "tokenizer"] feature_extractor_class = "WhisperFeatureExtractor" tokenizer_class = "AutoTokenizer" - def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None): + def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None, audio_token="<|AUDIO|>"): if chat_template is None: chat_template = self.default_chat_template + self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token super().__init__(feature_extractor, tokenizer, chat_template=chat_template) def __call__( From 4de1294bb0ce98e428aa30ad8a4cf92d6fb4bebc Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 16:27:56 +0100 Subject: [PATCH 02/11] expand input_ids --- .../qwen2_audio/processing_qwen2_audio.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py index e4c856bd584f..a33a5dc97924 100644 --- a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py @@ -91,7 +91,8 @@ def __call__( if text is None: raise ValueError("You need to specify either a `text` input to process.") - inputs = self.tokenizer(text, padding=padding, **kwargs) + elif isinstance(text, str): + text = [text] if audios is not None: audio_inputs = self.feature_extractor( @@ -100,6 +101,26 @@ def __call__( audio_inputs["feature_attention_mask"] = audio_inputs.pop( "attention_mask" ) # rename attention_mask to prevent conflicts later on + + expanded_text = [] + audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist() + for sample in text: + replace_str = [] + while self.audio_token in sample: + audio_length = audio_lengths.pop(0) + input_length = (audio_length - 1) // 2 + 1 + num_audio_tokens = (input_length - 2) // 2 + 1 + replace_str.append(self.audio_token * num_audio_tokens) + sample = sample.replace(self.audio_token, "", 1) + + while "" in sample: + sample = sample.replace("", replace_str.pop(0), 1) + expanded_text.append(sample) + text = expanded_text + + inputs = self.tokenizer(text, padding=padding, **kwargs) + + if audios is not None: inputs.update(audio_inputs) return BatchFeature(data={**inputs}) From 9b82708c04d88a10da4dc3c715861986edb9b13d Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 16:42:22 +0100 Subject: [PATCH 03/11] and legacy and expanded input_ids --- .../qwen2_audio/modeling_qwen2_audio.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 44a5b5ce3155..012eb8fd1857 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -1197,9 +1197,39 @@ def forward( selected_audio_feature = audio_outputs.last_hidden_state audio_features = self.multi_modal_projector(selected_audio_feature) - inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features( - audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels - ) + # if we have consecutive audio tokens, then it means we expanded input_ids in processing + audio_tokens = input_ids == self.config.audio_token_index + legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0 + + if legacy_processing: + logger.warning_once( + "Expanding inputs for audio tokens in Qwen2Audio should be done in processing." + ) + inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features( + audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels + ) + else: + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + audio_output_lengths.device + ) < audio_output_lengths.unsqueeze(1) + audio_features = audio_features[audio_features_mask] + + n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item() + n_audio_features = audio_features.shape[0] + + if n_audio_tokens != n_audio_features: + raise ValueError( + f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}" + ) + special_audio_mask = ( + (input_ids == self.config.audio_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) outputs = self.language_model( attention_mask=attention_mask, From 89d0d1bbb4f54f539574396c5ae415d50c655974 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 16:42:35 +0100 Subject: [PATCH 04/11] test update --- .../qwen2_audio/test_modeling_qwen2_audio.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 4806ec2c72d3..36fba03daad2 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -51,7 +51,7 @@ def __init__( parent, ignore_index=-100, audio_token_index=0, - seq_length=7, + seq_length=25, feat_seq_length=60, text_config={ "model_type": "qwen2", @@ -93,7 +93,7 @@ def __init__( self.is_training = is_training self.batch_size = 3 - self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1 + self.encoder_seq_length = seq_length def get_config(self): return Qwen2AudioConfig( @@ -118,11 +118,13 @@ def prepare_config_and_inputs(self): def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, input_features_values, feature_attention_mask = config_and_inputs + input_length = (input_features_values.shape[-1] - 1) // 2 + 1 + num_audio_tokens = (input_length - 2) // 2 + 1 input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) attention_mask[:, :1] = 0 # we are giving 3 audios let's make sure we pass in 3 audios tokens - input_ids[:, 1] = config.audio_token_index + input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index inputs_dict = { "input_features": input_features_values, "feature_attention_mask": feature_attention_mask, @@ -262,7 +264,9 @@ def test_small_model_integration_test_single(self): 25, 220, 151647, - 151646, + ] + + [151646] * 101 + + [ 151648, 198, 3838, @@ -280,7 +284,11 @@ def test_small_model_integration_test_single(self): ) self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) - EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>" + EXPECTED_DECODED_TEXT = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>" + + "<|AUDIO|>" * 101 + + "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>" + ) self.assertEqual( self.processor.decode(output[0], skip_special_tokens=False), From bd4ec1c743da6185abc9d0238b16138a51aac5a8 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 17:19:28 +0100 Subject: [PATCH 05/11] split lines --- .../models/qwen2_audio/modeling_qwen2_audio.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 012eb8fd1857..743e4c089c13 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -1210,9 +1210,8 @@ def forward( ) else: num_audios, max_audio_tokens, embed_dim = audio_features.shape - audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( - audio_output_lengths.device - ) < audio_output_lengths.unsqueeze(1) + audio_features_mask = torch.arange(max_audio_tokens, device=audio_output_lengths.device)[None, :] + audio_features_mask = audio_features_mask < audio_output_lengths[:, None] audio_features = audio_features[audio_features_mask] n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item() @@ -1222,12 +1221,8 @@ def forward( raise ValueError( f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}" ) - special_audio_mask = ( - (input_ids == self.config.audio_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device) + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) From fa85ac401bc91c6faf326a27b11a55bc341fa47d Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 19:02:12 +0100 Subject: [PATCH 06/11] add possibility not to provide eos and bos audio tokens --- .../qwen2_audio/processing_qwen2_audio.py | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py index a33a5dc97924..5c7cbddb4eee 100644 --- a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py @@ -42,16 +42,30 @@ class Qwen2AudioProcessor(ProcessorMixin): is used. audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`): The token to use for audio tokens. + audio_bos_token (`str`, *optional*, defaults to `"<|audio_bos|>"`): + The token to use for audio bos tokens. + audio_eos_token (`str`, *optional*, defaults to `"<|audio_eos|>"`): + The token to use for audio eos tokens. """ attributes = ["feature_extractor", "tokenizer"] feature_extractor_class = "WhisperFeatureExtractor" tokenizer_class = "AutoTokenizer" - def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None, audio_token="<|AUDIO|>"): + def __init__( + self, + feature_extractor=None, + tokenizer=None, + chat_template=None, + audio_token="<|AUDIO|>", + audio_bos_token="<|audio_bos|>", + audio_eos_token="<|audio_eos|>", + ): if chat_template is None: chat_template = self.default_chat_template self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token + self.audio_bos_token = tokenizer.audio_bos_token if hasattr(tokenizer, "audio_bos_token") else audio_bos_token + self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token super().__init__(feature_extractor, tokenizer, chat_template=chat_template) def __call__( @@ -110,7 +124,26 @@ def __call__( audio_length = audio_lengths.pop(0) input_length = (audio_length - 1) // 2 + 1 num_audio_tokens = (input_length - 2) // 2 + 1 - replace_str.append(self.audio_token * num_audio_tokens) + + expanded_audio_token = self.audio_token * num_audio_tokens + + audio_token_start_idx = sample.find(self.audio_token) + audio_token_end_idx = audio_token_start_idx + len(self.audio_token) + + has_bos = ( + sample[audio_token_start_idx - len(self.audio_bos_token) : audio_token_start_idx] + == self.audio_bos_token + ) + has_eos = ( + sample[audio_token_end_idx : audio_token_end_idx + len(self.audio_eos_token)] + == self.audio_eos_token + ) + + # Check if this audio token is surrounded by bos/eos tokens + if not has_bos and not has_eos: + expanded_audio_token = self.audio_bos_token + expanded_audio_token + self.audio_eos_token + + replace_str.append(expanded_audio_token) sample = sample.replace(self.audio_token, "", 1) while "" in sample: From 71ce83f54d5afc8ed77d0b771b849125c6ce0f0a Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 19:02:58 +0100 Subject: [PATCH 07/11] raise errors --- .../models/qwen2_audio/processing_qwen2_audio.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py index 5c7cbddb4eee..82bfab524d65 100644 --- a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py @@ -107,6 +107,16 @@ def __call__( raise ValueError("You need to specify either a `text` input to process.") elif isinstance(text, str): text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # ensure we have as much audios as audio tokens + num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) + num_audios = 1 if type(audios) == np.ndarray else len(audios) + if num_audio_tokens != num_audios: + raise ValueError( + f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" + ) if audios is not None: audio_inputs = self.feature_extractor( @@ -118,6 +128,7 @@ def __call__( expanded_text = [] audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist() + for sample in text: replace_str = [] while self.audio_token in sample: From 09870d849db69e597ebef759f8a1567fa093fc19 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 19:03:49 +0100 Subject: [PATCH 08/11] test incorrect number of audio tokens --- .../qwen2_audio/test_modeling_qwen2_audio.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 36fba03daad2..94527b3d809b 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -295,6 +295,53 @@ def test_small_model_integration_test_single(self): EXPECTED_DECODED_TEXT, ) + # test the error when incorrect number of audio tokens + inputs["input_ids"] = torch.tensor( + [ + [ + 151644, + 8948, + 198, + 2610, + 525, + 264, + 10950, + 17847, + 13, + 151645, + 198, + 151644, + 872, + 198, + 14755, + 220, + 16, + 25, + 220, + 151647, + ] + + [151646] * 200 + + [ + 151648, + 198, + 3838, + 594, + 429, + 5112, + 30, + 151645, + 198, + 151644, + 77091, + 198, + ] + ] + ) + with self.assertRaisesRegex( + ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101" + ): + model.generate(**inputs, max_new_tokens=32) + @slow def test_small_model_integration_test_batch(self): # Let' s make sure we test the preprocessing to replace what is used From 4569047a908f48fdd21d9e91c11ff5d76aa9e882 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 6 Jan 2025 19:18:05 +0100 Subject: [PATCH 09/11] add example --- docs/source/en/model_doc/qwen2_audio.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/docs/source/en/model_doc/qwen2_audio.md b/docs/source/en/model_doc/qwen2_audio.md index f399a7e7320c..d994e38ed275 100644 --- a/docs/source/en/model_doc/qwen2_audio.md +++ b/docs/source/en/model_doc/qwen2_audio.md @@ -34,6 +34,31 @@ The abstract from the paper is the following: `Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen) +### Inference + +```python +from io import BytesIO +from urllib.request import urlopen +import librosa +from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration + +model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True, device_map="auto") +processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_code=True) + +prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" +url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3" +audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate) +inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device) + +generate_ids = model.generate(**inputs, max_length=256) +generate_ids = generate_ids[:, inputs.input_ids.size(1):] + +response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + +# We can also omit the audio_bos and audio_eos tokens +prompt = "<|AUDIO|><|audio_eos|>Generate the caption in English:" +``` + In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose. ### Voice Chat Inference From d8227e479ffa29253a8d8a6549b5fb73b90e26d3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Tue, 7 Jan 2025 16:35:31 +0100 Subject: [PATCH 10/11] fmt --- .../qwen2_audio/test_modeling_qwen2_audio.py | 96 +++---------------- 1 file changed, 14 insertions(+), 82 deletions(-) diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 94527b3d809b..0c1afb0fd2bd 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -241,47 +241,13 @@ def test_small_model_integration_test_single(self): output = model.generate(**inputs, max_new_tokens=32) - EXPECTED_INPUT_IDS = torch.tensor( - [ - [ - 151644, - 8948, - 198, - 2610, - 525, - 264, - 10950, - 17847, - 13, - 151645, - 198, - 151644, - 872, - 198, - 14755, - 220, - 16, - 25, - 220, - 151647, - ] - + [151646] * 101 - + [ - 151648, - 198, - 3838, - 594, - 429, - 5112, - 30, - 151645, - 198, - 151644, - 77091, - 198, - ] - ] - ) + # fmt: off + EXPECTED_INPUT_IDS = torch.tensor([[ + 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647, + *[151646] * 101, + 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198, + ]]) + # fmt: on self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) EXPECTED_DECODED_TEXT = ( @@ -296,47 +262,13 @@ def test_small_model_integration_test_single(self): ) # test the error when incorrect number of audio tokens - inputs["input_ids"] = torch.tensor( - [ - [ - 151644, - 8948, - 198, - 2610, - 525, - 264, - 10950, - 17847, - 13, - 151645, - 198, - 151644, - 872, - 198, - 14755, - 220, - 16, - 25, - 220, - 151647, - ] - + [151646] * 200 - + [ - 151648, - 198, - 3838, - 594, - 429, - 5112, - 30, - 151645, - 198, - 151644, - 77091, - 198, - ] - ] - ) + # fmt: off + inputs["input_ids"] = torch.tensor([[ + 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647, + *[151646] * 200, + 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198, + ]]) + # fmt: on with self.assertRaisesRegex( ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101" ): From e7b482686eec3ad470a7814f2ef2bd90d2d25058 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Tue, 7 Jan 2025 16:35:50 +0100 Subject: [PATCH 11/11] typo --- docs/source/en/model_doc/qwen2_audio.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/qwen2_audio.md b/docs/source/en/model_doc/qwen2_audio.md index d994e38ed275..2ef947ce430d 100644 --- a/docs/source/en/model_doc/qwen2_audio.md +++ b/docs/source/en/model_doc/qwen2_audio.md @@ -56,7 +56,13 @@ generate_ids = generate_ids[:, inputs.input_ids.size(1):] response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] # We can also omit the audio_bos and audio_eos tokens -prompt = "<|AUDIO|><|audio_eos|>Generate the caption in English:" +prompt = "<|AUDIO|>Generate the caption in English:" +inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device) + +generate_ids = model.generate(**inputs, max_length=256) +generate_ids = generate_ids[:, inputs.input_ids.size(1):] + +response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ``` In the following, we demonstrate how to use `Qwen2-Audio-7B-Instruct` for the inference, supporting both voice chat and audio analysis modes. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.