Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base_dir parameter #174

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def prepare_run(args):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

run_name = args.name or args.tacotron_name or args.model
taco_checkpoint = os.path.join('logs-' + run_name, 'taco_' + args.checkpoint)
taco_checkpoint = os.path.join(args.base_dir, 'logs-' + run_name, 'taco_' + args.checkpoint)

run_name = args.name or args.wavenet_name or args.model
wave_checkpoint = os.path.join('logs-' + run_name, 'wave_' + args.checkpoint)
wave_checkpoint = os.path.join(args.base_dir, 'logs-' + run_name, 'wave_' + args.checkpoint)
return taco_checkpoint, wave_checkpoint, modified_hp

def get_sentences(args):
Expand All @@ -44,16 +44,16 @@ def synthesize(args, hparams, taco_checkpoint, wave_checkpoint, sentences):
def main():
accepted_modes = ['eval', 'synthesis', 'live']
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', default='pretrained/', help='Path to model checkpoint')
parser.add_argument('--base_dir', default='')
parser.add_argument('--checkpoint', default='pretrained', help='Path to model checkpoint')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--name', help='Name of logging directory if the two models were trained together.')
parser.add_argument('--tacotron_name', help='Name of logging directory of Tacotron. If trained separately')
parser.add_argument('--wavenet_name', help='Name of logging directory of WaveNet. If trained separately')
parser.add_argument('--model', default='Tacotron-2')
parser.add_argument('--input_dir', default='training_data/', help='folder to contain inputs sentences/targets')
parser.add_argument('--mels_dir', default='tacotron_output/eval/', help='folder to contain mels to synthesize audio from using the Wavenet')
parser.add_argument('--output_dir', default='output/', help='folder to contain synthesized mel spectrograms')
parser.add_argument('--input_dir', default='training_data', help='folder to contain inputs sentences/targets')
parser.add_argument('--output_dir', default='output', help='folder to contain synthesized mel spectrograms')
parser.add_argument('--mode', default='eval', help='mode of run: can be one of {}'.format(accepted_modes))
parser.add_argument('--GTA', default='True', help='Ground truth aligned synthesis, defaults to True, only considered in synthesis mode')
parser.add_argument('--text_list', default='', help='Text file contains list of texts to be synthesized. Valid if mode=eval')
Expand Down
20 changes: 9 additions & 11 deletions tacotron/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def run_eval(args, checkpoint_path, output_dir, hparams, sentences):
eval_dir = os.path.join(output_dir, 'eval')
log_dir = os.path.join(output_dir, 'logs-eval')

if args.model == 'Tacotron-2':
assert os.path.normpath(eval_dir) == os.path.normpath(args.mels_dir) #mels_dir = wavenet_input_dir

