Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine tacotron2, add nv configs and results #251

Merged
merged 4 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions training/benchmarks/tacotron2/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main() -> Tuple[Any, Any]:
init_helper.set_seed(config.seed, model_driver.config.vendor)

world_size = dist_pytorch.get_world_size()
config.distributed = world_size > 1 or config.multiprocessing_distributed
config.distributed = world_size > 1

# 构建dataset, dataloader 【train && validate】
train_dataset = build_train_dataset(config)
Expand Down Expand Up @@ -93,19 +93,20 @@ def main() -> Tuple[Any, Any]:
# TRAIN_START
dist_pytorch.barrier(config.vendor)
model_driver.event(Event.TRAIN_START)
raw_train_start_time = logger.previous_log_time # 训练起始时间,单位为ms
raw_train_start_time = time.time()

# 训练过程
training_state.epoch = 1
while not training_state.end_training:
trainer.train_one_epoch(train_dataloader)
training_state.epoch += 1


# TRAIN_END事件
model_driver.event(Event.TRAIN_END)
raw_train_end_time = logger.previous_log_time # 训练结束时间,单位为ms

# 训练时长,单位为秒
training_state.raw_train_time = (raw_train_end_time -
raw_train_start_time) / 1e+3
training_state.raw_train_time = time.time() - raw_train_start_time

return config, training_state

Expand All @@ -119,18 +120,33 @@ def main() -> Tuple[Any, Any]:
# 训练信息写日志
e2e_time = time.time() - start
if config_update.do_train:
training_perf = (dist_pytorch.global_batch_size(config_update) *
state.global_steps) / state.raw_train_time
finished_info = {
"e2e_time": e2e_time,
"training_samples_per_second": training_perf,
"converged": state.converged,
"raw_train_time": state.raw_train_time,
"init_time": state.init_time,
"epoch": state.epoch,
"global_steps": state.global_steps,
"train_loss": state.train_loss,
"val_loss": state.val_loss,
"e2e_time":
e2e_time,
"converged":
state.converged,
"raw_train_time":
state.raw_train_time,
"init_time":
state.init_time,
"epoch":
state.epoch,
"global_steps":
state.global_steps,
"train_loss":
state.train_loss,
"val_loss":
state.val_loss,
"num_trained_samples":
state.num_mels,
"pure_training_computing_time":
state.pure_compute_time,
"throughput(ips)_raw":
state.num_mels / state.raw_train_time,
"throughput(ips)_no_eval":
state.num_mels / state.no_eval_time,
"throughput(ips)_pure_compute":
state.num_mels / state.pure_compute_time,
}
else:
finished_info = {"e2e_time": e2e_time}
Expand Down
48 changes: 30 additions & 18 deletions training/benchmarks/tacotron2/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#
# Licensed under the Apache License, Version 2.0 (the "License")

import torch
from torch.types import Device
import time
import os
import sys

import numpy as np
import torch
from torch.types import Device

from model import create_model, create_model_config
from model.loss.loss_function import get_loss_function
Expand All @@ -21,8 +23,6 @@
sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../")))
from driver import Driver, Event

import torch.distributed as dist


class Trainer:

Expand Down Expand Up @@ -62,23 +62,26 @@ def init(self):
def train_one_epoch(self, train_dataloader):
state = self.training_state
driver = self.driver
state.epoch += 1
driver.event(Event.EPOCH_BEGIN, state.epoch)

if self.config.distributed:
self.train_dataloader.sampler.set_epoch(state.epoch)

no_eval_start_time = time.time()

for batch in train_dataloader:
self.train_one_step(batch)
state.no_eval_time += time.time() - no_eval_start_time

val_loss, _ = self.evaluator.evaluate(self)
state.val_loss = val_loss

epoch_data = {
"val_loss": val_loss,
"epoch": state.epoch,
"global_steps": state.global_steps
"global_steps": state.global_steps,
"num_trained_samples": state.num_mels,
"timestamp": int(time.time()),
}
print(epoch_data)

Expand All @@ -89,23 +92,19 @@ def train_one_step(self, batch):
driver = self.driver
state = self.training_state
args = self.config
state.global_steps += 1

adjust_learning_rate(self.training_state.epoch, self.optimizer,
args.learning_rate, args.lr_anneal_steps,
args.lr_anneal_factor)

self.model.zero_grad()
x, y, _ = batch_to_gpu(batch)
x, y, len_x = batch_to_gpu(batch)

loss = self.adapter.calculate_loss(self.model, self.config, self.criterion, x, y)
pure_compute_start_time = time.time()

if args.distributed:
reduced_loss = reduce_tensor(loss.data, self.world_size).item()
else:
reduced_loss = loss.item()

if np.isnan(reduced_loss):
raise Exception("loss is NaN")
loss = self.adapter.calculate_loss(self.model, self.config,
self.criterion, x, y)

if args.amp:
self.grad_scaler.scale(loss).backward()
Expand All @@ -119,13 +118,26 @@ def train_one_step(self, batch):
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
args.grad_clip_thresh)
self.optimizer.step()
state.pure_compute_time += time.time() - pure_compute_start_time

