Skip to content

Commit

Permalink
limit max response tokens (#36)
Browse files Browse the repository at this point in the history
If max_response_tokens is set (by default 50), the number of response tokens is truncated to this number. This helps avoid situations in which a very long response cause OOM issues.

See this comparison on the choice of default value:

https://3.basecamp.com/5478728/buckets/36374248/todos/8256002600
  • Loading branch information
zqhuang211 authored Jan 27, 2025
1 parent 158929e commit dfd936d
Show file tree
Hide file tree
Showing 7 changed files with 930 additions and 648 deletions.
1,509 changes: 883 additions & 626 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python = ">=3.10,<4.0"
torch = ">=2.4"
transformers = {version = ">=4.43.1,<4.48.0 || >4.48.0", extras = ["torch"]}
huggingface-hub = ">=0.27.0"
bitsandbytes = "~0.42.0"
peft = "~0.11.1"
simple-parsing = "~0.1.5"
librosa = "~0.10.2.post1"
Expand Down
5 changes: 2 additions & 3 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,14 @@ def _get_sample(self, row) -> Optional[data_sample.VoiceSample]:
).render(
**row,
text_proc=text_proc,
dataset=self,
**self._config.user_template_args,
)
assistant_content = jinja2.Template(
self._config.assistant_template, undefined=jinja2.StrictUndefined
).render(**row, text_proc=text_proc, dataset=self)
).render(**row, text_proc=text_proc)
transcript = jinja2.Template(
self._config.transcript_template, undefined=jinja2.StrictUndefined
).render(**row, text_proc=text_proc, dataset=self)
).render(**row, text_proc=text_proc)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"user_template: {self._config.user_template}")
Expand Down
52 changes: 35 additions & 17 deletions ultravox/model/ultravox_data_proc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -14,6 +14,7 @@ def __init__(
train_on_inputs: bool = False,
inference_mode: bool = False,
include_alt_fields: bool = False,
max_response_tokens: Optional[int] = None,
) -> None:
"""
Pre-processing for the Ultravox model: applies tokenization and audio processing using the UltravoxProcessor
Expand All @@ -37,6 +38,7 @@ def __init__(
if self.inference_mode:
self.train_on_inputs = True
self.include_alt_fields = include_alt_fields
self.max_response_tokens = max_response_tokens

def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
if self.inference_mode:
Expand Down Expand Up @@ -71,6 +73,18 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
# No need to shift the labels as the model does it internally
labels = input_ids.clone()

# Compute the length of the user text
user_text = self.processor.tokenizer.apply_chat_template(
sample.messages[:-1], tokenize=False
)
# TODO: this might be slow due to calling audio_processor twice. We can compute modified input_text_len directly too.
# Revisit when using WhisperProcessor.
user_token_len = self.processor(
text=user_text,
audio=audio,
sampling_rate=sample.sample_rate,
)["input_ids"].shape[-1]

if not self.train_on_inputs:
# Mask the prompt tokens and only compute loss on the assistant message, not the prompt.
# The idea is that the model should only be able to predict the assistant message given the user message.
Expand All @@ -81,18 +95,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
# Labels: -100 -100 -100 -100 <assistant> Brown fox jumps over the lazy dog </s>
#
# Note: The above might look weird because I'm mixing token IDs and text, but that's just for illustration.
input_text = self.processor.tokenizer.apply_chat_template(
sample.messages[:-1], tokenize=False
)

# TODO: this might be slow due to calling audio_processor twice. We can compute modified input_text_len directly too.
# Revisit when using WhisperProcessor.
input_token_len = self.processor(
text=input_text,
audio=audio,
sampling_rate=sample.sample_rate,
)["input_ids"].shape[-1]
labels[:input_token_len] = -100
labels[:user_token_len] = -100

# If include_alt_fields is True, also include alt_input_ids, alt_attention_mask, and alt_labels
if self.include_alt_fields:
Expand All @@ -107,17 +110,32 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
alt_input_ids = alt_inputs["input_ids"].squeeze_(0)
alt_inputs["attention_mask"].squeeze_(0)

alt_user_token_len = user_token_len + len(alt_input_ids) - len(input_ids)
alt_labels = alt_input_ids.clone()
if not self.train_on_inputs:
alt_input_token_len = (
input_token_len + len(alt_input_ids) - len(input_ids)
)
alt_labels[:alt_input_token_len] = -100
alt_labels[:alt_user_token_len] = -100

inputs["alt_input_ids"] = alt_input_ids
inputs["alt_attention_mask"] = alt_inputs["attention_mask"]
inputs["alt_labels"] = alt_labels.tolist()

# Truncate the input_ids and labels if the response is longer than max_response_tokens
if (
self.max_response_tokens
and user_token_len + self.max_response_tokens < len(input_ids)
):
max_tokens = user_token_len + self.max_response_tokens
inputs["input_ids"] = inputs["input_ids"][:max_tokens]
inputs["attention_mask"] = inputs["attention_mask"][:max_tokens]
labels = labels[:max_tokens]
if self.include_alt_fields:
max_alt_tokens = alt_user_token_len + self.max_response_tokens
inputs["alt_input_ids"] = inputs["alt_input_ids"][:max_alt_tokens]
inputs["alt_attention_mask"] = inputs["alt_attention_mask"][
:max_alt_tokens
]
inputs["alt_labels"] = inputs["alt_labels"][:max_alt_tokens]

return {
# input_ids, attention_mask, audio_values, audio_token_start_idx, audio_token_len
# if include_alt_fields is True, also include alt_input_ids, alt_attention_mask, alt_labels
Expand Down
3 changes: 3 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ class TrainConfig:

# Dataloader workers
num_workers: int = 8 if torch.cuda.is_available() else 1
# Training sample control
train_on_inputs: bool = False
# assistant response is truncated to avoid OOM errors
max_response_tokens: Optional[int] = 50

# Device and dtype
device: str = "cuda"
Expand Down
7 changes: 6 additions & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@ eval_sets:
- name: covost2-en-zh
- name: covost2-es-en
- name: covost2-zh-en
- name: librispeech-clean-transcription
- name: librispeech-other-transcription
- name: commonvoice-en-transcription
- name: commonvoice-es-transcription
- name: commonvoice-ru-transcription
- name: alpaca-tts-llama

eval_dataset_args:
max_samples: 2000

Expand Down
1 change: 1 addition & 0 deletions ultravox/training/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def wrap_with_data_proc(self, dataset: datasets.SizedIterableDataset):
processor=self.processor,
train_on_inputs=self.args.train_on_inputs,
include_alt_fields=self.model.loss_config.requires_alt_fields,
max_response_tokens=self.args.max_response_tokens,
)

def get_pipeline(self):
Expand Down

0 comments on commit dfd936d

Please sign in to comment.