#Create output path if it doesn't exist
os.makedirs(eval_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
Expand All @@ -56,7 +53,6 @@ def run_eval(args, checkpoint_path, output_dir, hparams, sentences):
synth = Synthesizer()
synth.load(checkpoint_path, hparams)


with open(os.path.join(eval_dir, 'map.txt'), 'w') as file:
for i, text in enumerate(tqdm(sentences)):
start = time.time()
Expand All @@ -66,7 +62,7 @@ def run_eval(args, checkpoint_path, output_dir, hparams, sentences):
log('synthesized mel spectrograms at {}'.format(eval_dir))
return eval_dir

def run_synthesis(args, checkpoint_path, output_dir, hparams):
def run_synthesis(args, checkpoint_path, input_dir, output_dir, hparams):
GTA = (args.GTA == 'True')
if GTA:
synth_dir = os.path.join(output_dir, 'gta')
Expand All @@ -80,21 +76,22 @@ def run_synthesis(args, checkpoint_path, output_dir, hparams):
os.makedirs(synth_dir, exist_ok=True)


metadata_filename = os.path.join(args.input_dir, 'train.txt')
metadata_filename = os.path.join(input_dir, 'train.txt')
log(hparams_debug_string())
synth = Synthesizer()
synth.load(checkpoint_path, hparams, gta=GTA)

with open(metadata_filename, encoding='utf-8') as f:
metadata = [line.strip().split('|') for line in f]
frame_shift_ms = hparams.hop_size / hparams.sample_rate
hours = sum([int(x[4]) for x in metadata]) * frame_shift_ms / (3600)
log('Loaded metadata for {} examples ({:.2f} hours)'.format(len(metadata), hours))

metadata = [metadata[i: i+hparams.tacotron_synthesis_batch_size] for i in range(0, len(metadata), hparams.tacotron_synthesis_batch_size)]
metadata = [metadata[i: i + hparams.tacotron_synthesis_batch_size] for i in range(0, len(metadata), hparams.tacotron_synthesis_batch_size)]

log('starting synthesis')
mel_dir = os.path.join(args.input_dir, 'mels')
wav_dir = os.path.join(args.input_dir, 'audio')
mel_dir = os.path.join(input_dir, 'mels')
wav_dir = os.path.join(input_dir, 'audio')
with open(os.path.join(synth_dir, 'map.txt'), 'w') as file:
for i, meta in enumerate(tqdm(metadata)):
texts = [m[5] for m in meta]
Expand All @@ -109,7 +106,8 @@ def run_synthesis(args, checkpoint_path, output_dir, hparams):
return os.path.join(synth_dir, 'map.txt')

def tacotron_synthesize(args, hparams, checkpoint, sentences=None):
output_dir = 'tacotron_' + args.output_dir
input_dir = os.path.join(args.base_dir, args.input_dir)
output_dir = os.path.join(args.base_dir, 'tacotron_' + args.output_dir)

try:
checkpoint_path = tf.train.get_checkpoint_state(checkpoint).model_checkpoint_path
Expand All @@ -120,6 +118,6 @@ def tacotron_synthesize(args, hparams, checkpoint, sentences=None):
if args.mode == 'eval':
return run_eval(args, checkpoint_path, output_dir, hparams, sentences)
elif args.mode == 'synthesis':
return run_synthesis(args, checkpoint_path, output_dir, hparams)
return run_synthesis(args, checkpoint_path, input_dir, output_dir, hparams)
else:
run_live(args, checkpoint_path, hparams)
16 changes: 10 additions & 6 deletions test_wavenet_feeder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import numpy as np
import os
import argparse
from hparams import hparams

from datasets import audio
from hparams import hparams
from tqdm import tqdm



def _limit_time(hparams):
'''Limit time resolution to save GPU memory.
'''
Expand Down Expand Up @@ -34,6 +34,7 @@ def get_groups(args, hparams, meta, local_condition):

return (input_data, local_condition_features, None, len(input_data))


def _adjust_time_resolution(hparams, batch, local_condition, max_time_steps):
'''Adjust time resolution between audio and local condition
'''
Expand Down Expand Up @@ -65,16 +66,19 @@ def _adjust_time_resolution(hparams, batch, local_condition, max_time_steps):
new_batch.append((x, c, g, l))
return new_batch


def _assert_ready_for_upsample(hparams, x, c):
assert len(x) % len(c) == 0 and len(x) // len(c) == audio.get_hop_size(hparams)


def check_time_alignment(hparams, batch, local_condition):
#No need to check beyond this step when preparing data
#Limit time steps to save GPU Memory usage
max_time_steps = _limit_time(hparams)
#Adjust time resolution for upsampling
batch = _adjust_time_resolution(hparams, batch, local_condition, max_time_steps)


def _ensure_divisible(length, divisible_by=256, lower=True):
if length % divisible_by == 0:
return length
Expand All @@ -83,20 +87,20 @@ def _ensure_divisible(length, divisible_by=256, lower=True):
else:
return length + (divisible_by - length % divisible_by)


def run(args, hparams):
with open(args.metadata, 'r') as file:
metadata = [line.strip().split('|') for line in file]

local_condition = hparams.cin_channels > 0

examples = [get_groups(args, hparams, meta, local_condition) for meta in metadata]
batches = [examples[i: i+hparams.wavenet_batch_size] for i in range(0, len(examples), hparams.wavenet_batch_size)]
batches = [examples[i: i + hparams.wavenet_batch_size] for i in range(0, len(examples), hparams.wavenet_batch_size)]

for batch in tqdm(batches):
check_time_alignment(hparams, batch, local_condition)



def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default='')
Expand All @@ -110,4 +114,4 @@ def main():


