Skip to content

Commit

Permalink
[TTS] Add tutorial for training audio codecs
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Mar 22, 2024
1 parent dbc8a6e commit 79ffef9
Show file tree
Hide file tree
Showing 7 changed files with 878 additions and 22 deletions.
1 change: 1 addition & 0 deletions examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)
trainer.fit(model)


Expand Down
8 changes: 3 additions & 5 deletions examples/tts/conf/audio_codec/audio_codec_16000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ model:
log_epochs: [1, 2, 3, 4, 5, 6]
epoch_frequency: 1
log_tensorboard: false
log_wandb: true
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_dequantized: true
log_encoding: false
log_dequantized: false

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
Expand Down Expand Up @@ -129,8 +129,6 @@ model:
_target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]

# The original EnCodec uses hinged loss, but squared-GAN loss is more stable
# and reduces the need to tune the loss weights or use a gradient balancer.
generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

Expand Down
12 changes: 5 additions & 7 deletions examples/tts/conf/audio_codec/audio_codec_24000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: EnCodec
name: AudioCodec

max_epochs: ???
# Adjust batch size based on GPU memory
Expand Down Expand Up @@ -90,13 +90,13 @@ model:
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: true
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_dequantized: true
log_encoding: false
log_dequantized: false

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
Expand Down Expand Up @@ -127,8 +127,6 @@ model:
_target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]

# The original EnCodec uses hinged loss, but squared-GAN loss is more stable
# and reduces the need to tune the loss weights or use a gradient balancer.
generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

Expand Down Expand Up @@ -162,7 +160,7 @@ exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: false
create_wandb_logger: true
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Expand Down
8 changes: 4 additions & 4 deletions examples/tts/conf/audio_codec/encodec_24000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ model:
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: true
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_dequantized: true
log_encoding: false
log_dequantized: false

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
Expand Down Expand Up @@ -162,7 +162,7 @@ exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: false
create_wandb_logger: true
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Expand Down
194 changes: 194 additions & 0 deletions examples/tts/conf/audio_codec/mel_codec_22050.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# This config contains the default values for training 22.05kHz audio codec model which encodes mel spectrogram
# instead of raw audio.
# If you want to train model on other dataset, you can change config values according to your dataset.
# Most dataset-specific arguments are in the head of the config file, see below.

name: MelCodec

max_epochs: ???
# Adjust batch size based on GPU memory
batch_size: 16
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch.
# If null, then weighted sampling is disabled.
weighted_sampling_steps_per_epoch: null

# Dataset metadata for each manifest
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
train_ds_meta: ???
val_ds_meta: ???

log_ds_meta: ???
log_dir: ???

# Modify these values based on your sample rate
sample_rate: 22050
win_length: 1024
hop_length: 256
train_n_samples: 8192 # ~0.37 seconds
# The product of the up_sample_rates should match the hop_length.
# For example 8 * 8 * 2 * 2 = 256.
up_sample_rates: [8, 8, 2, 2]


model:

max_epochs: ${max_epochs}
steps_per_epoch: ${weighted_sampling_steps_per_epoch}

sample_rate: ${sample_rate}
samples_per_frame: ${hop_length}

mel_loss_l1_scale: 1.0
mel_loss_l2_scale: 0.0
stft_loss_scale: 20.0
time_domain_loss_scale: 0.0
commit_loss_scale: 0.0

# Probability of updating the discriminator during each training step
# For example, update the discriminator 1/2 times (1 update for every 2 batches)
disc_updates_per_period: 1
disc_update_period: 2

# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
loss_resolutions: [
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]
]
mel_loss_dims: [5, 10, 20, 40, 80, 160, 320]
mel_loss_log_guard: 1.0
stft_loss_log_guard: 1.0
feature_loss_type: absolute

train_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
dataset_meta: ${train_ds_meta}
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
sample_rate: ${sample_rate}
n_samples: ${train_n_samples}
min_duration: 0.4
max_duration: null

dataloader_params:
batch_size: ${batch_size}
drop_last: true
num_workers: 4

validation_ds:
dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss
dataset_meta: ${val_ds_meta}

dataloader_params:
batch_size: 4
num_workers: 2

# Configures how audio samples are generated and saved during training.
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: false
log_dequantized: false

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
sample_rate: ${sample_rate}
n_samples: null
min_duration: null
max_duration: null
trunc_duration: 10.0 # Only log the first 10 seconds of generated audio.
dataset_meta: ${log_ds_meta}

dataloader_params:
batch_size: 4
num_workers: 2

audio_encoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.MultiBandMelEncoder
mel_bands: [[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80]]
out_channels: 4 # The dimension of each codebook
hidden_channels: 128
filters: 256
mel_processor:
_target_: nemo.collections.tts.modules.audio_codec_modules.MelSpectrogramProcessor
mel_dim: 80
sample_rate: ${sample_rate}
win_length: ${win_length}
hop_length: ${hop_length}

audio_decoder:
_target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder
up_sample_rates: ${up_sample_rates}
input_dim: 32 # Should be equal to len(audio_encoder.mel_bands) * audio_encoder.out_channels
base_channels: 1024 # This is double the base channels of HiFi-GAN V1, making it approximately 4x larger.

vector_quantizer:
_target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer
num_groups: 8 # Should equal len(audio_encoder.mel_bands)
num_levels_per_group: [8, 5, 5, 5] # 8 * 5 * 5 * 5 = 1000 entries per codebook

discriminator:
_target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator
discriminators:
- _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]
- _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator

generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

discriminator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss

optim:
_target_: torch.optim.Adam
lr: 2e-4
betas: [0.8, 0.99]

sched:
name: ExponentialLR
gamma: 0.998

trainer:
num_nodes: 1
devices: 1
accelerator: gpu
strategy: ddp_find_unused_parameters_true
precision: 16
max_epochs: ${max_epochs}
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 5
benchmark: false

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: false
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
create_checkpoint_callback: true
checkpoint_callback_params:
monitor: val_loss
mode: min
save_top_k: 5
save_best_model: true
always_save_nemo: true
resume_if_exists: false
resume_ignore_no_checkpoint: false
10 changes: 4 additions & 6 deletions examples/tts/conf/audio_codec/mel_codec_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ model:
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: true
log_wandb: false

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_dequantized: true
log_encoding: false
log_dequantized: false

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
Expand Down Expand Up @@ -146,8 +146,6 @@ model:
resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]]
- _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator

# The original EnCodec uses hinged loss, but squared-GAN loss is more stable
# and reduces the need to tune the loss weights or use a gradient balancer.
generator_loss:
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss

Expand Down Expand Up @@ -181,7 +179,7 @@ exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: false
create_wandb_logger: true
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Expand Down
Loading

0 comments on commit 79ffef9

Please sign in to comment.