-
Notifications
You must be signed in to change notification settings - Fork 112
/
Copy pathconvert.py
92 lines (81 loc) · 3.62 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import argparse
import torch
import librosa
import time
from scipy.io.wavfile import write
from tqdm import tqdm
import utils
from models import SynthesizerTrn
from mel_processing import mel_spectrogram_torch
from wavlm import WavLM, WavLMConfig
from speaker_encoder.voice_encoder import SpeakerEncoder
import logging
logging.getLogger('numba').setLevel(logging.WARNING)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hpfile", type=str, default="configs/freevc.json", help="path to json config file")
parser.add_argument("--ptfile", type=str, default="checkpoints/freevc.pth", help="path to pth file")
parser.add_argument("--txtpath", type=str, default="convert.txt", help="path to txt file")
parser.add_argument("--outdir", type=str, default="output/freevc", help="path to output dir")
parser.add_argument("--use_timestamp", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(args.outdir, exist_ok=True)
hps = utils.get_hparams_from_file(args.hpfile)
print("Loading model...")
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model).cuda()
_ = net_g.eval()
print("Loading checkpoint...")
_ = utils.load_checkpoint(args.ptfile, net_g, None, True)
print("Loading WavLM for content...")
cmodel = utils.get_cmodel(0)
if hps.model.use_spk:
print("Loading speaker encoder...")
smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
print("Processing text...")
titles, srcs, tgts = [], [], []
with open(args.txtpath, "r") as f:
for rawline in f.readlines():
title, src, tgt = rawline.strip().split("|")
titles.append(title)
srcs.append(src)
tgts.append(tgt)
print("Synthesizing...")
with torch.no_grad():
for line in tqdm(zip(titles, srcs, tgts)):
title, src, tgt = line
# tgt
wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
if hps.model.use_spk:
g_tgt = smodel.embed_utterance(wav_tgt)
g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).cuda()
else:
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).cuda()
mel_tgt = mel_spectrogram_torch(
wav_tgt,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
)
# src
wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
wav_src = torch.from_numpy(wav_src).unsqueeze(0).cuda()
c = utils.get_content(cmodel, wav_src)
if hps.model.use_spk:
audio = net_g.infer(c, g=g_tgt)
else:
audio = net_g.infer(c, mel=mel_tgt)
audio = audio[0][0].data.cpu().float().numpy()
if args.use_timestamp:
timestamp = time.strftime("%m-%d_%H-%M", time.localtime())
write(os.path.join(args.outdir, "{}.wav".format(timestamp+"_"+title)), hps.data.sampling_rate, audio)
else:
write(os.path.join(args.outdir, f"{title}.wav"), hps.data.sampling_rate, audio)