if __name__ == '__main__':
main()
main()
17 changes: 9 additions & 8 deletions wavenet_vocoder/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from wavenet_vocoder.synthesizer import Synthesizer


def run_synthesis(args, checkpoint_path, output_dir, hparams):
def run_synthesis(args, checkpoint_path, input_dir, output_dir, hparams):
log_dir = os.path.join(output_dir, 'plots')
wav_dir = os.path.join(output_dir, 'wavs')

Expand All @@ -20,7 +20,7 @@ def run_synthesis(args, checkpoint_path, output_dir, hparams):

if args.model == 'Tacotron-2':
#If running all Tacotron-2, synthesize audio from evaluated mels
metadata_filename = os.path.join(args.mels_dir, 'map.txt')
metadata_filename = os.path.join(input_dir, 'map.txt')
with open(metadata_filename, encoding='utf-8') as f:
metadata = np.array([line.strip().split('|') for line in f])

Expand All @@ -31,7 +31,7 @@ def run_synthesis(args, checkpoint_path, output_dir, hparams):
speaker_ids = None if (speaker_ids == '<no_g>').all() else speaker_ids
else:
#else Get all npy files in input_dir (supposing they are mels)
mel_files = [os.path.join(args.mels_dir, f) for f in os.listdir(args.mels_dir) if f.split('.')[-1] == 'npy']
mel_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.split('.')[-1] == 'npy']
speaker_ids = None if args.speaker_id is None else args.speaker_id.replace(' ', '').split(',')

if speaker_ids is not None:
Expand All @@ -43,9 +43,9 @@ def run_synthesis(args, checkpoint_path, output_dir, hparams):
os.makedirs(log_dir, exist_ok=True)
os.makedirs(wav_dir, exist_ok=True)

mel_files = [mel_files[i: i+hparams.wavenet_synthesis_batch_size] for i in range(0, len(mel_files), hparams.wavenet_synthesis_batch_size)]
speaker_ids = None if speaker_ids is None else [speaker_ids[i: i+hparams.wavenet_synthesis_batch_size] for i in range(0, len(speaker_ids), hparams.wavenet_synthesis_batch_size)]
texts = None if texts is None else [texts[i: i+hparams.wavenet_synthesis_batch_size] for i in range(0, len(texts), hparams.wavenet_synthesis_batch_size)]
mel_files = [mel_files[i: i + hparams.wavenet_synthesis_batch_size] for i in range(0, len(mel_files), hparams.wavenet_synthesis_batch_size)]
speaker_ids = None if speaker_ids is None else [speaker_ids[i: i + hparams.wavenet_synthesis_batch_size] for i in range(0, len(speaker_ids), hparams.wavenet_synthesis_batch_size)]
texts = None if texts is None else [texts[i: i + hparams.wavenet_synthesis_batch_size] for i in range(0, len(texts), hparams.wavenet_synthesis_batch_size)]

with open(os.path.join(wav_dir, 'map.txt'), 'w') as file:
for i, mel_batch in enumerate(tqdm(mel_files)):
Expand All @@ -68,12 +68,13 @@ def run_synthesis(args, checkpoint_path, output_dir, hparams):


def wavenet_synthesize(args, hparams, checkpoint):
output_dir = 'wavenet_' + args.output_dir
input_dir = os.path.join(args.base_dir, 'tacotron_' + args.output_dir, 'eval')
output_dir = os.path.join(args.base_dir, 'wavenet_' + args.output_dir)

try:
checkpoint_path = tf.train.get_checkpoint_state(checkpoint).model_checkpoint_path
log('loaded model at {}'.format(checkpoint_path))
except:
raise RuntimeError('Failed to load checkpoint at {}'.format(checkpoint))

run_synthesis(args, checkpoint_path, output_dir, hparams)
run_synthesis(args, checkpoint_path, input_dir, output_dir, hparams)