if args.distributed:
reduced_loss = reduce_tensor(loss.data, self.world_size).item()
else:
reduced_loss = loss.item()

if np.isnan(reduced_loss):
raise Exception("loss is NaN")

self.model.zero_grad(set_to_none=True)

state.train_loss = reduced_loss
step_info = dict(step=state.global_steps, train_loss=reduced_loss)
state.num_mels += len_x.item() * self.world_size
step_info = dict(
step=state.global_steps,
train_loss=reduced_loss,
num_trained_samples=state.num_mels,
)

self.training_state.global_steps += 1
print(f"step_info:{step_info}")
driver.event(Event.STEP_END, state.global_steps)

Expand All @@ -136,7 +148,7 @@ def detect_training_status(self):
if state.val_loss <= config.target_val_loss:
state.converged_success()

if state.epoch > config.max_epochs:
if state.epoch >= config.max_epochs:
state.end_training = True

return state.end_training
Expand Down
4 changes: 4 additions & 0 deletions training/benchmarks/tacotron2/pytorch/train/training_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dataclasses import dataclass


@dataclass
class TrainingState:
_trainer = None
Expand All @@ -16,9 +17,12 @@ class TrainingState:
epoch: int = 0
end_training: bool = False
converged: bool = False
num_mels = 0

init_time = 0
raw_train_time = 0
no_eval_time = 0
pure_compute_time = 0

def status(self):
if self.converged:
Expand Down
30 changes: 27 additions & 3 deletions training/nvidia/tacotron2-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,33 @@


### 运行情况
| 训练资源 | 配置文件 | 运行时长(s) | 目标val_loss | 收敛val_loss | 性能(samples/s) |
| -------- | --------------- | ----------- | ------------ | ------------ | --------------- |
| 单机8卡 | config_A100x1x8 | 16116.25 | 0.4852 | 0.4437 | 417.55 |
* 通用指标

| 指标名称 | 指标值 | 特殊说明 |
| -------------- | ----------------------- | ------------------------------------------- |
| 任务类别 | SpeechSynthesis | |
| 模型 | tacotron2 | |
| 数据集 | LJSpeech | |
| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16/tf32 |
| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 |
| 硬件设备简称 | nvidia A100 | |
| 硬件存储使用 | mem,见“性能指标” | 通常称为“显存”,单位为GiB |
| 端到端时间 | e2e_time,见“性能指标” | 总时间+Perf初始化等时间 |
| 总吞吐量 | p_whole,见“性能指标” | 实际训练样本数除以总时间(performance_whole) |
| 训练吞吐量 | p_train,见“性能指标” | 不包含每个epoch末尾的评估部分耗时 |
| **计算吞吐量** | **p_core,见“性能指标”** | 不包含数据IO部分的耗时(p3>p2>p1) |
| 训练结果 | val_loss,见“性能指标” | 验证loss |
| 额外修改项 | 无 | |

* 性能指标

| 配置 | precision | fix_hp | e2e_time | p_whole | p_train | p_core | val_loss | mem |
| ----------------- | --------- | --------------- | -------- | ------- | ------- | ------ | -------- | --------- |
| A100单机8卡(1x8) | tf32 | / | 10719 | 257556 | 265661 | 280476 | 0.4774 | 37.5/40.0 |
| A100单机单卡(1x1) | tf32 | bs=128,lr=0.001 | | 34440 | 34591 | 35562 | | 35.2/40.0 |
| A100两机8卡(2x8) | tf32 | bs=128,lr=0.001 | | 484402 | 512004 | 558171 | | 37.7/40.0 |



注:
训练精度来源:https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2#results,根据官方仓库中的脚本,训练1500epoch得到val_loss=0.4852.
Expand Down
9 changes: 9 additions & 0 deletions training/nvidia/tacotron2-pytorch/config/config_A100x1x1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from config_common import *

train_batch_size = 128
eval_batch_size = train_batch_size

warmup = 0.2
learning_rate = 1e-3

seed = 23333
9 changes: 9 additions & 0 deletions training/nvidia/tacotron2-pytorch/config/config_A100x2x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from config_common import *

train_batch_size = 128
eval_batch_size = train_batch_size

warmup = 0.2
learning_rate = 1e-3

seed = 23333
3 changes: 2 additions & 1 deletion training/nvidia/tacotron2-pytorch/config/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
librosa
librosa
inflect
yuzhou03 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion training/run_benchmarks/config/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
# "faster_rcnn:pytorch_1.8:A100:1:8:1": "/raid/dataset/fasterrcnn/coco2017/",
# "bigtransfer:pytorch_1.8:A100:1:8:1": "/raid/dataset/ImageNet_1k_2012/",

# "tacotron2:pytorch_1.8:A100:1:8:1": "/raid/dataset/tacotron2/LJSpeech/",
#"tacotron2:pytorch_1.13:A100:1:8:1": "/raid/dataset/tacotron2/LJSpeech/",
# "resnet50:pytorch_1.8:A100:1:8:1": "/raid/dataset/ImageNet_1k_2012/",
# "mask_rcnn:pytorch_1.8:A100:1:8:1": "/raid/dataset/maskrcnn/coco2017",

Expand Down