Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywalnut310 committed Oct 25, 2020
1 parent 9b82819 commit 13e9976
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 31 deletions.
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ We also provide the [pretrained model](https://drive.google.com/open?id=1JiCMBVT
</tr>
</table>


## Update Notes*

This result was not included in the paper. Lately, we found that two modifications help to improve the synthesis quality of Glow-TTS.; 1) moving to a vocoder, [HiFi-GAN](https://arxiv.org/abs/2010.05646) to reduce noise, 2) putting a blank token between any two input tokens to improve pronunciation. Specifically,
we used a fine-tuned vocoder with Tacotron 2 which is provided as a pretrained model in the [HiFi-GAN repo](https://github.com/jik876/hifi-gan). If you're interested, please listen to the samples in our [demo](https://jaywalnut310.github.io/glow-tts-demo/index.html).

For adding a blank token, we provide a [config file](./configs/base_blank.json) and a [pretrained model](https://drive.google.com/open?id=1RxR6JWg6WVBZYb-pIw58hi1XLNb5aHEi). We also provide an inference example [inference_hifigan.ipynb](./inference_hifigan.ipynb). You may need to initialize HiFi-GAN submodule: `git submodule init; git submodule update`


## 1. Environments we use

* Python3.6.9
Expand All @@ -37,7 +46,6 @@ For Mixed-precision training, we use [apex](https://github.com/NVIDIA/apex); com

a) Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), then rename or create a link to the dataset folder: `ln -s /path/to/LJSpeech-1.1/wavs DUMMY`


b) Initialize WaveGlow submodule: `git submodule init; git submodule update`

Don't forget to download pretrained WaveGlow model and place it into the waveglow folder.
Expand All @@ -47,7 +55,6 @@ c) Build Monotonic Alignment Search Code (Cython): `cd monotonic_align; python s

## 3. Training Example


```sh
sh train_ddi.sh configs/base.json base
```
Expand All @@ -57,17 +64,9 @@ sh train_ddi.sh configs/base.json base
See [inference.ipynb](./inference.ipynb)


## 5. Modifications after Paper Submission

This result was not included in the paper. Lately, we found that two modifications help to improve the synthesis quality of Glow-TTS.; 1) moving to a vocoder, HiFi-GAN (https://arxiv.org/abs/2010.05646) to reduce noise, 2) putting a blank token between any two input tokens to improve pronunciation. Specifically, we used a fine-tuned vocoder with Tacotron 2 which is provided as a pretrained model in the repo (https://github.com/jik876/hifi-gan). If you're interested, please listen to the three samples in our demo.

See [inference_hifigan.ipynb](./inference_hifigan.ipynb)

For adding a blank token, we provide a config file and a pretrained model.


## Acknowledgements
Our implementation is highly affected by the following repos:

Our implementation is hugely influenced by the following repos:
* [WaveGlow](https://github.com/NVIDIA/waveglow)
* [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor)
* [Mellotron](https://github.com/NVIDIA/mellotron)
2 changes: 1 addition & 1 deletion inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
"with torch.no_grad():\n",
" noise_scale = .667\n",
" length_scale = 1.0\n",
" (y_gen_tst, *r), attn_gen, *_ = model(x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale)\n",
" (y_gen_tst, *_), *_, (attn_gen, *_) = model(x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale)\n",
" try:\n",
" audio = waveglow.infer(y_gen_tst.half(), sigma=.666)\n",
" except:\n",
Expand Down
4 changes: 2 additions & 2 deletions inference_hifigan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"with torch.no_grad():\n",
" noise_scale = .667\n",
" length_scale = 1.0\n",
" (y_gen_tst, *r), attn_gen, *_ = model(x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale)\n",
" (y_gen_tst, *_), *_, (attn_gen, *_) = model(x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale)\n",
"\n",
"# save mel-frames\n",
"if not os.path.exists('./hifi-gan/test_mel_files'):\n",
Expand All @@ -96,7 +96,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Finetuned HiFi-GAN with Tacotron 2, which is provided in the repo of HiFi-GAN.\n",
"# Use finetuned HiFi-GAN with Tacotron 2, which is provided in the repo of HiFi-GAN.\n",
"!python ./hifi-gan/inference_e2e.py --checkpoint_file /path/to/finetuned_model"
]
},
Expand Down
23 changes: 11 additions & 12 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,20 +287,20 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, g=None, gen=False, noise
else:
y_max_length = y.size(2)
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
z_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)

if gen:
attn = commons.generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
y_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask

z = (y_m + torch.exp(y_logs) * torch.randn_like(y_m) * noise_scale) * y_mask
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
return (y, y_m, y_logs, logdet), attn, logw, logw_, x_m, x_logs
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask
y, logdet = self.decoder(z, z_mask, g=g, reverse=True)
return (y, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
else:
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
z, logdet = self.decoder(y, z_mask, g=g, reverse=False)
with torch.no_grad():
x_s_sq_r = torch.exp(-2 * x_logs)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1]
Expand All @@ -310,11 +310,10 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, g=None, gen=False, noise
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']

attn = monotonic_align.maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask

return (z, y_m, y_logs, logdet), attn, logw, logw_, x_m, x_logs
return (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)

def preprocess(self, y, y_lengths, y_max_length):
if y_max_length is not None:
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def train(rank, epoch, hps, generator, optimizer_g, train_loader, logger, writer
# Train Generator
optimizer_g.zero_grad()

(z, y_m, y_logs, logdet), attn, logw, logw_, x_m, x_logs = generator(x, x_lengths, y, y_lengths, gen=False)
l_mle = commons.mle_loss(z, y_m, y_logs, logdet, y_mask)
(z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) = generator(x, x_lengths, y, y_lengths, gen=False)
l_mle = commons.mle_loss(z, z_m, z_logs, logdet, z_mask)
l_length = commons.duration_loss(logw, logw_, x_lengths)

loss_gs = [l_mle, l_length]
Expand Down Expand Up @@ -156,8 +156,8 @@ def evaluate(rank, epoch, hps, generator, optimizer_g, val_loader, logger, write
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)


(z, y_m, y_logs, logdet), attn, logw, logw_, x_m, x_logs = generator(x, x_lengths, y, y_lengths, gen=False)
l_mle = commons.mle_loss(z, y_m, y_logs, logdet, y_mask)
(z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) = generator(x, x_lengths, y, y_lengths, gen=False)
l_mle = commons.mle_loss(z, z_m, z_logs, logdet, z_mask)
l_length = commons.duration_loss(logw, logw_, x_lengths)

loss_gs = [l_mle, l_length]
Expand Down

0 comments on commit 13e9976

Please sign in to comment.