forked from ericwudayi/SkipVQVC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pitch.py
executable file
·78 lines (66 loc) · 2.93 KB
/
train_pitch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import sys
import os
import argparse
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from utils.dataloader import AudioNpyPitchLoader, Pitch_collate
sys.path.append('logger')
from logger import Logger
from logger_utils import prepare_directories_and_logger
sys.path.append('utils')
from save_and_load import load_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument('-train_dir', '--train_dir', type=str, required = True,
help = 'preprocessed npy files in train dir')
parser.add_argument('-test_dir','--test_dir', type=str, required = False, default=None,
help = 'preprocessed npy files of test dir')
parser.add_argument('-m', '--model', type=str, required= True,
help='model type in model dir')
parser.add_argument('-n', '--n_embed', type=str,required= True,
help='number of vectors in codebook')
parser.add_argument('-ch', '--channel', type=str, required= True,
help='channel number in VQVC+')
parser.add_argument('-t', '--trainer', type=str, required= True,
help = 'which trainer do you want? (rhythm, mean_std, normal)')
parser.add_argument('--load_checkpoint', type=bool, default=False,
required=False)
args = parser.parse_args()
logger = prepare_directories_and_logger(Logger, output_directory = f'output/{args.model}_n{args.n_embed}_ch{args.channel}_{args.trainer}')
import importlib
trainer = importlib.import_module(f'trainer.{args.trainer}')
train_ = getattr(trainer, 'train_')
model = importlib.import_module(f'model.{args.model}.vq_model')
model = getattr(model, 'VC_MODEL')
'''
Dataset and loader
'''
def make_inf_iterator(data_iterator):
while True:
for data in data_iterator:
yield data
audio_dir = args.train_dir#"/home/ericwudayi/nas189/homes/ericwudayi/VCTK-Corpus/mel3/mel.melgan"
dataset = AudioNpyPitchLoader(audio_dir)
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8,collate_fn=Pitch_collate)
if args.test_dir != None:
audio_dir_test = args.test_dir#"/home/ericwudayi/nas189/homes/ericwudayi/VCTK-Corpus/mel3/mel.test"
else:
audio_dir_test = audio_dir
print ("None test dir, use train dir instead")
dataset_test = AudioNpyPitchLoader(audio_dir_test)
test_loader = DataLoader(dataset_test, batch_size=8, shuffle=True, num_workers=4,collate_fn=Pitch_collate)
inf_iterator_test = make_inf_iterator(test_loader)
'''
Model Initilization
'''
model = model(in_channel=80,channel=int(args.channel),n_embed=int(args.n_embed)).cuda()
opt = optim.Adam(model.parameters())
'''
Training
'''
criterion = nn.L1Loss()
latent_loss_weight = 0.1
iteration = 0
if args.load_checkpoint==True:
model, opt, iteration = load_checkpoint(f'checkpoint/{args.model}_n{args.n_embed}_ch{args.channel}_{args.trainer}/gen', model, opt)
train_(args, model, opt, latent_loss_weight, criterion, loader, 800, inf_iterator_test, logger, iteration)