diff --git a/LICENSE.3rd_party_library b/LICENSE.3rd_party_library index aaf98a4..9fc5153 100644 --- a/LICENSE.3rd_party_library +++ b/LICENSE.3rd_party_library @@ -42,6 +42,21 @@ jieba Copyright (C) 2013 Sun Junyi MIT License +fastdtw + https://github.com/slaypni/fastdtw + Copyright (C) 2015 slaypni + MIT License + +epitran + https://github.com/dmort27/epitran + Copyright (c) 2016 David Mortensen + MIT License + +ko_pron + https://github.com/kord123/ko_pron + Copyright (c) 2018 Koretskyy Andriy + MIT License + joblib https://github.com/joblib/joblib Copyright (C) 2008-2016, The joblib developers. diff --git a/pororo/models/tts/__init__.py b/pororo/models/tts/__init__.py new file mode 100644 index 0000000..d921278 --- /dev/null +++ b/pororo/models/tts/__init__.py @@ -0,0 +1 @@ +from pororo.models.tts.synthesizer import MultilingualSpeechSynthesizer diff --git a/pororo/models/tts/hifigan/checkpoint.py b/pororo/models/tts/hifigan/checkpoint.py new file mode 100644 index 0000000..1a8cb1b --- /dev/null +++ b/pororo/models/tts/hifigan/checkpoint.py @@ -0,0 +1,17 @@ +import torch +import os +import glob + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + checkpoint_dict = torch.load(filepath, map_location=device) + return checkpoint_dict + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '*') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return '' + return sorted(cp_list)[-1] diff --git a/pororo/models/tts/hifigan/model.py b/pororo/models/tts/hifigan/model.py new file mode 100644 index 0000000..e0086b8 --- /dev/null +++ b/pororo/models/tts/hifigan/model.py @@ -0,0 +1,289 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class ResBlock1(nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/pororo/models/tts/synthesis.py b/pororo/models/tts/synthesis.py new file mode 100644 index 0000000..bbf6c28 --- /dev/null +++ b/pororo/models/tts/synthesis.py @@ -0,0 +1,53 @@ +import torch + +from pororo.models.tts.tacotron.params import Params as hp +from pororo.models.tts.utils import audio, text + + +def synthesize(model, input_data, force_cpu=False, device=None): + item = input_data.split("|") + clean_text = item[1] + + if not hp.use_punctuation: + clean_text = text.remove_punctuation(clean_text) + if not hp.case_sensitive: + clean_text = text.to_lower(clean_text) + if hp.remove_multiple_wspaces: + clean_text = text.remove_odd_whitespaces(clean_text) + + t = torch.LongTensor( + text.to_sequence(clean_text, use_phonemes=hp.use_phonemes)) + + if hp.multi_language: + l_tokens = item[3].split(",") + t_length = len(clean_text) + 1 + l = [] + for token in l_tokens: + l_d = token.split("-") + + language = [0] * hp.language_number + for l_cw in l_d[0].split(":"): + l_cw_s = l_cw.split("*") + language[hp.languages.index( + l_cw_s[0])] = (1 if len(l_cw_s) == 1 else float(l_cw_s[1])) + + language_length = int(l_d[1]) if len(l_d) == 2 else t_length + l += [language] * language_length + t_length -= language_length + l = torch.FloatTensor([l]) + else: + l = None + + s = (torch.LongTensor([hp.unique_speakers.index(item[2])]) if hp.multi_speaker else None) + + if torch.cuda.is_available() and not force_cpu: + t = t.to(device) + if l is not None: + l = l.to(device) + if s is not None: + s = s.to(device) + + s = model.inference(t, speaker=s, language=l).cpu().detach().numpy() + s = audio.denormalize_spectrogram(s, not hp.predict_linear) + + return s diff --git a/pororo/models/tts/synthesizer.py b/pororo/models/tts/synthesizer.py new file mode 100644 index 0000000..5c563bf --- /dev/null +++ b/pororo/models/tts/synthesizer.py @@ -0,0 +1,144 @@ +import torch +import json +import librosa +from typing import Tuple + +from pororo.models.tts.hifigan.checkpoint import load_checkpoint +from pororo.models.tts.hifigan.model import Generator +from pororo.models.tts.tacotron.params import Params as tacotron_hp +from pororo.models.tts.tacotron.tacotron2 import Tacotron +from pororo.models.tts.synthesis import synthesize +from pororo.models.tts.utils import remove_dataparallel_prefix +from pororo.models.tts.waveRNN.gen_wavernn import generate as wavernn_generate +from pororo.models.tts.waveRNN.params import hp as wavernn_hp +from pororo.models.tts.waveRNN.waveRNN import WaveRNN + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class MultilingualSpeechSynthesizer(object): + def __init__( + self, + tacotron_path: str, + english_vocoder_path: str, + english_vocoder_config: str, + korean_vocoder_path: str, + korean_vocoder_config: str, + wavernn_path: str, + device: str, + lang: str = "en", + ): + self.lang = lang + self.device = device + self.vocoder_en_config = None + self.vocoder_ko_config = None + + self.tacotron, self.vocoder_en, self.vocoder_ko, self.vocoder_multi = self.build_model( + tacotron_path, + english_vocoder_path, + english_vocoder_config, + korean_vocoder_path, + korean_vocoder_config, + wavernn_path, + ) + + def _build_hifigan(self, config: str, hifigan_path: str) -> Generator: + with open(config) as f: + data = f.read() + + config = json.loads(data) + config = AttrDict(config) + + generator = Generator(config).to(self.device) + state_dict_g = load_checkpoint(hifigan_path, self.device) + generator.load_state_dict(state_dict_g['generator']) + generator.eval() + generator.remove_weight_norm() + + return generator + + def _build_tacotron(self, tacotron_path: str) -> Tacotron: + state = torch.load(tacotron_path, map_location=self.device) + tacotron_hp.load_state_dict(state["parameters"]) + tacotron = Tacotron() + tacotron.load_state_dict(remove_dataparallel_prefix(state["model"])) + tacotron.eval().to(self.device) + return tacotron + + def _build_wavernn(self, wavernn_path: str) -> WaveRNN: + wavernn = (WaveRNN( + rnn_dims=wavernn_hp.voc_rnn_dims, + fc_dims=wavernn_hp.voc_fc_dims, + bits=wavernn_hp.bits, + pad=wavernn_hp.voc_pad, + upsample_factors=wavernn_hp.voc_upsample_factors, + feat_dims=wavernn_hp.num_mels, + compute_dims=wavernn_hp.voc_compute_dims, + res_out_dims=wavernn_hp.voc_res_out_dims, + res_blocks=wavernn_hp.voc_res_blocks, + hop_length=wavernn_hp.hop_length, + sample_rate=wavernn_hp.sample_rate, + mode=wavernn_hp.voc_mode, + ).eval().to(self.device)) + wavernn.load(wavernn_path) + return wavernn + + def build_model( + self, + tacotron_path: str, + english_vocoder_path: str, + english_vocoder_config: str, + korean_vocoder_path: str, + korean_vocoder_config: str, + wavernn_path: str, + ) -> Tuple[Tacotron, Generator, Generator, WaveRNN]: + """Load and build tacotron a from checkpoint.""" + tacotron = self._build_tacotron(tacotron_path) + vocoder_multi = self._build_wavernn(wavernn_path) + vocoder_ko = self._build_hifigan(korean_vocoder_config, korean_vocoder_path) + vocoder_en = self._build_hifigan(english_vocoder_config, english_vocoder_path) + return tacotron, vocoder_en, vocoder_ko, vocoder_multi + + def _spectrogram_postprocess(self, spectrogram): + spectrogram = librosa.db_to_amplitude(spectrogram) + spectrogram = torch.log(torch.clamp(torch.Tensor(spectrogram), min=1e-5) * 1) + return spectrogram + + def predict(self, text: str, speaker: str): + speakers = speaker.split(',') + + spectrogram = synthesize(self.tacotron, f"|{text}", device=self.device) + + if len(speakers) > 1: + spectrogram = self._spectrogram_postprocess(spectrogram) + y_g_hat = self.vocoder_en(torch.Tensor(spectrogram).to(self.device).unsqueeze(0)) + audio = y_g_hat.squeeze() + audio = audio * 32768.0 + return audio.cpu().detach().numpy() + + if speaker in ("ko", "en"): + spectrogram = self._spectrogram_postprocess(spectrogram) + + if speaker == "ko": + y_g_hat = self.vocoder_ko(torch.Tensor(spectrogram).to(self.device).unsqueeze(0)) + else: + y_g_hat = self.vocoder_en(torch.Tensor(spectrogram).to(self.device).unsqueeze(0)) + + audio = y_g_hat.squeeze() + audio = audio * 32768.0 + return audio.cpu().detach().numpy() + + else: + audio = wavernn_generate( + self.vocoder_multi, + spectrogram, + wavernn_hp.voc_gen_batched, + wavernn_hp.voc_target, + wavernn_hp.voc_overlap, + ) + audio = audio * 32768.0 + return audio diff --git a/pororo/models/tts/tacotron/__init__.py b/pororo/models/tts/tacotron/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pororo/models/tts/tacotron/attention.py b/pororo/models/tts/tacotron/attention.py new file mode 100644 index 0000000..d9deaea --- /dev/null +++ b/pororo/models/tts/tacotron/attention.py @@ -0,0 +1,112 @@ +import torch +from torch.nn import Conv1d, Linear, Parameter +from torch.nn import functional as F + + +class AttentionBase(torch.nn.Module): + """Abstract attention class. + + Arguments: + representation_dim -- size of the hidden representation + query_dim -- size of the attention query input (probably decoder hidden state) + memory_dim -- size of the attention memory input (probably encoder outputs) + """ + + def __init__(self, representation_dim, query_dim, memory_dim): + super(AttentionBase, self).__init__() + self._bias = Parameter(torch.zeros(1, representation_dim)) + self._energy = Linear(representation_dim, 1, bias=False) + self._query = Linear(query_dim, representation_dim, bias=False) + self._memory = Linear(memory_dim, representation_dim, bias=False) + self._memory_dim = memory_dim + + def reset(self, encoded_input, batch_size, max_len, device): + """Initialize previous attention weights & prepare attention memory.""" + self._memory_transform = self._memory(encoded_input) + self._prev_weights = torch.zeros(batch_size, max_len, device=device) + self._prev_context = torch.zeros( + batch_size, + self._memory_dim, + device=device, + ) + return self._prev_context + + def _attent(self, query, memory_transform, weights): + raise NotImplementedError + + def _combine_weights(self, previsous_weights, weights): + raise NotImplementedError + + def _normalize(self, energies, mask): + raise NotImplementedError + + def forward(self, query, memory, mask, prev_decoder_output): + energies = self._attent( + query, + self._memory_transform, + self._prev_weights, + ) + attention_weights = self._normalize(energies, mask) + self._prev_weights = self._combine_weights( + self._prev_weights, + attention_weights, + ) + attention_weights = attention_weights.unsqueeze(1) + self._prev_context = torch.bmm(attention_weights, memory).squeeze(1) + return self._prev_context, attention_weights.squeeze(1) + + +class LocationSensitiveAttention(AttentionBase): + """ + Location Sensitive Attention: + Location-sensitive attention: https://arxiv.org/abs/1506.07503. + Extends additive attention (here https://arxiv.org/abs/1409.0473) + to use cumulative attention weights from previous decoder time steps. + + Arguments: + kernel_size -- kernel size of the convolution calculating location features + channels -- number of channels of the convolution calculating location features + smoothing -- to normalize weights using softmax, use False (default) and True to use sigmoids + """ + + def __init__( + self, + kernel_size, + channels, + smoothing, + representation_dim, + query_dim, + memory_dim, + ): + super(LocationSensitiveAttention, + self).__init__(representation_dim, query_dim, memory_dim) + self._location = Linear(channels, representation_dim, bias=False) + self._loc_features = Conv1d( + 1, + channels, + kernel_size, + padding=(kernel_size - 1) // 2, + bias=False, + ) + self._smoothing = smoothing + + def _attent(self, query, memory_transform, cum_weights): + query = self._query(query.unsqueeze(1)) + cum_weights = cum_weights.unsqueeze(-1) + loc_features = self._loc_features(cum_weights.transpose(1, 2)) + loc_features = self._location(loc_features.transpose(1, 2)) + energy = query + memory_transform + loc_features + energy = self._energy(torch.tanh(energy + self._bias)) + return energy.squeeze(-1) + + def _normalize(self, energies, mask): + energies[~mask] = float("-inf") + if self._smoothing: + sigmoid = torch.sigmoid(energies) + total = torch.sum(sigmoid, dim=-1) + return sigmoid / total + else: + return F.softmax(energies, dim=1) + + def _combine_weights(self, previous_weights, weights): + return previous_weights + weights diff --git a/pororo/models/tts/tacotron/encoder.py b/pororo/models/tts/tacotron/encoder.py new file mode 100644 index 0000000..e23011b --- /dev/null +++ b/pororo/models/tts/tacotron/encoder.py @@ -0,0 +1,133 @@ +import torch +from torch.nn import Embedding, Sequential + +from pororo.models.tts.tacotron.layers import ( + ConvBlockGenerated, + HighwayConvBlockGenerated, +) + + +class GeneratedConvolutionalEncoder(torch.nn.Module): + """Convolutional encoder (possibly multi-lingual) with weights generated by another network. + + Arguments: + see ConvolutionalEncoder + embedding_dim -- size of the generator embedding (should be language embedding) + bottleneck_dim -- size of the generating layer + Keyword arguments: + see ConvolutionalEncoder + """ + + def __init__( + self, + input_dim, + output_dim, + dropout, + embedding_dim, + bottleneck_dim, + groups=1, + ): + super(GeneratedConvolutionalEncoder, self).__init__() + + self._groups = groups + self._input_dim = input_dim + self._output_dim = output_dim + + input_dim *= groups + output_dim *= groups + + layers = ([ + ConvBlockGenerated( + embedding_dim, + bottleneck_dim, + input_dim, + output_dim, + 1, + dropout=dropout, + activation="relu", + groups=groups, + ), + ConvBlockGenerated( + embedding_dim, + bottleneck_dim, + output_dim, + output_dim, + 1, + dropout=dropout, + groups=groups, + ), + ] + [ + HighwayConvBlockGenerated( + embedding_dim, + bottleneck_dim, + output_dim, + output_dim, + 3, + dropout=dropout, + dilation=3**i, + groups=groups, + ) for i in range(4) + ] + [ + HighwayConvBlockGenerated( + embedding_dim, + bottleneck_dim, + output_dim, + output_dim, + 3, + dropout=dropout, + dilation=3**i, + groups=groups, + ) for i in range(4) + ] + [ + HighwayConvBlockGenerated( + embedding_dim, + bottleneck_dim, + output_dim, + output_dim, + 3, + dropout=dropout, + dilation=1, + groups=groups, + ) for _ in range(2) + ] + [ + HighwayConvBlockGenerated( + embedding_dim, + bottleneck_dim, + output_dim, + output_dim, + 1, + dropout=dropout, + dilation=1, + groups=groups, + ) for _ in range(2) + ]) + + self._layers = Sequential(*layers) + self._embedding = Embedding(groups, embedding_dim) + + def forward(self, x, x_lenghts=None, x_langs=None): + + # x_langs is specified during inference with batch size 1, so we need to + # expand the single language to create complete groups (all langs. in parallel) + if x_langs is not None and x_langs.shape[0] == 1: + x = x.expand((self._groups, -1, -1)) + + # create generator embeddings for all groups + e = self._embedding(torch.arange(self._groups, device=x.device)) + + bs = x.shape[0] + x = x.transpose(1, 2) + x = x.reshape(bs // self._groups, self._groups * self._input_dim, -1) + _, x = self._layers((e, x)) + x = x.reshape(bs, self._output_dim, -1) + x = x.transpose(1, 2) + + if x_langs is not None and x_langs.shape[0] == 1: + xr = torch.zeros(1, x.shape[1], x.shape[2], device=x.device) + x_langs_normed = x_langs / x_langs.sum(2, keepdim=True)[0] + for l in range(self._groups): + w = x_langs_normed[0, :, l].reshape(-1, 1) + xr[0] += w * x[l] + x = xr + + return x diff --git a/pororo/models/tts/tacotron/generated.py b/pororo/models/tts/tacotron/generated.py new file mode 100644 index 0000000..4e320a7 --- /dev/null +++ b/pororo/models/tts/tacotron/generated.py @@ -0,0 +1,147 @@ +import torch +from torch.nn import Linear +from torch.nn import functional as F + + +class Conv1dGenerated(torch.nn.Module): + """One dimensional convolution with generated weights (each group has separate weights). + + Arguments: + embedding_dim -- size of the meta embedding (should be language embedding) + bottleneck_dim -- size of the generating embedding + see torch.nn.Conv1d + """ + + def __init__( + self, + embedding_dim, + bottleneck_dim, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(Conv1dGenerated, self).__init__() + + self._in_channels = in_channels + self._out_channels = out_channels + self._kernel_size = kernel_size + self._stride = stride + self._padding = padding + self._dilation = dilation + self._groups = groups + + # in_channels and out_channels is divisible by groups + # tf.nn.functional.conv1d accepts weights of shape [out_channels, in_channels // groups, kernel] + + self._bottleneck = Linear(embedding_dim, bottleneck_dim) + self._kernel = Linear( + bottleneck_dim, + out_channels // groups * in_channels // groups * kernel_size) + self._bias = Linear( + bottleneck_dim, + out_channels // groups, + ) if bias else None + + def forward(self, generator_embedding, x): + + assert ( + generator_embedding.shape[0] == self._groups + ), "Number of groups of a convolutional layer must match the number of generators." + + e = self._bottleneck(generator_embedding) + kernel = self._kernel(e).view( + self._out_channels, + self._in_channels // self._groups, + self._kernel_size, + ) + bias = self._bias(e).view(self._out_channels) if self._bias else None + + return F.conv1d( + x, + kernel, + bias, + self._stride, + self._padding, + self._dilation, + self._groups, + ) + + +class BatchNorm1dGenerated(torch.nn.Module): + """One dimensional batch normalization with generated weights (each group has separate parameters). + + Arguments: + embedding_dim -- size of the meta embedding (should be language embedding) + bottleneck_dim -- size of the generating embedding + see torch.nn.BatchNorm1d + Keyword arguments: + groups -- number of groups with separate weights + """ + + def __init__( + self, + embedding_dim, + bottleneck_dim, + num_features, + groups=1, + eps=1e-8, + momentum=0.1, + ): + super(BatchNorm1dGenerated, self).__init__() + + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer("num_batches_tracked", + torch.tensor(0, dtype=torch.long)) + + self._num_features = num_features // groups + self._eps = eps + self._momentum = momentum + self._groups = groups + + self._bottleneck = Linear(embedding_dim, bottleneck_dim) + self._affine = Linear( + bottleneck_dim, + self._num_features + self._num_features, + ) + + def forward(self, generator_embedding, x): + + assert ( + generator_embedding.shape[0] == self._groups + ), "Number of groups of a batchnorm layer must match the number of generators." + + if self._momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self._momentum + + e = self._bottleneck(generator_embedding) + affine = self._affine(e) + scale = affine[:, :self._num_features].contiguous().view(-1) + bias = affine[:, self._num_features:].contiguous().view(-1) + + if self.training: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self._momentum is None: + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: + exponential_average_factor = self._momentum + + return F.batch_norm( + x, + self.running_mean, + self.running_var, + scale, + bias, + self.training, + exponential_average_factor, + self._eps, + ) diff --git a/pororo/models/tts/tacotron/layers.py b/pororo/models/tts/tacotron/layers.py new file mode 100644 index 0000000..2b76cea --- /dev/null +++ b/pororo/models/tts/tacotron/layers.py @@ -0,0 +1,286 @@ +import torch +from torch.nn import ( + BatchNorm1d, + ConstantPad1d, + Conv1d, + Dropout, + Identity, + ReLU, + Sequential, + Sigmoid, + Tanh, +) +from torch.nn import functional as F + +from pororo.models.tts.tacotron.generated import ( + BatchNorm1dGenerated, + Conv1dGenerated, +) + + +def get_activation(name): + """Get activation function by name.""" + return { + "relu": ReLU(), + "sigmoid": Sigmoid(), + "tanh": Tanh(), + "identity": Identity(), + }[name] + + +class ZoneoutLSTMCell(torch.nn.LSTMCell): + """Wrapper around LSTM cell providing zoneout regularization.""" + + def __init__( + self, + input_size, + hidden_size, + zoneout_rate_hidden, + zoneout_rate_cell, + bias=True, + ): + super(ZoneoutLSTMCell, self).__init__(input_size, hidden_size, bias) + self.zoneout_c = zoneout_rate_cell + self.zoneout_h = zoneout_rate_hidden + + def forward(self, cell_input, h, c): + new_h, new_c = super(ZoneoutLSTMCell, self).forward(cell_input, (h, c)) + if self.training: + new_h = (1 - self.zoneout_h) * F.dropout( + new_h - h, + self.zoneout_h, + ) + h + new_c = (1 - self.zoneout_c) * F.dropout( + new_c - c, + self.zoneout_c, + ) + c + else: + new_h = self.zoneout_h * h + (1 - self.zoneout_h) * new_h + new_c = self.zoneout_c * c + (1 - self.zoneout_c) * new_c + return new_h, new_c + + +class DropoutLSTMCell(torch.nn.LSTMCell): + """Wrapper around LSTM cell providing hidden state dropout regularization.""" + + def __init__(self, input_size, hidden_size, dropout_rate, bias=True): + super(DropoutLSTMCell, self).__init__(input_size, hidden_size, bias) + self._dropout = Dropout(dropout_rate) + + def forward(self, cell_input, h, c): + new_h, new_c = super(DropoutLSTMCell, self).forward(cell_input, (h, c)) + new_h = self._dropout(new_h) + return new_h, new_c + + +class ConvBlock(torch.nn.Module): + """ + One dimensional convolution with batchnorm and dropout, expected channel-first input. + + Arguments: + input_channels -- number if input channels + output_channels -- number of output channels + kernel -- convolution kernel size ('same' padding is used) + Keyword arguments: + dropout (default: 0.0) -- dropout rate to be aplied after the block + activation (default 'identity') -- name of the activation function applied after batchnorm + dilation (default: 1) -- dilation of the inner convolution + groups (default: 1) -- number of groups of the inner convolution + batch_norm (default: True) -- set False to disable batch normalization + """ + + def __init__( + self, + input_channels, + output_channels, + kernel, + dropout=0.0, + activation="identity", + dilation=1, + groups=1, + batch_norm=True, + ): + super(ConvBlock, self).__init__() + + self._groups = groups + + p = (kernel - 1) * dilation // 2 + padding = p if kernel % 2 != 0 else (p, p + 1) + layers = [ + ConstantPad1d(padding, 0.0), + Conv1d( + input_channels, + output_channels, + kernel, + padding=0, + dilation=dilation, + groups=groups, + bias=(not batch_norm), + ), + ] + + if batch_norm: + layers += [BatchNorm1d(output_channels)] + + layers += [get_activation(activation)] + layers += [Dropout(dropout)] + + self._block = Sequential(*layers) + + def forward(self, x): + return self._block(x) + + +class ConvBlockGenerated(torch.nn.Module): + """One dimensional convolution with generated weights and with batchnorm and dropout, expected channel-first input. + + Arguments: + embedding_dim -- size of the meta embedding + bottleneck_dim -- size of the generating layer + input_channels -- number if input channels + output_channels -- number of output channels + kernel -- convolution kernel size ('same' padding is used) + Keyword arguments: + dropout (default: 0.0) -- dropout rate to be aplied after the block + activation (default 'identity') -- name of the activation function applied after batchnorm + dilation (default: 1) -- dilation of the inner convolution + groups (default: 1) -- number of groups of the inner convolution + batch_norm (default: True) -- set False to disable batch normalization + """ + + def __init__( + self, + embedding_dim, + bottleneck_dim, + input_channels, + output_channels, + kernel, + dropout=0.0, + activation="identity", + dilation=1, + groups=1, + batch_norm=True, + ): + super(ConvBlockGenerated, self).__init__() + + self._groups = groups + + p = (kernel - 1) * dilation // 2 + padding = p if kernel % 2 != 0 else (p, p + 1) + + self._padding = ConstantPad1d(padding, 0.0) + self._convolution = Conv1dGenerated( + embedding_dim, + bottleneck_dim, + input_channels, + output_channels, + kernel, + padding=0, + dilation=dilation, + groups=groups, + bias=(not batch_norm), + ) + self._regularizer = (BatchNorm1dGenerated( + embedding_dim, + bottleneck_dim, + output_channels, + groups=groups, + ) if batch_norm else None) + self._activation = Sequential( + get_activation(activation), + Dropout(dropout), + ) + + def forward(self, x): + e, x = x + x = self._padding(x) + x = self._convolution(e, x) + if self._regularizer is not None: + x = self._regularizer(e, x) + x = self._activation(x) + return e, x + + +class HighwayConvBlock(ConvBlock): + """Gated 1D covolution. + + Arguments: + see ConvBlock + """ + + def __init__( + self, + input_channels, + output_channels, + kernel, + dropout=0.0, + activation="identity", + dilation=1, + groups=1, + batch_norm=True, + ): + super(HighwayConvBlock, self).__init__( + input_channels, + 2 * output_channels, + kernel, + dropout, + activation, + dilation, + groups, + batch_norm, + ) + self._gate = Sigmoid() + + def forward(self, x): + h = super(HighwayConvBlock, self).forward(x) + chunks = torch.chunk(h, 2 * self._groups, 1) + h1 = torch.cat(chunks[0::2], 1) + h2 = torch.cat(chunks[1::2], 1) + p = self._gate(h1) + return h2 * p + x * (1.0 - p) + + +class HighwayConvBlockGenerated(ConvBlockGenerated): + """Gated 1D covolution with generated weights. + + Arguments: + embedding_dim -- size of the meta embedding + bottleneck_dim -- size of the generating layer + see ConvBlockGenerated + """ + + def __init__( + self, + embedding_dim, + bottleneck_dim, + input_channels, + output_channels, + kernel, + dropout=0.0, + activation="identity", + dilation=1, + groups=1, + batch_norm=True, + ): + super(HighwayConvBlockGenerated, self).__init__( + embedding_dim, + bottleneck_dim, + input_channels, + 2 * output_channels, + kernel, + dropout, + activation, + dilation, + groups, + batch_norm, + ) + self._gate = Sigmoid() + + def forward(self, x): + e, x = x + _, h = super(HighwayConvBlockGenerated, self).forward((e, x)) + chunks = torch.chunk(h, 2 * self._groups, 1) + h1 = torch.cat(chunks[0::2], 1) + h2 = torch.cat(chunks[1::2], 1) + p = self._gate(h1) + return e, h2 * p + x * (1.0 - p) diff --git a/pororo/models/tts/tacotron/params.py b/pororo/models/tts/tacotron/params.py new file mode 100644 index 0000000..1323fc8 --- /dev/null +++ b/pororo/models/tts/tacotron/params.py @@ -0,0 +1,191 @@ +import json + + +class Params: + version = "1.0" # is used during training as name of checkpoints and Tensorboard logs (together with timestamp and reached loss) + """ + **************** PARAMETERS OF TRAINING LOOP **************** + """ + + epochs = 300 # training epochs + batch_size = 52 # batch size during training (is parallelization is True, each GPU has batch_size // num_gpus examples) + # if using encoder_type 'convolutional' or 'generated', should be divisible by the number of languages + learning_rate = 1e-3 # starting learning rate + learning_rate_decay = 0.5 # decay multiplier used by step learning rate scheduler (use 0.5 for halving) + learning_rate_decay_start = 15000 # number of training steps until the first lr decay, expected to be greater than learning_rate_decay_each + learning_rate_decay_each = 15000 # size of the learning rate scheduler step in training steps, it decays lr every this number steps + learning_rate_encoder = 1e-3 # initial learning rate of the encoder, just used if encoder_optimizer is set to True + weight_decay = 1e-6 # L2 regularization + encoder_optimizer = False # if True, different learning rates are used for the encoder and decoder, the ecoder uses learning_rate_encoder at start + max_output_length = 5000 # maximal number of frames produced by decoder, the number of frames is usualy much lower during synthesis + gradient_clipping = 0.25 # gradient norm clipping + reversal_gradient_clipping = 0.25 # used if reversal_classifier is True, clips gradients flowing from adversarial classifier to encoder + guided_attention_loss = True # if True, guided attention loss term is used + guided_attention_steps = 20000 # number of training steps for which the guided attention loss term is used + guided_attention_toleration = ( + 0.25 # starting variance of the guided attention (i.e. diagonal toleration) + ) + guided_attention_gain = ( + 1.00025 # multiplier applied after every batch to guided_attention_toleration + ) + constant_teacher_forcing = True # if True, ground-truth frames are with probability teacher_forcing passed into decoder, cosine decay is used otherwise + teacher_forcing = ( + 1.0 # ratio of ground-truth frames, used if constant_teacher_forcing is True + ) + teacher_forcing_steps = 100000 # used if constant_teacher_forcing is False, cosine decay spans this number of trainig steps starting at teacher_forcing_start_steps + teacher_forcing_start_steps = ( + 50000 # number of training steps after which the teacher forcing decay starts + ) + checkpoint_each_epochs = 10 # save a checkpoint every this number epochs + parallelization = True # if True, DataParallel (parallel batch) is used, supports any number of GPUs + """ + ******************* DATASET SPECIFICATION ******************* + """ + + dataset = "ljspeech" # one of: css10, ljspeech, vctk, my_blizzard, my_common_voice, mailabs, must have implementation in loaders.py + cache_spectrograms = True # if True, during iterating the dataset, it first tries to load spectrograms (mel or linear) from cached files + languages = [ + "en-us" + ] # list of lnguages which will be loaded from the dataset, codes should correspond to + # espeak format (see 'phonemize --help) in order support the converion to phonemes + balanced_sampling = False # enables balanced sampling per languages (not speakers), multi_language must be True + perfect_sampling = False # used just if balanced_sampling is True, should be used together with encoder_type 'convolutional' or 'generated' + # if True, each language has the same number of samples and these samples are grouped, batch_size must be divisible + # if False, samples are taken from the multinomial distr. with replacement + """ + *************************** TEXT **************************** + """ + + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz " # supported input alphabet, used for computation of character embeddings + # for lower-case russian, greek, latin and pinyin use " abcdefghijklmnopqrstuvwxyzçèéßäöōǎǐíǒàáǔüèéìūòóùúāēěīâêôûñőűабвгдежзийклмнопрстуфхцчшщъыьэюяё" + case_sensitive = True # if False, all characters are lowered before usage + remove_multiple_wspaces = True # if True, multiple whitespaces, leading and trailing whitespaces, etc. are removed + use_punctuation = ( + True # if True, punctuation is preserved and punctuations_{in, out} are used + ) + punctuations_out = '、。,"(),.:;¿?¡!\\' # punctuation which usualy occurs outside words (important during phonemization) + punctuations_in = "'-" # punctuation which can occur inside a word, so whitespaces do not have to be present + use_phonemes = False # phonemes are valid only if True, tacotron uses phonemes instead of characters + # all phonemes of IPA: 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧ ɚ˞ɫ' + phonemes = "ɹɐpbtdkɡfvθðszʃʒhmnŋlrwjeəɪɒuːɛiaʌʊɑɜɔx " # supported input phonemes, used if use_phonemes is True + """ + ******************** PARAMETERS OF MODEL ******************** + """ + + embedding_dimension = 512 # dimension of character embedding + encoder_type = "simple" # changes type of the encoder of the Tacotron 2 tacotron + # one of: simple (single vanilla encoder for all languages without embedding), + # separate (distinct vanilla encoders for each language) + # shared (single vanilla encoder for all languages with lang. embedding injected to character embeddings) + # convolutional (single grouped fully convolutional encoder without embedding, each group corresponds to a language) + # generated (same as convolutional but with parameters generated by a meta-learning network) + encoder_dimension = 512 # output dimension of the encoder + encoder_blocks = 3 # number of convolutional block in vanilla encoder + encoder_kernel_size = 5 # size of kernel of convolutional blocks in vanilla encoder + generator_dim = 8 # used if encoder_type is 'generated', size of the 'language embedding' which is used by layers to generate weights + generator_bottleneck_dim = 4 # used if encoder_type is 'generated', size of fully-connected layers which generate parameters for encoder layers + prenet_dimension = 256 # size of pre-net layers + prenet_layers = 2 # number of pre-net layers + attention_type = "location_sensitive" # Type of the attention mechanism. + # one of: location_sensitive (Tacotron 2 vanilla), + # forward (undebugged, should allow just monotonous att.) + # forward_transition_agent (undebugged, fwd with explicit transition agent) + attention_dimension = 128 # + attention_kernel_size = 31 # kernel size of the convolution which extracts features from attention weights + attention_location_dimension = 32 # size of the features extracted by a convolutional layer from attention weights + decoder_dimension = 1024 # size of decoder RNNs + decoder_regularization = ( + "dropout" # regularization of decoder RNNs, one of: 'dropout', 'zoneout' + ) + zoneout_hidden = ( + 0.1 # used if decoder_regularization is 'zoneout', zoneout rate of LSTM h state + ) + zoneout_cell = ( + 0.1 # used if decoder_regularization is 'zoneout', zoneout rate of LSTM c state + ) + dropout_hidden = ( + 0.1 # used if decoder_regularization is 'dropout', dropout rate of LSTM output + ) + postnet_dimension = 512 # size of post-net layers + postnet_blocks = 5 # number of convolutional blocks in post-net + postnet_kernel_size = 5 # kernel size of convolutions in post-net blocks + dropout = 0.5 # dropout rate of convolutional block in the whole tacotron + predict_linear = False # if True, vanilla post-net is replaced by CBHG module which predicts linear spectrograms + cbhg_bank_kernels = 8 # used if predict_linear is True + cbhg_bank_dimension = 128 # used if predict_linear is True + cbhg_projection_kernel_size = 3 # used if predict_linear is True + cbhg_projection_dimension = 256 # used if predict_linear is True + cbhg_highway_dimension = 128 # used if predict_linear is True + cbhg_rnn_dim = 128 # used if predict_linear is True + cbhg_dropout = 0.0 # used if predict_linear is True + multi_speaker = False # if True, multi-speaker tacotron is used, speaker embeddings are concatenated to encoder outputs + multi_language = False # if True, multi-lingual tacotron is used, language embeddings are concatenated to encoder outputs + speaker_embedding_dimension = ( + 32 # used if multi_speaker is True, size of the speaker embedding + ) + language_embedding_dimension = ( + 4 # used if multi_language is True, size of the language embedding + ) + input_language_embedding = 4 # used if encoder_type is 'shared', language embedding of this size is concatenated to input char. embeddings + reversal_classifier = False # if True, adversarial classifier for predicting speakers from encoder outputs is used + reversal_classifier_type = "reversal" # one of: 'reversal' for a standard adversarial process with reverted gradients + # 'cosine' for a cosine similarity-based adversarial process, does not converge at all + reversal_classifier_dim = 256 # used if reversal_classifier is True and reversal_classifier_type id 'reversal' + # size of the hidden layer of the adversarial classifer + reversal_classifier_w = 1.0 # weight of the loss of the adversarial classifier (it is also reduced by number of mels, see TacotronLoss) + stop_frames = 5 # number of frames at the end which are considered as "ending sequence" and stop token probability should be one + speaker_number = 0 # do not set! + language_number = 0 # do not set! + """ + ******************** PARAMETERS OF AUDIO ******************** + """ + + sample_rate = 22050 # sample rate of source .wavs, used while computing spectrograms, MFCCs, etc. + num_fft = 1102 # number of frequency bins used during computation of spectrograms + num_mels = 80 # number of mel bins used during computation of mel spectrograms + num_mfcc = 13 # number of MFCCs, used just for MCD computation (during training) + stft_window_ms = 50 # size in ms of the Hann window of short-time Fourier transform, used during spectrogram computation + stft_shift_ms = ( + 12.5 # shift of the window (or better said gap between windows) in ms + ) + griffin_lim_iters = 60 # used if vocoding using Griffin-Lim algorithm (synthesize.py), greater value does not make much sense + griffin_lim_power = 1.5 # power applied to spectrograms before using GL + normalize_spectrogram = True # if True, spectrograms are normalized before passing into the tacotron, a per-channel normalization is used + # statistics (mean and variance) are computed from dataset at the start of training + use_preemphasis = True # if True, a preemphasis is applied to raw waveform before using them (spectrogram computation) + preemphasis = 0.97 # amount of preemphasis, used if use_preemphasis is True + + @staticmethod + def load_state_dict(d): + for k, v in d.items(): + setattr(Params, k, v) + + @staticmethod + def state_dict(): + members = [ + attr for attr in dir(Params) + if not callable(getattr(Params, attr)) and not attr.startswith("__") + ] + return {k: Params.__dict__[k] for k in members} + + @staticmethod + def load(json_path): + with open(json_path, "r", encoding="utf-8") as f: + params = json.load(f) + Params.load_state_dict(params) + + @staticmethod + def save(json_path): + with open(json_path, "w", encoding="utf-8") as f: + d = Params.state_dict() + json.dump(d, f, indent=4) + + @staticmethod + def symbols_count(): + symbols_count = len(Params.characters) + if Params.use_phonemes: + symbols_count = len(Params.phonemes) + if Params.use_punctuation: + symbols_count += len(Params.punctuations_out) + len( + Params.punctuations_in) + return symbols_count diff --git a/pororo/models/tts/tacotron/tacotron2.py b/pororo/models/tts/tacotron/tacotron2.py new file mode 100644 index 0000000..30f22cf --- /dev/null +++ b/pororo/models/tts/tacotron/tacotron2.py @@ -0,0 +1,529 @@ +from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential +from torch.nn import functional as F + +from pororo.models.tts.tacotron.attention import LocationSensitiveAttention +from pororo.models.tts.tacotron.encoder import GeneratedConvolutionalEncoder +from pororo.models.tts.tacotron.layers import ( + ConvBlock, + DropoutLSTMCell, + ZoneoutLSTMCell, +) +from pororo.models.tts.tacotron.params import Params as hp +from pororo.models.tts.utils import * + + +class Prenet(torch.nn.Module): + """Decoder pre-net module. + + Details: + stack of 2 linear layers with dropout which is enabled even during inference (output variation) + should act as a bottleneck for the attention + + Arguments: + input_dim -- size of the input (supposed the number of frame mels) + output_dim -- size of the output + num_layers -- number of the linear layers (at least one) + dropout -- dropout rate to be aplied after each layer (even during inference) + """ + + def __init__(self, input_dim, output_dim, num_layers, dropout): + super(Prenet, self).__init__() + assert num_layers > 0, "There must be at least one layer in the pre-net." + self._dropout_rate = dropout + self._activation = ReLU() + layers = [Linear(input_dim, output_dim)] + [ + Linear(output_dim, output_dim) for _ in range(num_layers - 1) + ] + self._layers = ModuleList(layers) + + def _layer_pass(self, x, layer): + x = layer(x) + x = self._activation(x) + x = F.dropout(x, p=self._dropout_rate, training=True) + return x + + def forward(self, x): + for layer in self._layers: + x = self._layer_pass(x, layer) + return x + + +class Postnet(torch.nn.Module): + """Post-net module for output spectrogram enhancement. + + Details: + stack of 5 conv. layers 5 × 1 with BN and tanh (except last), dropout + + Arguments: + input_dimension -- size of the input and output (supposed the number of frame mels) + postnet_dimension -- size of the internal convolutional blocks + num_blocks -- number of the convolutional blocks (at least one) + kernel_size -- kernel size of the encoder's convolutional blocks + dropout -- dropout rate to be aplied after each convolutional block + """ + + def __init__(self, input_dimension, postnet_dimension, num_blocks, + kernel_size, dropout): + super(Postnet, self).__init__() + assert ( + num_blocks > 1 + ), "There must be at least two convolutional blocks in the post-net." + self._convs = Sequential( + ConvBlock(input_dimension, postnet_dimension, kernel_size, dropout, + "tanh"), *[ + ConvBlock( + postnet_dimension, + postnet_dimension, + kernel_size, + dropout, + "tanh", + ) for _ in range(num_blocks - 2) + ], + ConvBlock( + postnet_dimension, + input_dimension, + kernel_size, + dropout, + "identity", + )) + + def forward(self, x, x_lengths): + residual = x + x = self._convs(x) + x += residual + return x + + +class Decoder(torch.nn.Module): + """Tacotron 2 decoder with queries produced by the first RNN layer and output produced by the second RNN. + + Decoder: + stack of 2 uni-directional LSTM layers with 1024 units + first LSTM is used to query attention mechanism + input of the first LSTM is previous prediction (pre-net output) and previous context vector + second LSTM acts as a generator + input of the second LSTM is current context vector and output of the first LSTM + output is passed through stop token layer, frame prediction layer and pre-net + + Arguments: + output_dim -- size of the predicted frame, i.e. number of mels + decoder_dim -- size of the generator output (and also of all the LSTMs used in the decoder) + attention -- instance of the location-sensitive attention module + generator_rnn -- instance of generator RNN + attention_rnn -- instance of attention RNN + context_dim -- size of the context vector produced by the given attention + prenet -- instance of the pre-net module + prenet_dim -- output dimension of the pre-net + max_frames -- maximal number of the predicted frames + """ + + def __init__( + self, + output_dim, + decoder_dim, + attention, + generator_rnn, + attention_rnn, + context_dim, + prenet, + prenet_dim, + max_frames, + ): + super(Decoder, self).__init__() + self._prenet = prenet + self._attention = attention + self._output_dim = output_dim + self._decoder_dim = decoder_dim + self._max_frames = max_frames + self._attention_lstm = attention_rnn + self._generator_lstm = generator_rnn + self._frame_prediction = Linear(context_dim + decoder_dim, output_dim) + self._stop_prediction = Linear(context_dim + decoder_dim, 1) + + self._speaker_embedding, self._language_embedding = None, None + + if hp.multi_speaker and hp.speaker_embedding_dimension > 0: + self._speaker_embedding = self._get_embedding( + hp.speaker_embedding_dimension, hp.speaker_number) + if hp.multi_language and hp.language_embedding_dimension > 0: + self._language_embedding = self._get_embedding( + hp.language_embedding_dimension, len(hp.languages)) + + def _get_embedding(self, embedding_dimension, size=None): + embedding = Embedding(size, embedding_dimension) + torch.nn.init.xavier_uniform_(embedding.weight) + return embedding + + def _target_init(self, target, batch_size): + """Prepend target spectrogram with a zero frame and pass it through pre-net.""" + # the F.pad function has some issues: https://github.com/pytorch/pytorch/issues/13058 + first_frame = torch.zeros( + batch_size, + self._output_dim, + device=target.device, + ).unsqueeze(1) + target = target.transpose(1, 2) # [B, F, N_MEL] + target = torch.cat((first_frame, target), dim=1) + target = self._prenet(target) + return target + + def _decoder_init(self, batch_size, device): + """Initialize hidden and cell state of the deocder's RNNs.""" + h_att = torch.zeros(batch_size, self._decoder_dim, device=device) + c_att = torch.zeros(batch_size, self._decoder_dim, device=device) + h_gen = torch.zeros(batch_size, self._decoder_dim, device=device) + c_gen = torch.zeros(batch_size, self._decoder_dim, device=device) + return h_att, c_att, h_gen, c_gen + + def _add_conditional_embedding(self, encoded, layer, condition): + """Compute speaker (lang.) embedding and concat it to the encoder output.""" + embedded = layer(encoded if condition is None else condition) + return torch.cat((encoded, embedded), dim=-1) + + def _decode(self, encoded_input, mask, target, teacher_forcing_ratio, + speaker, language): + """Perform decoding of the encoded input sequence.""" + + batch_size = encoded_input.size(0) + max_length = encoded_input.size(1) + inference = target is None + max_frames = self._max_frames if inference else target.size(2) + input_device = encoded_input.device + + # obtain speaker and language embeddings (or a dummy tensor) + if hp.multi_speaker and self._speaker_embedding is not None: + encoded_input = self._add_conditional_embedding( + encoded_input, self._speaker_embedding, speaker) + if hp.multi_language and self._language_embedding is not None: + encoded_input = self._add_conditional_embedding( + encoded_input, self._language_embedding, language) + + # attention and decoder states initialization + context = self._attention.reset( + encoded_input, + batch_size, + max_length, + input_device, + ) + h_att, c_att, h_gen, c_gen = self._decoder_init( + batch_size, + input_device, + ) + + # prepare some inference or train specific variables (teacher forcing, max. predicted length) + frame = torch.zeros(batch_size, self._output_dim, device=input_device) + if not inference: + target = self._target_init(target, batch_size) + teacher = torch.rand( + [max_frames], device=input_device) > (1 - teacher_forcing_ratio) + + # tensors for storing output + spectrogram = torch.zeros( + batch_size, + max_frames, + self._output_dim, + device=input_device, + ) + alignments = torch.zeros( + batch_size, + max_frames, + max_length, + device=input_device, + ) + stop_tokens = torch.zeros( + batch_size, + max_frames, + 1, + device=input_device, + ) + + # decoding loop + stop_frames = -1 + for i in range(max_frames): + prev_frame = (self._prenet(frame) + if inference or not teacher[i] else target[:, i]) + + # run decoder attention and RNNs + attention_input = torch.cat((prev_frame, context), dim=1) + h_att, c_att = self._attention_lstm(attention_input, h_att, c_att) + context, weights = self._attention( + h_att, + encoded_input, + mask, + prev_frame, + ) + generator_input = torch.cat((h_att, context), dim=1) + h_gen, c_gen = self._generator_lstm(generator_input, h_gen, c_gen) + + # predict frame and stop token + proto_output = torch.cat((h_gen, context), dim=1) + frame = self._frame_prediction(proto_output) + stop_logits = self._stop_prediction(proto_output) + + # store outputs + spectrogram[:, i] = frame + alignments[:, i] = weights + stop_tokens[:, i] = stop_logits + + # stop decoding if predicted (just during inference) + if inference and torch.sigmoid(stop_logits).ge(0.5): + if stop_frames == -1: + stop_frames = hp.stop_frames + continue + stop_frames -= 1 + if stop_frames == 0: + return ( + spectrogram[:, :i + 1], + stop_tokens[:, :i + 1].squeeze(2), + alignments[:, :i + 1], + ) + + return spectrogram, stop_tokens.squeeze(2), alignments + + def forward( + self, + encoded_input, + encoded_lenghts, + target, + teacher_forcing_ratio, + speaker, + language, + ): + ml = encoded_input.size(1) + mask = lengths_to_mask(encoded_lenghts, max_length=ml) + return self._decode( + encoded_input, + mask, + target, + teacher_forcing_ratio, + speaker, + language, + ) + + def inference(self, encoded_input, speaker, language): + mask = lengths_to_mask(torch.LongTensor([encoded_input.size(1)])) + spectrogram, _, _ = self._decode( + encoded_input, + mask, + None, + 0.0, + speaker, + language, + ) + return spectrogram + + +class Tacotron(torch.nn.Module): + """ + Tacotron 2: + characters as learned embedding + encoder, attention, decoder which predicts frames of mel spectrogram + the predicted mel spectrogram is passed through post-net which + predicts a residual to add to the prediction + minimize MSE from before and after the post-net to aid convergence + """ + + def __init__(self): + super(Tacotron, self).__init__() + + # Encoder embedding + other_symbols = 3 # PAD, EOS, UNK + self._embedding = Embedding( + hp.symbols_count() + other_symbols, + hp.embedding_dimension, + padding_idx=0, + ) + torch.nn.init.xavier_uniform_(self._embedding.weight) + + # Encoder transforming graphmenes or phonemes into abstract input representation + self._encoder = self._get_encoder() + + # Prenet for transformation of previous predicted frame + self._prenet = Prenet( + hp.num_mels, + hp.prenet_dimension, + hp.prenet_layers, + hp.dropout, + ) + + # Speaker and language embeddings make decoder bigger + decoder_input_dimension = hp.encoder_dimension + if hp.multi_speaker: + decoder_input_dimension += hp.speaker_embedding_dimension + if hp.multi_language: + decoder_input_dimension += hp.language_embedding_dimension + + # Decoder attention layer + self._attention = self._get_attention(decoder_input_dimension) + + # Instantiate decoder RNN layers + gen_cell_dimension = decoder_input_dimension + hp.decoder_dimension + att_cell_dimension = decoder_input_dimension + hp.prenet_dimension + if hp.decoder_regularization == "zoneout": + generator_rnn = ZoneoutLSTMCell( + gen_cell_dimension, + hp.decoder_dimension, + hp.zoneout_hidden, + hp.zoneout_cell, + ) + attention_rnn = ZoneoutLSTMCell( + att_cell_dimension, + hp.decoder_dimension, + hp.zoneout_hidden, + hp.zoneout_cell, + ) + else: + generator_rnn = DropoutLSTMCell( + gen_cell_dimension, + hp.decoder_dimension, + hp.dropout_hidden, + ) + attention_rnn = DropoutLSTMCell( + att_cell_dimension, + hp.decoder_dimension, + hp.dropout_hidden, + ) + + # Decoder which controls attention and produces mel frames and stop tokens + self._decoder = Decoder( + hp.num_mels, + hp.decoder_dimension, + self._attention, + generator_rnn, + attention_rnn, + decoder_input_dimension, + self._prenet, + hp.prenet_dimension, + hp.max_output_length, + ) + + # Postnet transforming predicted mel frames (residual mel or linear frames) + self._postnet = self._get_postnet() + + def _get_encoder(self): + args = ( + hp.embedding_dimension, + hp.encoder_dimension, + hp.encoder_blocks, + hp.encoder_kernel_size, + hp.dropout, + ) + ln = 1 if not hp.multi_language else hp.language_number + return GeneratedConvolutionalEncoder( + hp.embedding_dimension, + hp.encoder_dimension, + 0.05, + hp.generator_dim, + hp.generator_bottleneck_dim, + groups=ln, + ) + + def _get_attention(self, memory_dimension): + args = (hp.attention_dimension, hp.decoder_dimension, memory_dimension) + return LocationSensitiveAttention( + hp.attention_kernel_size, + hp.attention_location_dimension, + False, + *args, + ) + + def _get_postnet(self): + return Postnet( + hp.num_mels, + hp.postnet_dimension, + hp.postnet_blocks, + hp.postnet_kernel_size, + hp.dropout, + ) + + def forward( + self, + text, + text_length, + target, + target_length, + speakers, + languages, + teacher_forcing_ratio=0.0, + ): + # enlarge speakers and languages to match sentence length if needed + if speakers is not None and speakers.dim() == 1: + speakers = speakers.unsqueeze(1).expand((-1, text.size(1))) + if languages is not None and languages.dim() == 1: + languages = languages.unsqueeze(1).expand((-1, text.size(1))) + + # encode input + embedded = self._embedding(text) + encoded = self._encoder(embedded, text_length, languages) + encoder_output = encoded + + # predict language as an adversarial task if needed + speaker_prediction = (self._reversal_classifier(encoded) + if hp.reversal_classifier else None) + + # decode + if languages is not None and languages.dim() == 3: + languages = torch.argmax( + languages, + dim=2, + ) # convert one-hot into indices + decoded = self._decoder( + encoded, + text_length, + target, + teacher_forcing_ratio, + speakers, + languages, + ) + prediction, stop_token, alignment = decoded + pre_prediction = prediction.transpose(1, 2) + post_prediction = self._postnet(pre_prediction, target_length) + + # mask output paddings + target_mask = lengths_to_mask(target_length, target.size(2)) + stop_token.masked_fill_(~target_mask, 1000) + target_mask = target_mask.unsqueeze(1).float() + pre_prediction = pre_prediction * target_mask + post_prediction = post_prediction * target_mask + + return ( + post_prediction, + pre_prediction, + stop_token, + alignment, + speaker_prediction, + encoder_output, + ) + + def inference(self, text, speaker=None, language=None): + # pretend having a batch of size 1 + text.unsqueeze_(0) + + if speaker is not None and speaker.dim() == 1: + speaker = speaker.unsqueeze(1).expand((-1, text.size(1))) + if language is not None and language.dim() == 1: + language = language.unsqueeze(1).expand((-1, text.size(1))) + + # encode input + embedded = self._embedding(text) + encoded = self._encoder( + embedded, + torch.LongTensor([text.size(1)]), + language, + ) + + # decode with respect to speaker and language embeddings + if language is not None and language.dim() == 3: + language = torch.argmax( + language, + dim=2, + ) # convert one-hot into indices + prediction = self._decoder.inference(encoded, speaker, language) + + # post process generated spectrogram + prediction = prediction.transpose(1, 2) + post_prediction = self._postnet( + prediction, + torch.LongTensor([prediction.size(2)]), + ) + return post_prediction.squeeze(0) diff --git a/pororo/models/tts/utils/__init__.py b/pororo/models/tts/utils/__init__.py new file mode 100644 index 0000000..135a16b --- /dev/null +++ b/pororo/models/tts/utils/__init__.py @@ -0,0 +1,18 @@ +from collections import OrderedDict + +import torch + + +def lengths_to_mask(lengths, max_length=None): + """Convert tensor of lengths into a boolean mask.""" + ml = torch.max(lengths) if max_length is None else max_length + return torch.arange(ml, device=lengths.device)[None, :] < lengths[:, None] + + +def remove_dataparallel_prefix(state_dict): + """Removes dataparallel prefix of layer names in a checkpoint state dictionary.""" + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k[:7] == "module." else k + new_state_dict[name] = v + return new_state_dict diff --git a/pororo/models/tts/utils/audio.py b/pororo/models/tts/utils/audio.py new file mode 100644 index 0000000..95f2d03 --- /dev/null +++ b/pororo/models/tts/utils/audio.py @@ -0,0 +1,113 @@ +import numpy as np +import scipy + +try: + import librosa + import librosa.effects + import librosa.feature +except ImportError: + raise ImportError("Please install librosa with: `pip install librosa`") +import soundfile as sf + +try: + from fastdtw import fastdtw +except ImportError: + raise ImportError("Please install fastdtw with: `pip install fastdtw`") +from pororo.models.tts.tacotron.params import Params as hp + + +def load(path): + """Load a sound file into numpy array.""" + data, sample_rate = sf.read(path) + assert ( + hp.sample_rate == sample_rate + ), f"Sample rate do not match: given {hp.sample_rate}, expected {sample_rate}" + return data + + +def save(data, path): + """Save numpy array as sound file.""" + sf.write(path, data, samplerate=hp.sample_rate) + + +def ms_to_frames(ms): + """Convert milliseconds into number of frames.""" + return int(hp.sample_rate * ms / 1000) + + +def trim_silence(data, window_ms, hop_ms, top_db=50, margin_ms=0): + """Trim leading and trailing silence from an audio signal.""" + wf = ms_to_frames(window_ms) + hf = ms_to_frames(hop_ms) + mf = ms_to_frames(margin_ms) + if mf != 0: + data = data[mf:-mf] + return librosa.effects.trim(data, + top_db=top_db, + frame_length=wf, + hop_length=hf) + + +def duration(data): + """Return duration of an audio signal in seconds.""" + return librosa.get_duration(data, sr=hp.sample_rate) + + +def amplitude_to_db(x): + """Convert amplitude to decibels.""" + return librosa.amplitude_to_db(x, ref=np.max, top_db=None) + + +def db_to_amplitude(x): + """Convert decibels to amplitude.""" + return librosa.db_to_amplitude(x) + + +def preemphasis(y): + """Preemphasize the signal.""" + # y[n] = x[n] - perc * x[n-1] + return scipy.signal.lfilter([1, -hp.preemphasis], [1], y) + + +def spectrogram(y, mel=False): + """Convert waveform to log-magnitude spectrogram.""" + if hp.use_preemphasis: + y = preemphasis(y) + wf = ms_to_frames(hp.stft_window_ms) + hf = ms_to_frames(hp.stft_shift_ms) + S = np.abs(librosa.stft(y, n_fft=hp.num_fft, hop_length=hf, win_length=wf)) + if mel: + S = librosa.feature.melspectrogram(S=S, + sr=hp.sample_rate, + n_mels=hp.num_mels) + return amplitude_to_db(S) + + +def mel_spectrogram(y): + """Convert waveform to log-mel-spectrogram.""" + return spectrogram(y, True) + + +def linear_to_mel(S): + """Convert linear to mel spectrogram (this does not return the same spec. as mel_spec. method due to the db->amplitude conversion).""" + S = db_to_amplitude(S) + S = librosa.feature.melspectrogram(S=S, + sr=hp.sample_rate, + n_mels=hp.num_mels) + return amplitude_to_db(S) + + +def normalize_spectrogram(S, is_mel): + """Normalize log-magnitude spectrogram.""" + if is_mel: + return (S - hp.mel_normalize_mean) / hp.mel_normalize_variance + else: + return (S - hp.lin_normalize_mean) / hp.lin_normalize_variance + + +def denormalize_spectrogram(S, is_mel): + """Denormalize log-magnitude spectrogram.""" + if is_mel: + return S * hp.mel_normalize_variance + hp.mel_normalize_mean + else: + return S * hp.lin_normalize_variance + hp.lin_normalize_mean diff --git a/pororo/models/tts/utils/display.py b/pororo/models/tts/utils/display.py new file mode 100644 index 0000000..e739d96 --- /dev/null +++ b/pororo/models/tts/utils/display.py @@ -0,0 +1,79 @@ +import sys +import time + + +def progbar(i, n, size=16): + done = (i * size) // n + bar = "" + for i in range(size): + bar += "█" if i <= done else "░" + return bar + + +def stream(message): + sys.stdout.write(f"\r{message}") + + +def simple_table(item_tuples): + border_pattern = "+---------------------------------------" + whitespace = " " + + headings, cells, = ( + [], + [], + ) + + for item in item_tuples: + + heading, cell = str(item[0]), str(item[1]) + + pad_head = True if len(heading) < len(cell) else False + + pad = abs(len(heading) - len(cell)) + pad = whitespace[:pad] + + pad_left = pad[:len(pad) // 2] + pad_right = pad[len(pad) // 2:] + + if pad_head: + heading = pad_left + heading + pad_right + else: + cell = pad_left + cell + pad_right + + headings += [heading] + cells += [cell] + + border, head, body = "", "", "" + + for i in range(len(item_tuples)): + + temp_head = f"| {headings[i]} " + temp_body = f"| {cells[i]} " + + border += border_pattern[:len(temp_head)] + head += temp_head + body += temp_body + + if i == len(item_tuples) - 1: + head += "|" + body += "|" + border += "+" + + print(border) + print(head) + print(border) + print(body) + print(border) + print(" ") + + +def time_since(started): + elapsed = time.time() - started + m = int(elapsed // 60) + s = int(elapsed % 60) + if m >= 60: + h = int(m // 60) + m = m % 60 + return f"{h}h {m}m {s}s" + else: + return f"{m}m {s}s" diff --git a/pororo/models/tts/utils/dsp.py b/pororo/models/tts/utils/dsp.py new file mode 100644 index 0000000..47befab --- /dev/null +++ b/pororo/models/tts/utils/dsp.py @@ -0,0 +1,125 @@ +import math + +import librosa +import numpy as np +from scipy.signal import lfilter + +from pororo.models.tts.waveRNN.params import hp + + +def label_2_float(x, bits): + return 2 * x / (2**bits - 1.0) - 1.0 + + +def float_2_label(x, bits): + assert abs(x).max() <= 1.0 + x = (x + 1.0) * (2**bits - 1) / 2 + return x.clip(0, 2**bits - 1) + + +def load_wav(path): + return librosa.load(path, sr=hp.sample_rate)[0] + + +def save_wav(x, path): + librosa.output.write_wav(path, x.astype(np.float32), sr=hp.sample_rate) + + +def split_signal(x): + unsigned = x + 2**15 + coarse = unsigned // 256 + fine = unsigned % 256 + return coarse, fine + + +def combine_signal(coarse, fine): + return coarse * 256 + fine - 2**15 + + +def encode_16bits(x): + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def linear_to_mel(spectrogram): + return librosa.feature.melspectrogram( + S=spectrogram, + sr=hp.sample_rate, + n_fft=hp.n_fft, + n_mels=hp.num_mels, + fmin=hp.fmin, + ) + + +def normalize(S): + return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1) + + +def denormalize(S): + return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db + + +def amp_to_db(x): + return 20 * np.log10(np.maximum(1e-5, x)) + + +def db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def spectrogram(y): + D = stft(y) + S = amp_to_db(np.abs(D)) - hp.ref_level_db + return normalize(S) + + +def melspectrogram(y): + D = stft(y) + S = amp_to_db(linear_to_mel(np.abs(D))) + return normalize(S) + + +def stft(y): + return librosa.stft(y=y, + n_fft=hp.n_fft, + hop_length=hp.hop_length, + win_length=hp.win_length) + + +def pre_emphasis(x): + return lfilter([1, -hp.preemphasis], [1], x) + + +def de_emphasis(x): + return lfilter([1], [1, -hp.preemphasis], x) + + +def encode_mu_law(x, mu): + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5) + + +def decode_mu_law(y, mu, from_labels=True): + # TODO: get rid of log2 - makes no sense + if from_labels: + y = label_2_float(y, math.log2(mu)) + mu = mu - 1 + x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1) + return x + + +def reconstruct_waveform(mel, n_iter=32): + """Uses Griffin-Lim phase reconstruction to convert from a normalized + mel spectrogram back into a waveform.""" + denormalized = denormalize(mel) + amp_mel = db_to_amp(denormalized) + S = librosa.feature.inverse.mel_to_stft(amp_mel, + power=1, + sr=hp.sample_rate, + n_fft=hp.n_fft, + fmin=hp.fmin) + wav = librosa.core.griffinlim(S, + n_iter=n_iter, + hop_length=hp.hop_length, + win_length=hp.win_length) + return wav diff --git a/pororo/models/tts/utils/numerical_pinyin_converter.py b/pororo/models/tts/utils/numerical_pinyin_converter.py new file mode 100644 index 0000000..0b3e2ac --- /dev/null +++ b/pororo/models/tts/utils/numerical_pinyin_converter.py @@ -0,0 +1,134 @@ +""" +Copied from https://github.com/em-shea/tones +Copyright (c) em-shea +""" + +# Enables print statements useful for debugging +DEBUG_ENABLED = False + +# Dictionary with lists of tonal pinyin for each vowel +pinyin = { + "a": ["ā", "á", "ǎ", "à", "a"], + "e": ["ē", "é", "ě", "è", "e"], + "i": ["ī", "í", "ǐ", "ì", "i"], + "o": ["ō", "ó", "ǒ", "ò", "o"], + "u": ["ū", "ú", "ǔ", "ù", "u"], + "ü": ["ǖ", "ǘ", "ǚ", "ǜ", "ü"], +} + + +# Function to enable/disable debugging print statements +def debug(*args, **kwargs): + if DEBUG_ENABLED: + print(*args, **kwargs) + + +# Function that converts numerical pinyin (ni3) to tone marked pinyin (nǐ) +def convert_from_numerical_pinyin(word): + finished_word = [] + + # Splits word into individual character strings and calls convert_indiv_character for each + split_word = word.split(" ") + for indiv_character in split_word: + try: + finished_char = convert_indiv_character(indiv_character) + except: + continue + finished_word.append(finished_char) + + # Joins the returned indiv char back into one string + finished_string = " ".join(finished_word) + debug("Joined individual characters into finished word:", finished_string) + return finished_string + + +# Converts indiv char to tone marked chars +def convert_indiv_character(indiv_character): + debug("") + debug("------") + debug("Starting loop for word:", indiv_character) + + # Convert indiv char string into list of letters + letter_list = list(indiv_character) + + # Identify v letters, convert to ü + for index, letter in enumerate(letter_list): + if letter == "v": + letter_list[index] = "ü" + debug("Letter v converted to 'ü' at index:", index) + + # Start an empty counter and list in case of multiple vowels + counter = 0 + vowels = [] + + # Find and count vowels, and use tone mark logic if multiple found + for index, char in enumerate(letter_list): + if char in "aeiouü": + counter = counter + 1 + vowels.append(char) + debug("Found vowels:", vowels) + + # If multiple vowels are found, use this logic to choose vowel for tone mark + # a, e, or o takes tone mark - a takes tone in 'ao' + # else, second vowel takes tone mark + if counter > 1: + debug("Found multiple vowels, count:", counter) + + if "a" in vowels: + tone_vowel = "a" + elif "o" in vowels: + tone_vowel = "o" + elif "e" in vowels: + tone_vowel = "e" + else: + tone_vowel = vowels[1] + + debug("Selected vowel:", tone_vowel) + elif counter == 0: + # try: + + # If the character is r5 (儿), remove tone number and return + if letter_list == ["r", "5"]: + return "".join(letter_list[:-1]) + else: + raise ValueError( + "Invalid numerical pinyin. Input does not contain a vowel.") + + else: + tone_vowel = vowels[0] + debug("Only one vowel found:", tone_vowel) + + # Select tone number, which is last item in letter_list + tone = letter_list[-1] + + # Set integer to use as pinyin dict/list index + # Select tonal vowel from pinyin dict/list using tone_vowel and tone index + try: + tone_int = int(tone) - 1 + tonal_pinyin = pinyin[tone_vowel][tone_int] + + except Exception as e: + raise ValueError( + "Invalid numerical pinyin. The last letter must be an integer between 1-5." + ) + + debug("Found tone:", tone) + debug("Tone vowel converted:", tonal_pinyin) + + # Cal replace_tone_vowel to replace and reformat the string + return replace_tone_vowel(letter_list, tone_vowel, tonal_pinyin) + + +def replace_tone_vowel(letter_list, tone_vowel, tonal_pinyin): + # Replace the tone vowel with tone marked vowel + letter_list = [w.replace(tone_vowel, tonal_pinyin) for w in letter_list] + debug("Replaced tone vowel with tone mark:", letter_list) + + # Remove tone number + tone_number_removed = letter_list[:-1] + debug("Removed now unnecessary tone number:", tone_number_removed) + + # Reform string + finished_char = "".join(tone_number_removed) + debug("Made the letters list into a string:", finished_char) + return finished_char diff --git a/pororo/models/tts/utils/text.py b/pororo/models/tts/utils/text.py new file mode 100644 index 0000000..105ceca --- /dev/null +++ b/pororo/models/tts/utils/text.py @@ -0,0 +1,101 @@ +import re +import string + +import regex + +try: + import epitran +except ImportError: + raise ImportError("Please install epitran with: `pip install epitran`") +try: + from ko_pron import romanise +except ImportError: + raise ImportError("Please install ko_pron with: `pip install ko_pron`") + +from pororo.models.tts.tacotron.params import Params as hp + +_pad = "_" # a dummy character for padding sequences to align text in batches to the same length +_eos = "~" # character which marks the end of a sequnce, further characters are invalid +_unk = "@" # symbols which are not in hp.characters and are present are substituted by this + + +def jejueo_romanize(text): + word = "" + results = [] + for char in text: + if regex.search("\p{Hangul}", char) is not None: + word += char + else: + result = romanise(word, "rr") + results.append(result) + word = char + result = romanise(word, "rr") + results.append(result) + return "".join(results) + + +def _other_symbols(): + return [_pad, _eos, _unk] + list(hp.punctuations_in) + list( + hp.punctuations_out) + + +def to_lower(text): + """Convert uppercase text into lowercase.""" + return text.lower() + + +def remove_odd_whitespaces(text): + """Remove multiple and trailing/leading whitespaces.""" + return " ".join(text.split()) + + +def remove_punctuation(text): + """Remove punctuation from text.""" + punct_re = "[" + hp.punctuations_out + hp.punctuations_in + "]" + return re.sub(punct_re.replace("-", "\-"), "", text) + + +def to_sequence(text, use_phonemes=False): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.""" + transform_dict = { + s: i for i, s in enumerate(_other_symbols() + list(hp.phonemes if use_phonemes else hp.characters)) + } + sequence = [transform_dict[_unk] if c not in transform_dict else transform_dict[c] for c in text] + sequence.append(transform_dict[_eos]) + return sequence + + +def to_text(sequence, use_phonemes=False): + """Converts a sequence of IDs back to a string""" + transform_dict = { + i: s for i, s in + enumerate(_other_symbols() + + list(hp.phonemes if use_phonemes else hp.characters)) + } + result = "" + for symbol_id in sequence: + if symbol_id in transform_dict: + s = transform_dict[symbol_id] + if s == _eos: + break + result += s + return result + + +def romanize(text): + """ + Copied from https://github.com/kord123/ko_pron + Copyright (c) Andriy Koretskyy + """ + word = "" + results = [] + for char in text: + if regex.search("\p{Hangul}", char) is not None or char == " ": + word += char + elif char.isalpha(): + result = romanise(word, "rr") + results.append(result) + word = char + result = romanise(word, "rr") + results.append(result) + return "".join(results) diff --git a/pororo/models/tts/waveRNN/__init__.py b/pororo/models/tts/waveRNN/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pororo/models/tts/waveRNN/gen_wavernn.py b/pororo/models/tts/waveRNN/gen_wavernn.py new file mode 100644 index 0000000..ce2e794 --- /dev/null +++ b/pororo/models/tts/waveRNN/gen_wavernn.py @@ -0,0 +1,18 @@ +import torch + +from pororo.models.tts.utils.dsp import * +from pororo.models.tts.waveRNN.params import hp + + +def generate(model, spectrogram, batched, target, overlap, save_str=None): + mel = normalize(spectrogram) + if mel.ndim != 2 or mel.shape[0] != hp.num_mels: + raise ValueError(f"Expected a numpy array shaped (n_mels, n_hops) !") + _max = np.max(mel) + _min = np.min(mel) + if _max >= 1.01 or _min <= -0.01: + raise ValueError( + f"Expected spectrogram range in [0,1] but was instead [{_min}, {_max}]" + ) + mel = torch.tensor(mel).unsqueeze(0) + return model.generate(mel, save_str, batched, target, overlap, hp.mu_law) diff --git a/pororo/models/tts/waveRNN/params.py b/pororo/models/tts/waveRNN/params.py new file mode 100644 index 0000000..3097d95 --- /dev/null +++ b/pororo/models/tts/waveRNN/params.py @@ -0,0 +1,65 @@ +class HParams: + # CONFIG -----------------------------------------------------------------------------------------------------------# + + # Here are the input and output data paths (Note: you can override wav_path in preprocess.py) + data_path = "dataset/" + + # model ids are separate - that way you can use a new tts with an old wavernn and vice versa + # NB: expect undefined behaviour if models were trained on different DSP settings + voc_model_id = "css_raw" + + # DSP --------------------------------------------------------------------------------------------------------------# + + # Settings for all models + sample_rate = 22050 + n_fft = 2048 + fft_bins = n_fft // 2 + 1 + num_mels = 80 + hop_length = 275 # 12.5ms - in line with Tacotron 2 paper + win_length = 1100 # 50ms - same reason as above + fmin = 40 + min_level_db = -100 + ref_level_db = 20 + bits = 10 # bit depth of signal + mu_law = ( + True # Recommended to suppress noise if using raw bits in hp.voc_mode below + ) + peak_norm = False # Normalise to the peak of each wav file + + # WAVERNN / VOCODER ------------------------------------------------------------------------------------------------# + + # Model Hparams + voc_mode = "RAW" # either 'RAW' (softmax on raw bits) or 'MOL' (sample from mixture of logistics) + voc_upsample_factors = ( + 5, + 5, + 11, + ) # NB - this needs to correctly factorise hop_length + voc_rnn_dims = 512 + voc_fc_dims = 512 + voc_compute_dims = 128 + voc_res_out_dims = 128 + voc_res_blocks = 10 + + # Training + voc_batch_size = 64 + voc_lr = 1e-3 + lr_decay = 0.5 + lr_decay_start = 100000 + lr_decay_each = 100000 + weight_decay = 1e-6 + voc_checkpoint_every = 25_000 + voc_gen_at_checkpoint = 10 # number of samples to generate at each checkpoint + voc_total_steps = 1_000_000 # Total number of training steps + voc_test_samples = 50 # How many unseen samples to put aside for testing + voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length + voc_seq_len = hop_length * 5 # must be a multiple of hop_length + voc_clip_grad_norm = 4 # set to None if no gradient clipping needed + + # Generating / Synthesizing + voc_gen_batched = True # very fast (realtime+) single utterance batched generation + voc_target = 11_000 # target number of samples to be generated in each batch entry + voc_overlap = 550 # number of samples for crossfading between batches + + +hp = HParams() diff --git a/pororo/models/tts/waveRNN/waveRNN.py b/pororo/models/tts/waveRNN/waveRNN.py new file mode 100644 index 0000000..acaa647 --- /dev/null +++ b/pororo/models/tts/waveRNN/waveRNN.py @@ -0,0 +1,500 @@ +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Union + +from pororo.models.tts.utils.dsp import * + + +class ResBlock(nn.Module): + + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + + def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d(in_dims, + compute_dims, + kernel_size=k_size, + bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for i in range(res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + + def __init__(self, feat_dims, upsample_scales, compute_dims, res_blocks, + res_out_dims, pad): + super().__init__() + total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * total_scale + self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, + res_out_dims, pad) + self.resnet_stretch = Stretch2d(total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, + 1, + kernel_size=k_size, + padding=padding, + bias=False) + conv.weight.data.fill_(1.0 / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent:-self.indent] + return m.transpose(1, 2), aux.transpose(1, 2) + + +class WaveRNN(nn.Module): + + def __init__( + self, + rnn_dims, + fc_dims, + bits, + pad, + upsample_factors, + feat_dims, + compute_dims, + res_out_dims, + res_blocks, + hop_length, + sample_rate, + mode="RAW", + ): + super().__init__() + self.mode = mode + self.pad = pad + if self.mode == "RAW": + self.n_classes = 2**bits + elif self.mode == "MOL": + self.n_classes = 30 + else: + RuntimeError("Unknown model mode value - ", self.mode) + + # List of rnns to call `flatten_parameters()` on + self._to_flatten = [] + + self.rnn_dims = rnn_dims + self.aux_dims = res_out_dims // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + self.upsample = UpsampleNetwork(feat_dims, upsample_factors, + compute_dims, res_blocks, res_out_dims, + pad) + self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) + + self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) + self._to_flatten += [self.rnn1, self.rnn2] + + self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) + self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) + self.fc3 = nn.Linear(fc_dims, self.n_classes) + + self.register_buffer("step", torch.zeros(1, dtype=torch.long)) + self.num_params() + + # Avoid fragmentation of RNN parameters and associated warning + self._flatten_parameters() + + def forward(self, x, mels): + device = next(self.parameters()).device # use same device as parameters + + # Although we `_flatten_parameters()` on init, when using DataParallel + # the model gets replicated, making it no longer guaranteed that the + # weights are contiguous in GPU memory. Hence, we must call it again + self._flatten_parameters() + + self.step += 1 + bsize = x.size(0) + h1 = torch.zeros(1, bsize, self.rnn_dims, device=device) + h2 = torch.zeros(1, bsize, self.rnn_dims, device=device) + mels, aux = self.upsample(mels) + + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + x = self.I(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def generate(self, mels, save_path: Union[str, Path], batched, target, + overlap, mu_law): + self.eval() + + device = next(self.parameters()).device # use same device as parameters + + mu_law = mu_law if self.mode == "RAW" else False + + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + + mels = torch.as_tensor(mels, device=device) + wave_len = (mels.size(-1) - 1) * self.hop_length + mels = self.pad_tensor(mels.transpose(1, 2), + pad=self.pad, + side="both") + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + h1 = torch.zeros(b_size, self.rnn_dims, device=device) + h2 = torch.zeros(b_size, self.rnn_dims, device=device) + x = torch.zeros(b_size, 1, device=device) + + d = self.aux_dims + aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + + m_t = mels[:, i, :] + + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.mode == "MOL": + sample = sample_from_discretized_mix_logistic( + logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1) + + elif self.mode == "RAW": + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - + 1.0) - 1.0 + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError("Unknown model mode value - ", self.mode) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu().numpy() + output = output.astype(np.float64) + + if mu_law: + output = decode_mu_law(output, self.n_classes, False) + + if batched: + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.hop_length) + output = output[:wave_len] + output[-20 * self.hop_length:] *= fade_out + + if save_path is not None: + save_wav(output, save_path) + + self.train() + + return output + + def get_gru_cell(self, gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + def pad_tensor(self, x, pad, side="both"): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == "both" else t + pad + padded = torch.zeros(b, total, c, device=x.device) + if side == "before" or side == "both": + padded[:, pad:pad + t, :] = x + elif side == "after": + padded[:, :t, :] = x + return padded + + def fold_with_overlap(self, x, target, overlap): + """Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + + Details: + x = [[h1, h2, ... hn]] + + Where each h is a vector of conditioning features + + Eg: target=2, overlap=1 with x.size(1)=10 + + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + """ + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side="after") + + folded = torch.zeros(num_folds, + target + 2 * overlap, + features, + device=x.device) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + def xfade_and_unfold(self, y, target, overlap): + """Applies a crossfade and unfolds into a 1d array. + + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + + Details: + y = [[seq1], + [seq2], + [seq3]] + + Apply a gain envelope at both ends of the sequences + + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + + Stagger and add up the groups of samples: + + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + + """ + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + linear = np.ones((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([linear, fade_out]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def get_step(self): + return self.step.data.item() + + def log(self, path, msg): + with open(path, "a") as f: + print(msg, file=f) + + def load(self, path: Union[str, Path]): + # Use device of model params as location for loaded state + device = next(self.parameters()).device + self.load_state_dict(torch.load(path, map_location=device), + strict=False) + + def save(self, path: Union[str, Path]): + # No optimizer argument because saving a model should not include data + # only relevant in the training process - it should only be properties + # of the model itself. Let caller take care of saving optimzier state. + torch.save(self.state_dict(), path) + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + + return parameters + + def _flatten_parameters(self): + """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used + to improve efficiency and avoid PyTorch yelling at us.""" + [m.flatten_parameters() for m in self._to_flatten] + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = F.one_hot(argmax, nr_mix).float() + # select logistic parameters + means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, + dim=-1), + min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x diff --git a/pororo/tasks/speech_synthesis.py b/pororo/tasks/speech_synthesis.py new file mode 100644 index 0000000..1fe5e7d --- /dev/null +++ b/pororo/tasks/speech_synthesis.py @@ -0,0 +1,221 @@ +"""Speech Synthesis related modeling class""" + +from typing import Optional, Tuple +from numpy import ndarray + +from pororo.tasks import ( + PororoFactoryBase, + PororoG2pFactory, + PororoSimpleBase, + download_or_load, +) + + +class PororoTtsFactory(PororoFactoryBase): + """ + Synthesis text to speech using trained model. + Output audio's sample rate is 22050. + + Multi (`tacotron`) + + - dataset: TBU + - metric: TBU + + Args: + text (str): text for speech synthesis + lang (str): text's language Ex) how are you?: en, 안녕하세요.: ko + speaker (str): designate a speaker such as ko, en, zh etc.. (default: lang) + + Returns: + ndarray: waveform of speech signal + + Examples: + >>> import IPython + >>> from IPython.display import Audio + >>> model = Pororo('tts', lang='multi') + >>> wave = model('how are you?', lang='en', speaker='en') + >>> IPython.display.display(IPython.display.Audio(data=wave, rate=22050)) + + >>> model = Pororo('tts', lang='multi') + >>> wave = model('저는 미국 사람이에요.', lang='ko', speaker='en') + >>> IPython.display.display(IPython.display.Audio(data=wave, rate=22050)) + + Notes: + Currently 10 languages supports. + Supported Languages: English, Korean, Japanese, Chinese, Jejueo, Dutch, German, Spanish, French, Russian + This task can designate a speaker such as ko, en, zh etc. + + """ + + def __init__(self, task: str, lang: str = "multi", model: Optional[str] = None): + super().__init__(task, lang, model) + + @staticmethod + def get_available_langs(): + return ["multi"] + + @staticmethod + def get_available_models(): + return { + "multi": ["tacotron"], + } + + def load(self, device: str): + """ + Load user-selected task-specific model + + Args: + device (str): device information + + Returns: + object: User-selected task-specific model + + """ + if self.config.n_model == "tacotron": + from pororo.models.tts.synthesizer import ( + MultilingualSpeechSynthesizer, + ) + from pororo.models.tts.utils.numerical_pinyin_converter import ( + convert_from_numerical_pinyin, + ) + from pororo.models.tts.utils.text import jejueo_romanize, romanize + + tacotron_path = download_or_load("misc/tacotron2", self.config.lang) + english_vocoder_path = download_or_load("misc/hifigan_en", self.config.lang) + korean_vocoder_path = download_or_load("misc/hifigan_ko", self.config.lang) + english_vocoder_config = download_or_load("misc/hifigan_en_config.json", self.config.lang) + korean_vocoder_config = download_or_load("misc/hifigan_ko_config.json", self.config.lang) + wavernn_path = download_or_load("misc/wavernn.pyt", self.config.lang) + synthesizer = MultilingualSpeechSynthesizer( + tacotron_path, + english_vocoder_path, + english_vocoder_config, + korean_vocoder_path, + korean_vocoder_config, + wavernn_path, + device, + self.config.lang, + ) + return PororoTTS( + synthesizer, + device, + romanize, + jejueo_romanize, + convert_from_numerical_pinyin, + self.config, + ) + + +class PororoTTS(PororoSimpleBase): + + def __init__( + self, + synthesizer, + device, + romanize, + jejueo_romanize, + convert_from_numerical_pinyin, + config, + ): + super().__init__(config) + self._synthesizer = synthesizer + + self.g2p_ja = None + self.g2p_zh = None + + self.lang_dict = { + "en": "en", + "ko": "ko", + "ja": "jp", + "de": "de", + "nl": "nl", + "ru": "ru", + "es": "es", + "fr": "fr", + "zh": "zh", + "fi": "fi", + "je": "je", + } + self.device = device + + self.romanize = romanize + self.jejueo_romanize = jejueo_romanize + self.convert_from_numerical_pinyin = convert_from_numerical_pinyin + + def load_g2p_ja(self): + """ load g2p module for Japanese """ + self.g2p_ja = PororoG2pFactory( + task="g2p", + model="g2p.ja", + lang="ja", + ) + self.g2p_ja = self.g2p_ja.load(self.device) + + def load_g2p_zh(self): + """ load g2p module for Chinese """ + self.g2p_zh = PororoG2pFactory( + task="g2p", + model="g2p.zh", + lang="zh", + ) + self.g2p_zh = self.g2p_zh.load(self.device) + + def _preprocess( + self, + text: str, + lang: str = "en", + speaker: str = None, + ) -> Tuple[str, str]: + """ + Pre-process text for TTS format + + Args: + text (str): text for tts + lang (str): text language + speaker (speaker): designation of speaker + + Returns: + str: pre-processed text + + """ + if lang == "ko": + text = self.romanize(text) + elif lang == "ja": + if self.g2p_ja is None: + self.load_g2p_ja() + text = self.g2p_ja(text) + elif lang == "zh": + if self.g2p_zh is None: + self.load_g2p_zh() + text = self.g2p_zh(text).replace(" ", " ") + text = self.convert_from_numerical_pinyin(text) + elif lang == "je": + text = self.jejueo_romanize(text) + return f"{text}|00-{self.lang_dict[lang]}|{speaker}", speaker + + def predict(self, text: str, speaker: str) -> ndarray: + """ + Conduct speech synthesis on given text + + Args: + text (str): text for tts + speaker (speaker): designation of speaker + + Returns: + ndarray: waveform of speech signal + + """ + return self._synthesizer.predict(text, speaker) + + def __call__(self, text: str, lang: str = "en", speaker: str = None): + if speaker is None: + speaker = lang + + if speaker == "ja": + speaker = "jp" + + assert ( + speaker in self.lang_dict.values() + ), f"Unsupported speaker: {speaker}\nSupported speaker: {self.lang_dict.keys()}" + text, speaker = self._preprocess(text, lang, speaker) + return self.predict(text, speaker) diff --git a/tests/test_speech_synthesis.py b/tests/test_speech_synthesis.py new file mode 100644 index 0000000..8d3e6cd --- /dev/null +++ b/tests/test_speech_synthesis.py @@ -0,0 +1,23 @@ +"""Test Text-To-Speech module""" + +import unittest + +import numpy as np + +from pororo import Pororo + + +class PororoTTSTester(unittest.TestCase): + + def test_modules(self): + tts = Pororo(task="tts", lang="multi") + wave = tts("how are you?", lang="en", speaker="en") + self.assertIsInstance(wave, np.ndarray) + + tts = Pororo(task="tts", lang="multi") + wave = tts("저는 미국 사람이에요.", lang="ko", speaker="en") + self.assertIsInstance(wave, np.ndarray) + + +if __name__ == "__main__": + unittest.main()