diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..18765d0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +datasets/ +trainings/ +tmp/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c384084 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +#################### Python ##################### +**/__pycache__/ +**.egg**/ +dist/ +build/ + + +#################### Daft-Exprt ##################### +datasets/ +trainings/ +tmp/ diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..4b702cc --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 Ubisoft Entertainment + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..0cf1b19 --- /dev/null +++ b/README.md @@ -0,0 +1,244 @@ + +# Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis + + +### Julian Zaïdi, Hugo Seuté, Benjamin van Niekerk, Marc-André Carbonneau +In our recent [paper](https://arxiv.org/abs/2108.02271) we propose Daft-Exprt, a multi-speaker acoustic model advancing the state-of-the-art on inter-speaker and inter-text prosody transfer. This improvement is achieved using FiLM conditioning layers, alongside adversarial training that encourages disentanglement between prosodic information and speaker identity. The acoustic model inherits attractive qualities from FastSpeech 2, such as fast inference and local prosody attributes prediction for finer grained control over generation. Moreover, results indicate that adversarial training effectively discards speaker identity information from the prosody representation, which ensures Daft-Exprt will consistently generate speech with the desired voice. + +Experimental results show that Daft-Exprt accurately transfers prosody, while yielding naturalness comparable to state-of-the-art expressive models. Visit our [demo page](https://ubisoft-laforge.github.io/speech/daft-exprt/) for audio samples related to the paper experiments. + +### Pre-trained model +**Full disclosure**: The model provided in this repository is not the same as in the paper evaluation. The model of the paper was trained with proprietary data which prevents us to release it publicly. +We pre-train Daft-Exprt on a combination of [LJ speech dataset](https://keithito.com/LJ-Speech-Dataset/) and the emotional speech dataset (ESD) from [Zhou et al](https://github.com/HLTSingapore/Emotional-Speech-Data). +Visit the [releases](https://github.com/ubisoft/ubisoft-laforge-daft-exprt/releases) of this repository to download the pre-trained model and to listen to prosody transfer examples using this same model. + + +## Table of Contents +- [Installation](#installation) + - [Local Environment](#local-environment) + - [Docker Image](#docker-image) +- [Quick Start Example](#quick-start-example) + - [Introduction](#introduction) + - [Dataset Formatting](#dataset-formatting) + - [Data Pre-Processing](#data-pre-processing) + - [Training](#training) + - [Fine-Tuning](#fine-tuning) + - [TTS Synthesis](#tts-synthesis) +- [Citation](#citation) +- [Contributing](#contributing) + +## Installation + +### Local Environment +Requirements: +- Ubuntu >= 20.04 +- Python >= 3.8 +- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx?lang=en-us) >= 450.80.02 +- [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) >= 11.1 +- [CuDNN](https://developer.nvidia.com/rdp/cudnn-archive) >= v8.0.5 + +We recommend using conda for python environment management, for example download and install [Miniconda](https://docs.conda.io/en/latest/miniconda.html). +Create your python environment and install dependencies using the Makefile: +1. `conda create -n daft_exprt python=3.8 -y` +2. `conda activate daft_exprt` +3. `cd environment` +4. `make` + +All Linux/Conda/Python dependencies will be installed by the Makefile, and the repository will be installed as a pip package in editable mode. + +### Docker Image +Requirements: +- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) +- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx?lang=en-us) >= 450.80.02 + +Build the Docker image using the associated Dockerfile: +1. `docker build -f environment/Dockerfile -t daft_exprt .` + +## Quick Start Example + +### Introduction +This quick start guide will illustrate how to use the different scripts of this repository to: +1. Format datasets +2. Pre-process these datasets +3. Train DaftExprt on the pre-processed data +4. Generate a dataset for vocoder fine-tuning +5. Use Daft-Exprt for TTS synthesis + +All scripts are located in [scripts](scripts) directory. +Daft-Exprt source code is located in [daft_exprt](src/daft_exprt) directory. +Config parameters used in the scripts are all instanciated in [hparams.py](src/daft_exprt/hparams.py). + +As a quick start example, we consider using the 22kHz [LJ speech dataset](https://keithito.com/LJ-Speech-Dataset/) and the 16kHz emotional speech dataset (ESD) from [Zhou et al](https://github.com/HLTSingapore/Emotional-Speech-Data). +This combines a total of 11 speakers. All speaker datasets must be in the same root directory. For example: +``` +/data_dir + LJ_Speech + ESD + spk_1 + ... + spk_N +``` + +In this example, we use the docker image built in the previous section: + ``` +docker run -it --gpus all -v /path/to/data_dir:/workdir/data_dir -v path/to/repo_dir:/workdir/repo_dir IMAGE_ID +``` + + +### Dataset Formatting +The source code expects the specific tree structure for each speaker data set: +``` +/speaker_dir + metadata.csv + /wavs + wav_file_name_1.wav + ... + wav_file_name_N.wav +``` + +metadata.csv must be formatted as follows: +``` +wav_file_name_1|text_1 +... +wav_file_name_N|text_N +``` + +Given each dataset has its own nomenclature, this project does not provide a ready-made universal script. +However, the script [format_dataset.py](scripts/format_dataset.py) already proposes the code to format LJ and ESD: +``` +python format_dataset.py \ + --data_set_dir /workdir/data_dir/LJ_Speech \ + LJ + +python format_dataset.py \ + --data_set_dir /workdir/data_dir/ESD \ + ESD \ + --language english +``` + +### Data Pre-Processing +In this section, the code will: +1. Align data using MFA +2. Extract features for training +3. Create train and validation sets +4. Extract features stats on the train set for speaker standardization + +To pre-process all available formatted data (i.e. LJ and ESD in this example): +``` +python training.py \ + --experiment_name EXPERIMENT_NAME \ + --data_set_dir /workdir/data_dir \ + pre_process +``` + +This will pre-process data using the default hyper-parameters that are set for 22kHz audios. +All outputs related to the experiment will be stored in `/workdir/repo_dir/trainings/EXPERIMENT_NAME`. +You can also target specific speakers for data pre-processing. For example, to consider only ESD speakers: +``` +python training.py \ + --experiment_name EXPERIMENT_NAME \ + --speakers ESD/spk_1 ... ESD/spk_N \ + --data_set_dir /workdir/data_dir \ + pre_process +``` + +The pre-process function takes several arguments: +- `--features_dir`: absolute path where pre-processed data will be stored. Default to `/workdir/repo_dir/datasets` +- `--proportion_validation`: Proportion of examples that will be in the validation set. Default to `0.1`% per speaker. +- `--nb_jobs`: number of cores to use for python multi-processing. If set to `max`, all CPU cores are used. Default to `6`. + +Note that if it is the first time that you pre-process the data, this step will take several hours. +You can decrease computing time by increasing the `--nb_jobs` parameter. + +### Training +Once pre-processing is finished, launch training. To train on all pre-processed data: +``` +python training.py \ + --experiment_name EXPERIMENT_NAME \ + --data_set_dir /workdir/data_dir \ + train +``` + +Or if you targeted specific speakers during pre-processing (e.g. ESD speakers): +``` +python training.py \ + --experiment_name EXPERIMENT_NAME \ + --speakers ESD/spk_1 ... ESD/spk_N \ + --data_set_dir /workdir/data_dir \ + train +``` + +All outputs related to the experiment will be stored in `/workdir/repo_dir/trainings/EXPERIMENT_NAME`. + +The train function takes several arguments: +- `--checkpoint`: absolute path of a Daft-Exprt checkpoint. Default to `""` +- `--no_multiprocessing_distributed`: disable PyTorch multi-processing distributed training. Default to `False` +- `--world_size`: number of nodes for distributed training. Default to `1`. +- `--rank`: node rank for distributed training. Default to `0`. +- `--master`: url used to set up distributed training. Default to `tcp://localhost:54321`. + +These default values will launch a new training starting at iteration 0, using all available GPUs on the machine. +The code supposes that only 1 GPU is available on the machine. +Default [batch size](src/daft_exprt/hparams.py#L66) and [gradient accumulation](src/daft_exprt/hparams.py#L67) hyper-parameters are set to values to reproduce the batch size of 48 from the paper. + +The code also supports tensorboard logging. To display logging outputs: +`tensorboard --logdir_spec=EXPERIMENT_NAME:/workdir/repo_dir/trainings/EXPERIMENT_NAME/logs` + +### Fine-Tuning +Once training is finished, you can create a dataset for vocoder fine-tuning: +``` +python training.py \ + --experiment_name EXPERIMENT_NAME \ + --data_set_dir /workdir/data_dir \ + fine_tune \ + --checkpoint CHECKPOINT_PATH +``` + +Or if you targeted specific speakers during pre-processing and training (e.g. ESD speakers): +``` +python training.py \ + --experiment_name EXPERIMENT_NAME \ + --speakers ESD/spk_1 ... ESD/spk_N \ + --data_set_dir /workdir/data_dir \ + fine_tune \ + --checkpoint CHECKPOINT_PATH +``` + +Fine-tuning dataset will be stored in `/workdir/repo_dir/trainings/EXPERIMENT_NAME/fine_tuning_dataset`. + +### TTS Synthesis +For an example on how to use Daft-Exprt for TTS synthesis, run the script [synthesize.py](scripts/synthesize.py). +``` +python synthesize.py \ + --output_dir OUTPUT_DIR \ + --checkpoint CHECKPOINT +``` + +Default sentences and reference utterances are used in the script. + +The script also offers the possibility to: +- `--batch_size`: process batch of sentences in parallel +- `--real_time_factor`: estimate Daft-Exprt real time factor performance given the chosen batch size +- `--control`: perform local prosody control + + +## Citation +``` +@article{Zaidi2018, +abstract = {}, +journal = {arXiv}, +arxivId = {2108.02271}, +author = {Za{\"{i}}di, Julian and Seut{\'{e}}, Hugo and van Niekerk, Benjamin and Carbonneau, Marc-Andr{\'{e}}}, +eprint = {2108.02271}, +title = {{Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis}}, +url = {https://arxiv.org/pdf/2108.02271.pdf}, +year = {2021} +} +``` + +## Contributing +Any contribution to this repository is more than welcome! +If you have any feedback, please send it to julian.zaidi@ubisoft.com. + + +© [2021] Ubisoft Entertainment. All Rights Reserved \ No newline at end of file diff --git a/environment/Dockerfile b/environment/Dockerfile new file mode 100644 index 0000000..81d5016 --- /dev/null +++ b/environment/Dockerfile @@ -0,0 +1,40 @@ +FROM nvidia/cuda:11.2.0-base-ubuntu20.04 + +# set environment variables +ARG DEBIAN_FRONTEND=noninteractive +ENV CONDA_AUTO_UPDATE_CONDA=false +ENV PATH=/root/miniconda3/bin:$PATH + +# install linux packages +RUN apt-get update && apt-get install -y curl libsndfile1 libopenblas-dev +RUN rm -rf /var/lib/apt/lists/* + +# install miniconda and python 3.8 +RUN curl -sLo miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && chmod +x ./miniconda.sh \ + && ./miniconda.sh -b -p /root/miniconda3 \ + && rm miniconda.sh \ + && conda install -y python==3.8 \ + && conda clean -ya + +# install conda and pip packages +COPY ./environment/conda_requirements.txt /opt/daft_exprt/environment/ +COPY ./environment/pip_requirements.txt /opt/daft_exprt/environment/ +WORKDIR /opt/daft_exprt/environment +RUN conda install -c conda-forge -y --file conda_requirements.txt +RUN pip install -r pip_requirements.txt + +# install PyTorch +RUN pip install torch==1.9.0+cu111 torchaudio==0.9.0 tensorboard -f https://download.pytorch.org/whl/torch_stable.html + +# install MFA thirdparty packages +RUN mfa thirdparty download +RUN mfa thirdparty validate + +# download pre-trained MFA models for english +RUN mfa download acoustic english +RUN mfa download g2p english_g2p +RUN mfa download dictionary english + +# set working directory +WORKDIR /workdir diff --git a/environment/Makefile b/environment/Makefile new file mode 100644 index 0000000..4458bb7 --- /dev/null +++ b/environment/Makefile @@ -0,0 +1,21 @@ +.PHONY: all + +all: linux_requirements python_requirements MFA_thirdparty MFA_pretrained + +linux_requirements: + sudo apt-get update && sudo apt-get install -y libsndfile1 libopenblas-dev + +python_requirements: + conda install -c conda-forge -y --file conda_requirements.txt + pip install pip setuptools --upgrade + pip install -e ../.[pytorch] \ + --find-links https://download.pytorch.org/whl/torch_stable.html + +MFA_thirdparty: + mfa thirdparty download + mfa thirdparty validate + +MFA_pretrained: + mfa download acoustic english + mfa download g2p english_g2p + mfa download dictionary english diff --git a/environment/conda_requirements.txt b/environment/conda_requirements.txt new file mode 100644 index 0000000..2f592c4 --- /dev/null +++ b/environment/conda_requirements.txt @@ -0,0 +1,6 @@ +# MFA dependencies +baumwelch +ngram +openblas +openfst +pynini diff --git a/environment/pip_requirements.txt b/environment/pip_requirements.txt new file mode 100644 index 0000000..3f2e7e6 --- /dev/null +++ b/environment/pip_requirements.txt @@ -0,0 +1,7 @@ +inflect==5.3.0 +librosa == 0.8.1 +matplotlib == 3.4.3 +montreal-forced-aligner == 2.0.0a24 +python-dateutil == 2.8.2 +tgt == 1.4.4 +unidecode == 1.2.0 \ No newline at end of file diff --git a/scripts/benchmarks/english/sentences.txt b/scripts/benchmarks/english/sentences.txt new file mode 100644 index 0000000..a5ebf28 --- /dev/null +++ b/scripts/benchmarks/english/sentences.txt @@ -0,0 +1,19 @@ +This is the kind of quality that we get with our algorithm. +Unknown words may not be synthetized by all bingbongratata jablow. +The whole thing of doing the movie was a risk. +Wish I could be there on Sunday in person but I can't. +The blue lagoon is a nineteen eighty American romance adventure film. +The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. +Aspects of the sublime in English poetry and painting, seventeen seventy to eighteen fifty. +How many pickled peppers did Peter Piper pick? +Sally sells sea shells by the sea shore. +He thought it was time to present the present. +We are not out of the wood yet, but we still do not think of the witch. +To my readers. +But just now, my loving tirans, won't allow me. +And as one of my small friends, Hadly states. +It isn't a real Oz story without her. +Then he dreamed about it, and waking or dreaming he found the tale hard to believe. +Do you think this sentence will have the good pitch? +Do you rest? +Is it a good choice? \ No newline at end of file diff --git a/scripts/format_dataset.py b/scripts/format_dataset.py new file mode 100644 index 0000000..ce7c8b8 --- /dev/null +++ b/scripts/format_dataset.py @@ -0,0 +1,117 @@ +import argparse +import logging +import os + +from shutil import copyfile + + +_logger = logging.getLogger(__name__) + + +''' + This script modifies speakers data sets to match the required format + Each speaker data set must be of the following format: + + /speaker_name + metadata.csv + /wavs + wav_file_name_1.wav + wav_file_name_2.wav + ... + + metadata.csv must be formatted as follows (pipe "|" separator): + wav_file_name_1|text_1 + wav_file_name_2|text_2 + ... +''' + + +def format_LJ_speech(lj_args): + ''' Format LJ data set + Only metadata.csv needs to be modified + ''' + # read metadata lines + _logger.info('Formatting LJ Speech') + metadata = os.path.join(lj_args.data_set_dir, 'metadata.csv') + assert(os.path.isfile(metadata)), _logger.error(f'There is no such file {metadata}') + with open(metadata, 'r', encoding='utf-8') as f: + metadata_lines = f.readlines() + # create new metadata.csv + metadata_lines = [line.strip().split(sep='|') for line in metadata_lines] + metadata_lines = [f'{line[0]}|{line[2]}\n' for line in metadata_lines] + with open(metadata, 'w', encoding='utf-8') as f: + f.writelines(metadata_lines) + _logger.info('Done!') + + +def format_ESD(esd_args): + ''' Format ESD data set + ''' + # extract speaker dirs depending on the language + _logger.info(f'Formatting ESD -- Language = {esd_args.language}') + speakers = [x for x in os.listdir(esd_args.data_set_dir) if + os.path.isdir(os.path.join(esd_args.data_set_dir, x))] + speakers.sort() + if esd_args.language == 'english': + for speaker in speakers[10:]: + _logger.info(f'Speaker -- {speaker}') + speaker_dir = os.path.join(esd_args.data_set_dir, speaker) + spk_out_dir = os.path.join(esd_args.data_set_dir, esd_args.language, speaker) + os.makedirs(spk_out_dir, exist_ok=True) + # read metadata lines + if speaker == speakers[10]: + metadata = os.path.join(speaker_dir,f'{speaker}.txt') + assert(os.path.isfile(metadata)), _logger.error(f'There is no such file {metadata}') + with open(metadata, 'r', encoding='utf-8') as f: + metadata_lines = f.readlines() + metadata_lines = [line.strip().split(sep='\t') for line in metadata_lines] + # create new metadata.csv + spk_metadata_lines = [f'{speaker}_{line[0].strip().split(sep="_")[1]}|{line[1]}\n' + for line in metadata_lines] + with open(os.path.join(spk_out_dir, 'metadata.csv'), 'w', encoding='utf-8') as f: + f.writelines(spk_metadata_lines) + # copy all audio files to /wavs directory + wavs_dir = os.path.join(spk_out_dir, 'wavs') + os.makedirs(wavs_dir, exist_ok=True) + for root, _, files in os.walk(speaker_dir): + wav_files = [x for x in files if x.endswith('.wav')] + for wav_file in wav_files: + src = os.path.join(root, wav_file) + dst = os.path.join(wavs_dir, wav_file) + copyfile(src, dst) + elif esd_args.language == 'mandarin': + _logger.error(f'"mandarin" not implemented') + else: + _logger.error(f'"language" must be either "english" or "mandarin", not "{esd_args.language}"') + _logger.info('Done!') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='script to format speakers data sets') + subparsers = parser.add_subparsers(help='commands for targeting a specific data set') + + parser.add_argument('-dd', '--data_set_dir', type=str, + help='path to the directory containing speakers data sets to format') + + parser_LJ = subparsers.add_parser('LJ', help='format LJ data set') + parser_LJ.set_defaults(func=format_LJ_speech) + + parser_ESD = subparsers.add_parser('ESD', help='format emotional speech dataset from Zhou et al.') + parser_ESD.set_defaults(func=format_ESD) + parser_ESD.add_argument('-lg', '--language', type=str, + help='either english or mandarin') + + args = parser.parse_args() + + # set logger config + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + + # run args + args.func(args) diff --git a/scripts/style_bank/english/0012_000567.wav b/scripts/style_bank/english/0012_000567.wav new file mode 100644 index 0000000..2f8e4e4 Binary files /dev/null and b/scripts/style_bank/english/0012_000567.wav differ diff --git a/scripts/style_bank/english/0012_001611.wav b/scripts/style_bank/english/0012_001611.wav new file mode 100644 index 0000000..fcdbefd Binary files /dev/null and b/scripts/style_bank/english/0012_001611.wav differ diff --git a/scripts/style_bank/english/0015_000490.wav b/scripts/style_bank/english/0015_000490.wav new file mode 100644 index 0000000..f46878c Binary files /dev/null and b/scripts/style_bank/english/0015_000490.wav differ diff --git a/scripts/style_bank/english/0015_001566.wav b/scripts/style_bank/english/0015_001566.wav new file mode 100644 index 0000000..f00a72c Binary files /dev/null and b/scripts/style_bank/english/0015_001566.wav differ diff --git a/scripts/style_bank/english/0018_000536.wav b/scripts/style_bank/english/0018_000536.wav new file mode 100644 index 0000000..0aadefb Binary files /dev/null and b/scripts/style_bank/english/0018_000536.wav differ diff --git a/scripts/style_bank/english/0018_001645.wav b/scripts/style_bank/english/0018_001645.wav new file mode 100644 index 0000000..429ce82 Binary files /dev/null and b/scripts/style_bank/english/0018_001645.wav differ diff --git a/scripts/style_bank/english/0019_000607.wav b/scripts/style_bank/english/0019_000607.wav new file mode 100644 index 0000000..1d586d1 Binary files /dev/null and b/scripts/style_bank/english/0019_000607.wav differ diff --git a/scripts/style_bank/english/0019_001536.wav b/scripts/style_bank/english/0019_001536.wav new file mode 100644 index 0000000..94d6784 Binary files /dev/null and b/scripts/style_bank/english/0019_001536.wav differ diff --git a/scripts/style_bank/english/0_audio_ref.wav b/scripts/style_bank/english/0_audio_ref.wav new file mode 100644 index 0000000..3adb4bb Binary files /dev/null and b/scripts/style_bank/english/0_audio_ref.wav differ diff --git a/scripts/style_bank/english/18_audio_ref.wav b/scripts/style_bank/english/18_audio_ref.wav new file mode 100644 index 0000000..fcc0566 Binary files /dev/null and b/scripts/style_bank/english/18_audio_ref.wav differ diff --git a/scripts/style_bank/english/24_audio_ref.wav b/scripts/style_bank/english/24_audio_ref.wav new file mode 100644 index 0000000..028dbc2 Binary files /dev/null and b/scripts/style_bank/english/24_audio_ref.wav differ diff --git a/scripts/style_bank/english/26_audio_ref.wav b/scripts/style_bank/english/26_audio_ref.wav new file mode 100644 index 0000000..7f96ec6 Binary files /dev/null and b/scripts/style_bank/english/26_audio_ref.wav differ diff --git a/scripts/style_bank/english/2_audio_ref.wav b/scripts/style_bank/english/2_audio_ref.wav new file mode 100644 index 0000000..3ed3503 Binary files /dev/null and b/scripts/style_bank/english/2_audio_ref.wav differ diff --git a/scripts/style_bank/english/36_audio_ref.wav b/scripts/style_bank/english/36_audio_ref.wav new file mode 100644 index 0000000..a334cdd Binary files /dev/null and b/scripts/style_bank/english/36_audio_ref.wav differ diff --git a/scripts/style_bank/english/38_audio_ref.wav b/scripts/style_bank/english/38_audio_ref.wav new file mode 100644 index 0000000..31e6691 Binary files /dev/null and b/scripts/style_bank/english/38_audio_ref.wav differ diff --git a/scripts/synthesize.py b/scripts/synthesize.py new file mode 100644 index 0000000..9ea7d2c --- /dev/null +++ b/scripts/synthesize.py @@ -0,0 +1,149 @@ +import argparse +import logging +import os +import random +import sys +import time + +import torch + +from shutil import copyfile + +FILE_ROOT = os.path.dirname(os.path.realpath(__file__)) +PROJECT_ROOT = os.path.dirname(FILE_ROOT) +os.environ['PYTHONPATH'] = os.path.join(PROJECT_ROOT, 'src') +sys.path.append(os.path.join(PROJECT_ROOT, 'src')) + +from daft_exprt.generate import extract_reference_parameters, generate_mel_specs, prepare_sentences_for_inference +from daft_exprt.hparams import HyperParams +from daft_exprt.model import DaftExprt +from daft_exprt.utils import get_nb_jobs + + +_logger = logging.getLogger(__name__) +random.seed(1234) + + +''' + Script example that showcases how to generate with Daft-Exprt + using a target sentence, a target speaker, and a target prosody +''' + + +def synthesize(args, dur_factor=None, energy_factor=None, pitch_factor=None, + pitch_transform=None, use_griffin_lim=False, get_time_perf=False): + ''' Generate with DaftExprt + ''' + # get hyper-parameters that were used to create the checkpoint + checkpoint_dict = torch.load(args.checkpoint, map_location=f'cuda:{0}') + hparams = HyperParams(verbose=False, **checkpoint_dict['config_params']) + # load model + torch.cuda.set_device(0) + model = DaftExprt(hparams).cuda(0) + state_dict = {k.replace('module.', ''): v for k, v in checkpoint_dict['state_dict'].items()} + model.load_state_dict(state_dict) + + # prepare sentences + n_jobs = get_nb_jobs('max') + sentences, file_names = prepare_sentences_for_inference(args.text_file, args.output_dir, hparams, n_jobs) + # extract reference parameters + audio_refs = [os.path.join(args.style_bank, x) for x in os.listdir(args.style_bank) if x.endswith('.wav')] + for audio_ref in audio_refs: + extract_reference_parameters(audio_ref, args.style_bank, hparams) + # choose a random reference per sentence + refs = [os.path.join(args.style_bank, x) for x in os.listdir(args.style_bank) if x.endswith('.npz')] + refs = [random.choice(refs) for _ in range(len(sentences))] + # choose a random speaker ID per sentence + speaker_ids = [random.choice(hparams.speakers_id) for _ in range(len(sentences))] + + # add duration factors for each symbol in the sentence + dur_factors = [] if dur_factor is not None else None + energy_factors = [] if energy_factor is not None else None + pitch_factors = [pitch_transform, []] if pitch_factor is not None else None + for sentence in sentences: + # count number of symbols in the sentence + nb_symbols = 0 + for item in sentence: + if isinstance(item, list): # correspond to phonemes of a word + nb_symbols += len(item) + else: # correspond to word boundaries + nb_symbols += 1 + # append to lists + if dur_factors is not None: + dur_factors.append([dur_factor for _ in range(nb_symbols)]) + if energy_factors is not None: + energy_factors.append([energy_factor for _ in range(nb_symbols)]) + if pitch_factors is not None: + pitch_factors[1].append([pitch_factor for _ in range(nb_symbols)]) + + # generate mel-specs and synthesize audios with Griffin-Lim + generate_mel_specs(model, sentences, file_names, speaker_ids, refs, args.output_dir, + hparams, dur_factors, energy_factors, pitch_factors, args.batch_size, + n_jobs, use_griffin_lim, get_time_perf) + + return file_names, refs, speaker_ids + + +def pair_ref_and_generated(args, file_names, refs, speaker_ids): + ''' Simplify prosody transfer evaluation by matching generated audio with its reference + ''' + # save references to output dir to make prosody transfer evaluation easier + for idx, (file_name, ref, speaker_id) in enumerate(zip(file_names, refs, speaker_ids)): + # extract reference audio + ref_file_name = os.path.basename(ref).replace('.npz', '') + audio_ref = os.path.join(args.style_bank, f'{ref_file_name}.wav') + # check correponding synthesized audio exists + synthesized_file_name = f'{file_name}_spk_{speaker_id}_ref_{ref_file_name}' + synthesized_audio = os.path.join(args.output_dir, f'{synthesized_file_name}.wav') + assert(os.path.isfile(synthesized_audio)), _logger.error(f'There is no such file {synthesized_audio}') + # rename files + os.rename(synthesized_audio, f'{os.path.join(args.output_dir, f"{idx}_{synthesized_file_name}.wav")}') + copyfile(audio_ref, f'{os.path.join(args.output_dir, f"{idx}_ref.wav")}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='script to synthesize sentences with Daft-Exprt') + + parser.add_argument('-out', '--output_dir', type=str, + help='output dir to store synthesis outputs') + parser.add_argument('-chk', '--checkpoint', type=str, + help='checkpoint path to use for synthesis') + parser.add_argument('-tf', '--text_file', type=str, default=os.path.join(PROJECT_ROOT, 'scripts', 'benchmarks', 'english', 'sentences.txt'), + help='text file to use for synthesis') + parser.add_argument('-sb', '--style_bank', type=str, default=os.path.join(PROJECT_ROOT, 'scripts', 'style_bank', 'english'), + help='directory path containing the reference utterances to use for synthesis') + parser.add_argument('-bs', '--batch_size', type=int, default=50, + help='batch of sentences to process in parallel') + parser.add_argument('-rtf', '--real_time_factor', action='store_true', + help='get Daft-Exprt real time factor performance given the batch size') + parser.add_argument('-ctrl', '--control', action='store_true', + help='perform local prosody control during synthesis') + + args = parser.parse_args() + + # set logger config + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + + if args.real_time_factor: + synthesize(args, get_time_perf=True) + time.sleep(5) + _logger.info('') + if args.control: + # small hard-coded example that showcases duration and pitch control + # control is performed on the sentence level in this example + # however, the code also supports control on the word/phoneme level + dur_factor = 1.25 # decrease speed + pitch_transform = 'add' # pitch shift + pitch_factor = 50 # 50Hz + synthesize(args, dur_factor=dur_factor, pitch_factor=pitch_factor, + pitch_transform=pitch_transform, use_griffin_lim=True) + else: + file_names, refs, speaker_ids = synthesize(args, use_griffin_lim=True) + pair_ref_and_generated(args, file_names, refs, speaker_ids) diff --git a/scripts/training.py b/scripts/training.py new file mode 100644 index 0000000..3473b57 --- /dev/null +++ b/scripts/training.py @@ -0,0 +1,203 @@ +import argparse +import json +import logging +import os +import sys + +from shutil import copyfile +from subprocess import call + +# ROOT directory +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +os.environ['PYTHONPATH'] = os.path.join(PROJECT_ROOT, 'src') +sys.path.append(os.path.join(PROJECT_ROOT, 'src')) + +from daft_exprt.create_sets import create_sets +from daft_exprt.extract_features import check_features_config_used, extract_features +from daft_exprt.features_stats import extract_features_stats +from daft_exprt.hparams import HyperParams +from daft_exprt.mfa import mfa +from daft_exprt.utils import get_nb_jobs + + +_logger = logging.getLogger(__name__) + + +def list_all_speakers(data_set_dir): + ''' List all speakers contained in data_set_dir + ''' + # initialize variables + speakers = [] + data_set_dir = os.path.normpath(data_set_dir) + # walk into data_set_dir + for root, directories, files in os.walk(data_set_dir): + if 'wavs' in directories and 'metadata.csv' in files: + # extract speaker data set relative path + spk_relative_path = os.path.relpath(root, data_set_dir) + spk_relative_path = os.path.normpath(spk_relative_path) + speakers.append(f'{spk_relative_path}') + + return speakers + + +def pre_process(pre_process_args): + ''' Pre-process speakers data sets for training + ''' + # check experiment folder is new + checkpoint_dir = os.path.join(output_directory, 'checkpoints') + if os.path.isdir(checkpoint_dir): + print(f'"{output_directory}" has already been used for a previous training experiment') + print(f'Cannot perform pre-processing') + print(f'Please change the "experiment_name" script argument\n') + sys.exit(1) + + # set logger config + log_dir = os.path.join(output_directory, 'logs') + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, 'pre_processing.log') + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file, mode='w') + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + + # create default location for features dir if not specified by the user + features_dir = os.path.join(pre_process_args.features_dir, pre_process_args.language, f'{hparams.sampling_rate}Hz') \ + if pre_process_args.features_dir == os.path.join(PROJECT_ROOT, "datasets") else pre_process_args.features_dir + # check current config is the same than the one used in features dir + if os.path.isdir(features_dir): + same_config = check_features_config_used(features_dir, hparams) + assert(same_config), _logger.error(f'"{features_dir}" contains data that were extracted using a different set ' + f'of hyper-parameters. Please change the "features_dir" script argument') + + # set number of parallel jobs + nb_jobs = get_nb_jobs(pre_process_args.nb_jobs) + # perform alignment using MFA + mfa(data_set_dir, hparams, nb_jobs) + + # copy metadata.csv + for speaker in hparams.speakers: + spk_features_dir = os.path.join(features_dir, speaker) + os.makedirs(spk_features_dir, exist_ok=True) + metadata_src = os.path.join(data_set_dir, speaker, 'metadata.csv') + metadata_dst = os.path.join(features_dir, speaker, 'metadata.csv') + assert(os.path.isfile(metadata_src)), _logger.error(f'There is no such file: {metadata_src}') + copyfile(metadata_src, metadata_dst) + + # extract features + extract_features(data_set_dir, features_dir, hparams, nb_jobs) + # create train and valid sets + create_sets(features_dir, hparams, pre_process_args.proportion_validation) + # extract features stats on the training set + stats = extract_features_stats(hparams, nb_jobs) + with open(stats_file, 'w') as f: + json.dump(stats, f, indent=4, sort_keys=True) + + +def train(train_args): + ''' Train Daft-Exprt on the pre-processed data sets + ''' + # launch training in distributed mode or not + training_script = os.path.join(PROJECT_ROOT, 'src', 'daft_exprt', 'train.py') + process = ['python', f'{training_script}', + '--data_set_dir', f'{data_set_dir}', + '--config_file', f'{config_file}', + '--benchmark_dir', f'{benchmark_dir}', + '--log_file', f"{os.path.join(output_directory, 'logs', 'training.log')}", + '--world_size', f'{train_args.world_size}', + '--rank', f'{train_args.rank}', + '--master', f'{train_args.master}'] + if not train_args.no_multiprocessing_distributed: + process.append('--multiprocessing_distributed') + call(process) + + +def fine_tune(fine_tune_args): + ''' Generate data sets with the Daft-Exprt trained model for vocoder fine-tuning + ''' + # launch fine-tuning + fine_tune_script = os.path.join(PROJECT_ROOT, 'src', 'daft_exprt', 'fine_tune.py') + process = ['python', f'{fine_tune_script}', + '--data_set_dir', f'{data_set_dir}', + '--config_file', f'{config_file}', + '--log_file', f"{os.path.join(output_directory, 'logs', 'fine_tuning.log')}"] + call(process) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='script to pre-process speakers data sets and train with Daft-Exprt') + subparsers = parser.add_subparsers(help='commands for pre-processing, training and generating data for vocoder fine-tuning') + + parser.add_argument('-en', '--experiment_name', type=str, + help='directory name where all pre-process, training and fine-tuning outputs will be stored') + parser.add_argument('-dd', '--data_set_dir', type=str, + help='path to the directory containing speakers data sets') + parser.add_argument('-spks', '--speakers', nargs='*', default=[], + help='speakers to use for training. ' + 'If [], finds all speakers contained in data_set_dir') + parser.add_argument('-lg', '--language', type=str, default='english', + help='spoken language of the speakers that are stored in data_set_dir') + + parser_pre_process = subparsers.add_parser('pre_process', help='pre-process speakers data sets for training') + parser_pre_process.set_defaults(func=pre_process) + parser_pre_process.add_argument('-fd', '--features_dir', type=str, default=f'{os.path.join(PROJECT_ROOT, "datasets")}', + help='path to the directory where pre-processed data sets will be stored') + parser_pre_process.add_argument('-pv', '--proportion_validation', type=float, default=0.1, + help='for each speaker, proportion of examples (%) that will be in the validation set') + parser_pre_process.add_argument('-nj', '--nb_jobs', type=str, default='6', + help='number of cores to use for python multi-processing') + + parser_train = subparsers.add_parser('train', help='train Daft-Exprt on the pre-processed data sets') + parser_train.set_defaults(func=train) + parser_train.add_argument('-chk', '--checkpoint', type=str, default='', + help='checkpoint path to use to restart training at a specific iteration. ' + 'If empty, starts training at iteration 0') + parser_train.add_argument('-nmpd', '--no_multiprocessing_distributed', action='store_true', + help='disable PyTorch multi-processing distributed training') + parser_train.add_argument('-ws', '--world_size', type=int, default=1, + help='number of nodes for distributed training') + parser_train.add_argument('-r', '--rank', type=int, default=0, + help='node rank for distributed training') + parser_train.add_argument('-m', '--master', type=str, default='tcp://localhost:54321', + help='url used to set up distributed training') + + parser_fine_tune = subparsers.add_parser('fine_tune', help='generate data sets with the Daft-Exprt trained model for vocoder fine-tuning') + parser_fine_tune.set_defaults(func=fine_tune) + parser_fine_tune.add_argument('-chk', '--checkpoint', type=str, + help='checkpoint path to use for creating the data set for fine-tuning') + + args = parser.parse_args() + + # create path variables + data_set_dir = args.data_set_dir + output_directory = os.path.join(PROJECT_ROOT, 'trainings', args.experiment_name) + training_files = os.path.join(output_directory, f'train_{args.language}.txt') + validation_files = os.path.join(output_directory, f'validation_{args.language}.txt') + config_file = os.path.join(output_directory, 'config.json') + stats_file = os.path.join(output_directory, 'stats.json') + benchmark_dir = os.path.join(PROJECT_ROOT, 'scripts', 'benchmarks') + + # find all speakers in data_set_dir if not specified in the args + args.speakers = list_all_speakers(data_set_dir) if len(args.speakers) == 0 else args.speakers + + # fill hparams dictionary with mandatory keyword arguments + hparams_kwargs = { + 'training_files': training_files, + 'validation_files': validation_files, + 'output_directory': output_directory, + 'language': args.language, + 'speakers': args.speakers + } + # fill hparams dictionary to overwrite default hyper-param values + hparams_kwargs['checkpoint'] = args.checkpoint if hasattr(args, 'checkpoint') else '' + + # create hyper-params object and save config parameters + hparams = HyperParams(**hparams_kwargs) + hparams.save_hyper_params(config_file) + + # run args + args.func(args) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..afc18bd --- /dev/null +++ b/setup.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +import os + +from setuptools import setup, find_packages + +with open('README.md') as readme_file: + readme = readme_file.read() + +setup( + name='daft_exprt', + author='Julian Zaidi', + author_email='julian.zaidi@ubisoft.com', + description='Package for training and generating speech representations with Daft-Exprt acoustic model.', + url='https://github.com/ubisoft/ubisoft-laforge-daft-exprt', + license='© [2021] Ubisoft Entertainment. All Rights Reserved', + long_description=readme, + classifiers=[ + 'Programming Language :: Python :: 3.8', + 'Operating System :: Linux' + ], + setup_requires=['setuptools_scm'], + python_requires='>=3.8', + install_requires=open(os.path.join('environment', 'pip_requirements.txt')).readlines(), + extras_require={ + 'pytorch': ['torch==1.9.0+cu111', 'torchaudio==0.9.0', 'tensorboard'] + }, + package_dir={'':'src'}, + packages=find_packages('src'), + use_scm_version={ + 'root': '.', + 'relative_to': __file__, + 'version_scheme': 'post-release', + 'local_scheme': 'dirty-tag' + } +) diff --git a/src/daft_exprt/__init__.py b/src/daft_exprt/__init__.py new file mode 100644 index 0000000..3d8e2fb --- /dev/null +++ b/src/daft_exprt/__init__.py @@ -0,0 +1,20 @@ +import os +import platform +import subprocess + + +# check platform +if platform.system() == "Linux": + # REAPER binary + binary_dir = os.path.join(os.path.dirname(__file__), 'bin', 'reaper', 'linux') + os.environ['PATH'] += os.pathsep + binary_dir + # binary requires minimum version for glibc + ldd_version = subprocess.check_output("ldd --version | awk '/ldd/{print $NF}'", shell=True) + ldd_version = float(ldd_version.decode('utf-8').strip()) + if ldd_version < 2.29: + raise Exception(f'REAPER binary -- Unsupported ldd version: {ldd_version} < 2.29') + # make binary executable for all groups + binary_file = os.path.join(binary_dir, 'reaper') + os.chmod(binary_file, 0o0755) +else: + raise Exception(f'Unsupported platform: {os.platform.system()}') diff --git a/src/daft_exprt/bin/reaper/linux/reaper b/src/daft_exprt/bin/reaper/linux/reaper new file mode 100644 index 0000000..fbcbfa1 Binary files /dev/null and b/src/daft_exprt/bin/reaper/linux/reaper differ diff --git a/src/daft_exprt/cleaners.py b/src/daft_exprt/cleaners.py new file mode 100644 index 0000000..5085249 --- /dev/null +++ b/src/daft_exprt/cleaners.py @@ -0,0 +1,148 @@ +import re + +from unidecode import unidecode + +from daft_exprt.normalize_numbers import normalize_numbers + + +''' +Cleaners are transformations that need to be applied to in-the-wild text before it is sent to the acoustic model + +greatly inspired from https://github.com/keithito/tacotron +''' + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def hyphen_remover(text): + text = re.sub('–', ', ', text) + text = re.sub(' -- ', ', ', text) + return re.sub('-', ' ', text) + + +def quote_remover(text): + return re.sub('"', '', text) + + +def parenthesis_remover(text): + return re.sub('\(|\)', '', text) + + +def space_coma_replacer(text): + return re.sub('[\s,]*,+[\s,]*', ', ', text) + + +def incorrect_starting_character_remover(text): + while text.startswith((',', ' ', '.', '!', '?', '-')): + text = text[1:] + return text + + +def apostrophee_formater(text): + return re.sub('’', '\'', text) + + +def dot_coma_replacer(text): + return re.sub(';', ',', text) + + +def double_dot_replacer(text): + return re.sub(':', ',', text) + + +def underscore_replacer(text): + return re.sub('_', ' ', text) + + +def triple_dot_replacer(text): + text = re.sub('…', '.', text) + return re.sub('[\s\.]*\.+[\s\.]*', '. ', text) + + +def multiple_punctuation_fixer(text): + text = re.sub('[\s\.,?!]*\?+[\s\.,?!]*', '? ', text) + text = re.sub('[\s\.,!]*\!+[\s\.,!]*', '! ', text) + return re.sub('[\s\.,]*\.+[\s\.,]*', '. ', text) + + +def english_cleaners(text): + ''' pipeline for English text, including number and abbreviation expansion + + :param text: sentence to process + ''' + # convert to regular english letters in lowercase. + text = convert_to_ascii(text) + text = lowercase(text) + + # replace all abbreviations and numbers with text + text = expand_numbers(text) + text = expand_abbreviations(text) + + # deal with punctuation + text = hyphen_remover(text) + text = quote_remover(text) + text = dot_coma_replacer(text) # replace by a coma + text = double_dot_replacer(text) # replace by a coma + text = triple_dot_replacer(text) # replace by a coma + text = apostrophee_formater(text) + text = parenthesis_remover(text) + text = space_coma_replacer(text) + text = underscore_replacer(text) + text = collapse_whitespace(text) + text = incorrect_starting_character_remover(text) + text = multiple_punctuation_fixer(text) + text = text.strip() + + return text + + +def text_cleaner(text, lang='english'): + if lang.lower() == 'english': + text = english_cleaners(text) + + return text diff --git a/src/daft_exprt/create_sets.py b/src/daft_exprt/create_sets.py new file mode 100644 index 0000000..904a50a --- /dev/null +++ b/src/daft_exprt/create_sets.py @@ -0,0 +1,55 @@ +import logging +import os + + +_logger = logging.getLogger(__name__) + + +def create_sets(features_dir, hparams, proportion_validation=0.1): + ''' Create train and validation sets, for all specified speakers + + :param features_dir: directory containing all the speakers features files + :param hparams: hyper-parameters used for pre-processing + :param proportion_validation: for each speaker, proportion of examples (%) that will be in the validation set + ''' + # create directory where extracted train/validation sets will be saved + os.makedirs(os.path.dirname(hparams.training_files), exist_ok=True) + os.makedirs(os.path.dirname(hparams.validation_files), exist_ok=True) + # create train/validation text files + file_training = open(hparams.training_files, 'w', encoding='utf-8') + file_validation = open(hparams.validation_files, 'w', encoding='utf-8') + + # iterate over speakers + _logger.info('--' * 30) + _logger.info('Creating training and validation sets'.upper()) + _logger.info('--' * 30) + for speaker, speaker_id in zip(hparams.speakers, hparams.speakers_id): + _logger.info(f'Speaker: "{speaker}" -- ID: {speaker_id} -- Validation files: {proportion_validation}%') + # check metadata file exists + spk_features_dir = os.path.join(features_dir, speaker) + metadata = os.path.join(spk_features_dir, 'metadata.csv') + # read metadata lines + with open(metadata, 'r', encoding='utf-8') as f: + lines = f.readlines() + lines = [x.strip().split(sep='|') for x in lines] # [[file_name, text], ...] + # get available features files for training + # some metadata files might miss because there was no .markers associated to the file + file_names = [line[0].strip() for line in lines] + features_files = [x for x in file_names if os.path.isfile(os.path.join(spk_features_dir, f'{x}.npy'))] + nb_feats_files = len(features_files) + + ctr = 0 + validation_ctr = 0 + for feature_file in features_files: + # store the line + ctr += 1 + new_line = f'{spk_features_dir}|{feature_file}|{speaker_id}\n' + if ctr % int(100 / proportion_validation) == 0 or (ctr == nb_feats_files and validation_ctr == 0): + file_validation.write(new_line) + validation_ctr += 1 + else: + file_training.write(new_line) + _logger.info('') + + file_training.close() + file_validation.close() diff --git a/src/daft_exprt/data_loader.py b/src/daft_exprt/data_loader.py new file mode 100644 index 0000000..f299c2d --- /dev/null +++ b/src/daft_exprt/data_loader.py @@ -0,0 +1,243 @@ +import os +import random + +import numpy as np +import torch + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + + +class DaftExprtDataLoader(Dataset): + ''' Load PyTorch Data Set + 1) load features, symbols and speaker ID + 2) convert symbols to sequence of one-hot vectors + ''' + def __init__(self, data_file, hparams, shuffle=True): + # check data file exists and extract lines + assert(os.path.isfile(data_file)) + with open(data_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + self.data = [line.strip().split(sep='|') for line in lines] + self.hparams = hparams + + # shuffle + if shuffle: + random.seed(hparams.seed) + random.shuffle(self.data) + + def get_mel_spec(self, mel_spec): + ''' Extract PyTorch float tensor from .npy mel-spec file + ''' + # transform to PyTorch tensor and check size + mel_spec = torch.from_numpy(np.load(mel_spec)) + assert(mel_spec.size(0) == self.hparams.n_mel_channels) + + return mel_spec + + def get_symbols_and_durations(self, markers): + ''' Extract PyTorch int tensor from an input symbols sequence + Extract PyTorch float and int duration for each symbol + ''' + # initialize variables + symbols, durations_float, durations_int = [], [], [] + + # read lines of markers file + with open(markers, 'r', encoding='utf-8') as f: + lines = f.readlines() + markers = [line.strip().split(sep='\t') for line in lines] + + # iterate over markers + for marker in markers: + begin, end, int_dur, symbol, _, _ = marker + symbols.append(self.hparams.symbols.index(symbol)) + durations_float.append(float(end) - float(begin)) + durations_int.append(int(int_dur)) + + # convert lists to PyTorch tensors + symbols = torch.IntTensor(symbols) + durations_float = torch.FloatTensor(durations_float) + durations_int = torch.IntTensor(durations_int) + + return symbols, durations_float, durations_int + + def get_energies(self, energies, speaker_id, normalize=True): + ''' Extract standardized PyTorch float tensor for energies + ''' + # read energy lines + with open(energies, 'r', encoding='utf-8') as f: + lines = f.readlines() + energies = np.array([float(line.strip()) for line in lines]) + # standardize energies based on speaker stats + if normalize: + zero_idxs = np.where(energies == 0.)[0] + energies -= self.hparams.stats[f'spk {speaker_id}']['energy']['mean'] + energies /= self.hparams.stats[f'spk {speaker_id}']['energy']['std'] + energies[zero_idxs] = 0. + # convert to PyTorch float tensor + energies = torch.FloatTensor(energies) + + return energies + + def get_pitch(self, pitch, speaker_id, normalize=True): + ''' Extract standardized PyTorch float tensor for pitch + ''' + # read pitch lines + with open(pitch, 'r', encoding='utf-8') as f: + lines = f.readlines() + pitch = np.array([float(line.strip()) for line in lines]) + # standardize voiced pitch based on speaker stats + if normalize: + zero_idxs = np.where(pitch == 0.)[0] + pitch -= self.hparams.stats[f'spk {speaker_id}']['pitch']['mean'] + pitch /= self.hparams.stats[f'spk {speaker_id}']['pitch']['std'] + pitch[zero_idxs] = 0. + # convert to PyTorch float tensor + pitch = torch.FloatTensor(pitch) + + return pitch + + def get_data(self, data): + ''' Extract features, symbols and speaker ID + ''' + # get mel-spec path, markers path, pitch path and speaker ID + features_dir = data[0] + feature_file = data[1] + speaker_id = int(data[2]) + + mel_spec = os.path.join(features_dir, f'{feature_file}.npy') + markers = os.path.join(features_dir, f'{feature_file}.markers') + symbols_energy = os.path.join(features_dir, f'{feature_file}.symbols_nrg') + frames_energy = os.path.join(features_dir, f'{feature_file}.frames_nrg') + symbols_pitch = os.path.join(features_dir, f'{feature_file}.symbols_f0') + frames_pitch = os.path.join(features_dir, f'{feature_file}.frames_f0') + + # extract data + mel_spec = self.get_mel_spec(mel_spec) + symbols, durations_float, durations_int = self.get_symbols_and_durations(markers) + symbols_energy = self.get_energies(symbols_energy, speaker_id) + frames_energy = self.get_energies(frames_energy, speaker_id, normalize=False) + symbols_pitch = self.get_pitch(symbols_pitch, speaker_id) + frames_pitch = self.get_pitch(frames_pitch, speaker_id, normalize=False) + + # check everything is correct with sizes + assert(len(symbols_energy) == len(symbols)) + assert(len(symbols_pitch) == len(symbols)) + assert(len(frames_energy) == mel_spec.size(1)) + assert(len(frames_pitch) == mel_spec.size(1)) + assert(torch.sum(durations_int) == mel_spec.size(1)) + + return symbols, durations_float, durations_int, symbols_energy, symbols_pitch, \ + frames_energy, frames_pitch, mel_spec, speaker_id, features_dir, feature_file + + def __getitem__(self, index): + return self.get_data(self.data[index]) + + def __len__(self): + return len(self.data) + + +class DaftExprtDataCollate(): + ''' Zero-pads model inputs and targets + ''' + def __init__(self, hparams): + self.hparams = hparams + + def __call__(self, batch): + ''' Collate training batch + + :param batch: [[symbols, durations_float, durations_int, symbols_energy, symbols_pitch, + frames_energy, frames_pitch, mel_spec, speaker_id, features_dir, feature_file], ...] + + :return: collated batch of training samples + ''' + # find symbols sequence max length + input_lengths, ids_sorted_decreasing = \ + torch.sort(torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True) + max_input_len = input_lengths[0] + + # right zero-pad sequences to max input length + symbols = torch.LongTensor(len(batch), max_input_len).zero_() + durations_float = torch.FloatTensor(len(batch), max_input_len).zero_() + durations_int = torch.LongTensor(len(batch), max_input_len).zero_() + symbols_energy = torch.FloatTensor(len(batch), max_input_len).zero_() + symbols_pitch = torch.FloatTensor(len(batch), max_input_len).zero_() + speaker_ids = torch.LongTensor(len(batch)) + + for i in range(len(ids_sorted_decreasing)): + # extract batch sequences + symbols_seq = batch[ids_sorted_decreasing[i]][0] + dur_float_seq = batch[ids_sorted_decreasing[i]][1] + dur_int_seq = batch[ids_sorted_decreasing[i]][2] + symbols_energy_seq = batch[ids_sorted_decreasing[i]][3] + symbols_pitch_seq = batch[ids_sorted_decreasing[i]][4] + # fill padded arrays + symbols[i, :symbols_seq.size(0)] = symbols_seq + durations_float[i, :dur_float_seq.size(0)] = dur_float_seq + durations_int[i, :dur_int_seq.size(0)] = dur_int_seq + symbols_energy[i, :symbols_energy_seq.size(0)] = symbols_energy_seq + symbols_pitch[i, :symbols_pitch_seq.size(0)] = symbols_pitch_seq + # add corresponding speaker ID + speaker_ids[i] = batch[ids_sorted_decreasing[i]][8] + + # find mel-spec max length + max_output_len = max([x[7].size(1) for x in batch]) + + # right zero-pad mel-specs to max output length + frames_energy = torch.FloatTensor(len(batch), max_output_len).zero_() + frames_pitch = torch.FloatTensor(len(batch), max_output_len).zero_() + mel_specs = torch.FloatTensor(len(batch), self.hparams.n_mel_channels, max_output_len).zero_() + output_lengths = torch.LongTensor(len(batch)) + + for i in range(len(ids_sorted_decreasing)): + # extract batch sequences + frames_energy_seq = batch[ids_sorted_decreasing[i]][5] + frames_pitch_seq = batch[ids_sorted_decreasing[i]][6] + mel_spec = batch[ids_sorted_decreasing[i]][7] + # fill padded arrays + frames_energy[i, :frames_energy_seq.size(0)] = frames_energy_seq + frames_pitch[i, :frames_pitch_seq.size(0)] = frames_pitch_seq + mel_specs[i, :, :mel_spec.size(1)] = mel_spec + output_lengths[i] = mel_spec.size(1) + + # store file identification + # only used in fine_tune.py script + feature_dirs, feature_files = [], [] + for i in range(len(ids_sorted_decreasing)): + feature_dirs.append(batch[ids_sorted_decreasing[i]][9]) + feature_files.append(batch[ids_sorted_decreasing[i]][10]) + + return symbols, durations_float, durations_int, symbols_energy, symbols_pitch, input_lengths, \ + frames_energy, frames_pitch, mel_specs, output_lengths, speaker_ids, feature_dirs, feature_files + + +def prepare_data_loaders(hparams, num_workers=1, drop_last=True): + ''' Initialize train and validation Data Loaders + + :param hparams: hyper-parameters used for training + :param num_workers: number of workers involved in the Data Loader + + :return: Data Loaders for train and validation sets + ''' + # get data and collate function ready + train_set = DaftExprtDataLoader(hparams.training_files, hparams) + val_set = DaftExprtDataLoader(hparams.validation_files, hparams) + collate_fn = DaftExprtDataCollate(hparams) + + # get number of training examples + nb_training_examples = len(train_set) + + # use distributed sampler if we use distributed training + if hparams.multiprocessing_distributed: + train_sampler = DistributedSampler(train_set, shuffle=False) + else: + train_sampler = None + + # build training and validation data loaders + # drop_last=True because we shuffle data set at each epoch + train_loader = DataLoader(train_set, num_workers=num_workers, shuffle=(train_sampler is None), sampler=train_sampler, + batch_size=hparams.batch_size, pin_memory=True, drop_last=drop_last, collate_fn=collate_fn) + val_loader = DataLoader(val_set, num_workers=num_workers, shuffle=False, batch_size=hparams.batch_size, + pin_memory=True, drop_last=False, collate_fn=collate_fn) + + return train_loader, train_sampler, val_loader, nb_training_examples diff --git a/src/daft_exprt/extract_features.py b/src/daft_exprt/extract_features.py new file mode 100644 index 0000000..77936e8 --- /dev/null +++ b/src/daft_exprt/extract_features.py @@ -0,0 +1,553 @@ +import json +import logging +import logging.handlers +import os +import re +import subprocess +import types +import uuid + +import librosa +import numpy as np +import torch + +from shutil import rmtree + +from librosa.filters import mel as librosa_mel_fn +from scipy.io import wavfile + +from daft_exprt.symbols import ascii, eos, punctuation, SIL_WORD_SYMBOL, whitespace +from daft_exprt.utils import launch_multi_process + + +_logger = logging.getLogger(__name__) +FILE_ROOT = os.path.dirname(os.path.realpath(__file__)) +TMP_DIR = os.path.join(FILE_ROOT, 'tmp') +FEATURES_HPARAMS = ['centered', 'cutoff', 'f0_interval', 'filter_length', 'hop_length', + 'language', 'mel_fmax', 'mel_fmin', 'min_clipping', 'max_f0', 'min_f0', + 'n_mel_channels', 'order', 'sampling_rate', 'symbols', 'uv_cost', 'uv_interval'] + + +def check_features_config_used(features_dir, hparams): + ''' Check current config is the same than the one used in features directory + ''' + # hyper-params that are important for feature extraction + same_config = True + for root, _, file_names in os.walk(os.path.normpath(features_dir)): + # extract config files + configs = [x for x in file_names if x.endswith('.json')] + if len(configs) != 0: + # get previous config + with open(os.path.join(root, configs[0])) as f: + data = f.read() + config = json.loads(data) + hparams_prev = types.SimpleNamespace(**config) + # compare params + for param in FEATURES_HPARAMS: + if getattr(hparams, param) != getattr(hparams_prev, param): + same_config = False + _logger.warning(f'Parameter "{param}" is different in "{root}" -- ' + f'Was {getattr(hparams_prev, param)} and now is {getattr(hparams, param)}') + + return same_config + + +def get_min_phone_duration(lines, min_phone_dur=1000.): + ''' Extract shortest phone duration in the current .markers file + ''' + # iterate over phones + for line in lines: + line = line.strip().split(sep='\t') + # extract phone duration + begin, end = float(line[0]), float(line[1]) + if end - begin < min_phone_dur: + min_phone_dur = end - begin + + return min_phone_dur + + +def duration_to_integer(float_durations, hparams, nb_samples=None): + ''' Convert phoneme float durations to integer frame durations + ''' + # estimate number of samples in audio + if nb_samples is None: + # get total duration of audio + # float_durations = [[phone_begin, phone_end], ...] + total_duration = sum([(x[1] - x[0]) for x in float_durations]) + # convert in number of samples + nb_samples = int(total_duration * hparams.sampling_rate) + # get nb spectrogram frames + # ignore padding for the moment + nb_frames = 1 + int((nb_samples - hparams.filter_length) / hparams.hop_length) + # get spectrogram frames index + frames_idx = [int(hparams.filter_length / 2) + hparams.hop_length * i for i in range(nb_frames)] + + # compute number of frames per phoneme + curr_frame = 1 + int_durations = [] + while curr_frame <= nb_frames: + # extract phoneme duration + begin, end = float_durations.pop(0) + if begin != end: + # convert to sample idx + begin, end = int(begin * hparams.sampling_rate), int(end * hparams.sampling_rate) + # get corresponding frames + nb_phone_frames = len([idx for idx in frames_idx if begin < idx <= end]) + int_durations.append(nb_phone_frames) + curr_frame += nb_phone_frames + else: # we should not have 0 durations + raise ValueError + # add edge frames if padding is on + if hparams.centered: + nb_edge_frames = int(hparams.filter_length / 2 / hparams.hop_length) + # left padding + int_durations[0] += nb_edge_frames + # right padding + if len(float_durations) != 0: # correspond to last phoneme + int_durations.append(nb_edge_frames) + else: + int_durations[-1] += nb_edge_frames + + return int_durations + + +def update_markers(file_name, lines, sentence, sent_begin, int_durations, hparams, logger): + ''' Update markers: + - change timings to start from 0 + - add punctuation or whitespace at word boundaries + - add EOS token at end of sentence + - add int durations + ''' + # characters to consider in the sentence + if hparams.language == 'english': + all_chars = ascii + punctuation + else: + raise NotImplementedError() + + ''' + match words in the sentence with the ones in markers lines + Sentence: ,THAT's, an example'! ' of a sentence. . .' + Markers words: that s an example of a sentence + ''' + # split sentence: + # [',', "that's", ',', 'an', "example'", '!', "'", 'of', 'a', 'sentence', '.', '.', '.', "'"] + sent_words = re.findall(f"[\w']+|[{punctuation}]", sentence.lower().strip()) + # remove characters that are not letters or punctuation: + # [',', "that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence', '.', '.', '.'] + sent_words = [x for x in sent_words if len(re.sub(f'[^{all_chars}]', '', x)) != 0] + # be sure to begin the sentence with a word and not a punctuation + # ["that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence', '.', '.', '.'] + while sent_words[0] in punctuation: + sent_words.pop(0) + # keep only one punctuation type at the end + # ["that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence'] + punctuation_end = None + while sent_words[-1] in punctuation: + punctuation_end = sent_words.pop(-1) + + # split markers lines -- [[begin, end, phone, word, word_idx], ....] + markers = [line.strip().split(sep='\t') for line in lines] + # extract markers words + # they are no '' at beginning and end of sentence because we trimmed the audio + # ['that', 's', 'an', example'', '', 'of', 'a', 'sentence'] + words_idx = [marker[4] for marker in markers] + lines_idx = [words_idx.index(word_idx) for word_idx in list(dict.fromkeys(words_idx).keys())] + marker_words = [markers[line_idx][3] for line_idx in lines_idx] + + # update markers with word boundaries + sent_words_copy, markers_old = sent_words.copy(), markers.copy() + markers, word_idx, word_error = [], 0, False + while len(sent_words) != 0: + # extract word in .lab sentence and .markers file + sent_word = sent_words.pop(0) + marker_word, marker_word_idx = markers_old[0][3], markers_old[0][4] + if marker_word != sent_word: + # we should have the same words + # generally the issue comes from the symbol ' + # e.g. example' vs example or that's vs [that, s] + regex_word = re.findall(f"[\w]+|[{punctuation}]", sent_word) + if len(regex_word) == 1: # ['example'] + sent_word = regex_word[0] + else: # ['that', 's'] + sent_words = regex_word + sent_words + sent_word = sent_words.pop(0) + if marker_word != sent_word: + # cannot fix the mismatch between words + word_error = True + logger.warning(f'Correspondance issue between words in the .lab sentence and those in .markers file -- ' + f'File name: {file_name} -- Sentence: {sent_words_copy} -- ' + f'Markers: {marker_words} -- Problematic words: {sent_word} -- {marker_word}') + break + # retrieve all markers lines that correspond to the word + while len(markers_old) != 0 and markers_old[0][4] == marker_word_idx: + begin, end, phone, word, _ = markers_old.pop(0) + begin = f'{float(begin) - sent_begin:.3f}' + end = f'{float(end) - sent_begin:.3f}' + int_dur = str(int_durations.pop(0)) + markers.append([begin, end, int_dur, phone, word, str(word_idx)]) + # at this point we pass to the next word + # we must add a word boundary between two consecutive words + word_idx += 1 + if len(sent_words) != 0: + word_bound = sent_words.pop(0) if sent_words[0] in punctuation else whitespace + # check if a silence marker is associated to the word boundary + if markers_old[0][3] == SIL_WORD_SYMBOL: + begin, end, _, _, _ = markers_old.pop(0) + begin = f'{float(begin) - sent_begin:.3f}' + end = f'{float(end) - sent_begin:.3f}' + int_dur = str(int_durations.pop(0)) + markers.append([begin, end, int_dur, word_bound, word_bound, str(word_idx)]) + else: + end_prev = markers[-1][1] + markers.append([end_prev, end_prev, str(0), word_bound, word_bound, str(word_idx)]) + word_idx += 1 + + if not word_error: + # add end punctuation if there is one + if punctuation_end is not None: + end_prev = markers[-1][1] + markers.append([end_prev, end_prev, str(0), punctuation_end, punctuation_end, str(word_idx)]) + word_idx += 1 + # add EOS token + end_prev = markers[-1][1] + markers.append([end_prev, end_prev, str(0), eos, eos, str(word_idx)]) + # check everything is correct + assert(len(sent_words) == len(markers_old) == len(int_durations) == 0), \ + logger.error(f'File name: {file_name} -- length mismatch between lists: ({sent_words}, {markers_old}, {int_durations})') + return markers + else: + return None + + +def extract_pitch(wav, fs, hparams): + ''' Extract pitch frames from audio using REAPER binary + Convert pitch to log scale and set unvoiced values to 0. + ''' + # REAPER asks for int16 audios + # audio is in float32 + wav = wav * 32768.0 + wav = wav.astype('int16') + # save audio file locally + rand_name = str(uuid.uuid4()) + out_dir = os.path.join(TMP_DIR, 'reaper') + os.makedirs(out_dir, exist_ok=True) + wav_file = os.path.join(out_dir, f'{rand_name}.wav') + wavfile.write(wav_file, fs, wav) + + # extract pitch values + f0_file = wav_file.replace('.wav', '.f0') + process = ['reaper', '-i', f'{wav_file}', + '-a', '-f', f'{f0_file}', + '-e', f'{hparams.f0_interval}', + '-m', f'{hparams.min_f0}', + '-x', f'{hparams.max_f0}', + '-u', f'{hparams.uv_interval}', + '-w', f'{hparams.uv_cost}'] + with open(os.devnull, 'wb') as devnull: + subprocess.check_call(process, stdout=devnull, stderr=subprocess.STDOUT) + # read PCM file + with open(f0_file, 'rb') as f: + buf = f.read() + pitch = np.frombuffer(buf, dtype='int16') + # extract unvoiced indexes + pitch = np.copy(pitch) + uv_idxs = np.where(pitch <= 0.)[0] + # put to log scale + pitch[uv_idxs] = 1000. + pitch = np.log(pitch) + # set unvoiced values to 0. + pitch[uv_idxs] = 0. + # extract pitch for each mel-spec frame + pitch_frames = pitch[::hparams.hop_length] + # edge case + if len(pitch) % hparams.hop_length == 0: + pitch_frames = np.append(pitch_frames, pitch[-1]) + # delete files + os.remove(wav_file) + os.remove(f0_file) + + return pitch_frames + + +def get_symbols_pitch(pitch, markers): + ''' Compute mean pitch per symbol + + pitch = NumPy array of shape (nb_mel_spec_frames, ) + markers = [[begin, end, int_dur, symbol, word, word_idx], ...] + ''' + idx = 0 + symbols_pitch = [] + for marker in markers: + # number of mel-spec frames assigned to the symbol + int_dur = int(marker[2]) + if int_dur != 0: + # ignore unvoiced values + symbol_pitch = pitch[idx: idx + int_dur] + symbol_pitch = symbol_pitch[symbol_pitch > 0.] + # compute mean pitch for voiced values + if len(symbol_pitch) != 0: + symbols_pitch.append(f'{np.mean(symbol_pitch):.3f}\n') + else: + symbols_pitch.append(f'{0.:.3f}\n') + idx += int_dur + else: + symbols_pitch.append(f'{0.:.3f}\n') + + return symbols_pitch + + +def extract_energy(mel_spec): + ''' Extract energy of each mel-spec frame + mel_spec = NumPy array of shape (nb_mel_spec_channels, nb_mel_spec_frames) + ''' + energy = np.linalg.norm(mel_spec, axis=0) + return energy + + +def get_symbols_energy(energy, markers): + ''' Compute mean energy per symbol + + energy = NumPy array of shape (nb_mel_spec_frames, ) + markers = [[begin, end, int_dur, symbol, word, word_idx], ...] + ''' + idx = 0 + symbols_energy = [] + for marker in markers: + # number of mel-spec frames assigned to the symbol + int_dur = int(marker[2]) + if int_dur != 0: + # compute mean energy + symbol_energy = energy[idx: idx + int_dur] + symbol_energy = np.mean(symbol_energy) + symbols_energy.append(f'{symbol_energy:.3f}\n') + idx += int_dur + else: + symbols_energy.append(f'{0.:.3f}\n') + + return symbols_energy + + +def mel_spectrogram_HiFi(wav, hparams): + ''' Mel-Spectrogram extraction as it is performed by HiFi-GAN + ''' + # convert to PyTorch float tensor + wav = torch.FloatTensor(wav) # (T, ) + # extract hparams + fmin = hparams.mel_fmin + fmax = hparams.mel_fmax + center = hparams.centered + hop_size = hparams.hop_length + n_fft = hparams.filter_length + num_mels = hparams.n_mel_channels + sampling_rate = hparams.sampling_rate + min_clipping = hparams.min_clipping + # get mel filter bank + mel_filter_bank = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) # (n_mels, 1 + n_fft/2) + mel_filter_bank = torch.from_numpy(mel_filter_bank).float() # (n_mels, 1 + n_fft/2) + # build hann window + hann_window = torch.hann_window(n_fft) + # extract amplitude spectrogram + spec = torch.stft(wav, n_fft, hop_length=hop_size, win_length=n_fft, window=hann_window, + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + # convert to mels and pass to log + mel_spec = torch.matmul(mel_filter_bank, spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=min_clipping)) + # transform to numpy array + mel_spec = mel_spec.squeeze().numpy() + + return mel_spec + + +def rescale_wav_to_float32(x): + ''' Rescale audio array between -1.f and 1.f based on the current format + ''' + # convert + if x.dtype == 'int16': + y = x / 32768.0 + elif x.dtype == 'int32': + y = x / 2147483648.0 + elif x.dtype == 'uint8': + y = ((x / 255.0) - 0.5)*2 + elif x.dtype == 'float32' or x.dtype == 'float64': + y = x + else: + raise TypeError(f"could not normalize wav, unsupported sample type {x.dtype}") + # check amplitude is correct + y = y.astype('float32') + max_ampl = np.max(np.abs(y)) + if max_ampl > 1.0: + pass # the error should be raised but librosa returns values bigger than 1 sometimes + # raise ValueError(f'float32 wav contains samples not in the range [-1., 1.] -- ' + # f'max amplitude: {max_ampl}') + + return y + + +def _extract_features(files, features_dir, hparams, log_queue): + ''' Extract mel-spectrogram and markers with int duration + ''' + # create logger from logging queue + qh = logging.handlers.QueueHandler(log_queue) + root = logging.getLogger() + if not root.hasHandlers(): + root.setLevel(logging.INFO) + root.addHandler(qh) + logger = logging.getLogger(f"worker{str(uuid.uuid4())}") + + # check files exist + markers_file, wav_file = files + assert(os.path.isfile(markers_file)), logger.error(f'There is no such file: {markers_file}') + assert(os.path.isfile(wav_file)), logger.error(f'There is no such file: {wav_file}') + # read markers lines + with open(markers_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # check min phone duration is coherent + # min phone duration must be >= filter_length // 2 + # in order to have at least one mel-spec frame attributed to the phone + min_phone_dur = get_min_phone_duration(lines) + fft_length = hparams.filter_length / hparams.sampling_rate + assert(min_phone_dur > fft_length / 2), \ + logger.error(f'Min phone duration = {min_phone_dur} -- filter_length / 2 = {fft_length / 2}') + + # extract sentence duration + # leading and tailing silences have been removed in markers.py script + sent_begin = float(lines[0].strip().split(sep='\t')[0]) + sent_end = float(lines[-1].strip().split(sep='\t')[1]) + sent_dur = sent_end - sent_begin + + # ignore audio if length is inferior to min wav duration + if sent_dur >= hparams.minimum_wav_duration / 1000: + # read wav file to range [-1, 1] in np.float32 + wav, fs = librosa.load(wav_file, sr=hparams.sampling_rate) + wav = rescale_wav_to_float32(wav) + # remove leading and tailing silences + wav = wav[int(sent_begin * fs): int(sent_end * fs)] + + # extract mel-spectrogram + mel_spec = mel_spectrogram_HiFi(wav, hparams) + # get number of mel-spec frames + nb_mel_spec_frames = mel_spec.shape[1] + + # convert phoneme durations to integer frame durations + float_durations = [[float(x[0]) - sent_begin, float(x[1]) - sent_begin] + for x in [line.strip().split(sep='\t') for line in lines]] + int_durations = duration_to_integer(float_durations, hparams, nb_samples=len(wav)) + assert(len(int_durations) == len(lines)), logger.error(f'{markers_file} -- ({len(int_durations)}, {len(lines)})') + assert(sum(int_durations) == nb_mel_spec_frames), logger.error(f'{markers_file} -- ({sum(int_durations)}, {nb_mel_spec_frames})') + assert(0 not in int_durations), logger.error(f'{markers_file} -- {int_durations}') + + # update markers: + # change timings to start from 0 + # add punctuation or whitespace at word boundaries + # add EOS token at end of sentence + # add int durations + markers_dir = os.path.dirname(markers_file) + file_name = os.path.basename(markers_file).replace('.markers', '') + sentence_file = os.path.join(markers_dir, f'{file_name}.lab') + assert(os.path.isfile(sentence_file)), logger.error(f'There is no such file: {sentence_file}') + with open(sentence_file, 'r', encoding='utf-8') as f: + sentence = f.readline() + markers = update_markers(file_name, lines, sentence, sent_begin, int_durations, hparams, logger) + + if markers is not None: + # save mel-spectrogram -- (n_mel_channels, T) + np.save(os.path.join(features_dir, f'{file_name}.npy'), mel_spec) + + # save markers + # each line has the format: [begin, end, int_dur, symbol, word, word_idx] + markers_file = os.path.join(features_dir, f'{file_name}.markers') + with open(markers_file, 'w', encoding='utf-8') as f: + f.writelines(['\t'.join(x) + '\n' for x in markers]) + + # extract energy for each mel-spec frame + mel_spec = np.exp(mel_spec) # remove log + frames_energy = extract_energy(mel_spec) + # save frames energy values + energy_file = os.path.join(features_dir, f'{file_name}.frames_nrg') + with open(energy_file, 'w', encoding='utf-8') as f: + for val in frames_energy: + f.write(f'{val:.3f}\n') + # extract energy on the symbol level + # we use average energy value per symbol + symbols_energy = get_symbols_energy(frames_energy, markers) + # save symbols energy + energy_file = os.path.join(features_dir, f'{file_name}.symbols_nrg') + with open(energy_file, 'w', encoding='utf-8') as f: + f.writelines(symbols_energy) + + # extract log pitch for each mel-spec frame + frames_pitch = extract_pitch(wav, fs, hparams) + assert(len(frames_pitch) == nb_mel_spec_frames), logger.error(f'{markers_file} -- ({len(frames_pitch)}, {nb_mel_spec_frames})') + # save frames pitch values + pitch_file = os.path.join(features_dir, f'{file_name}.frames_f0') + with open(pitch_file, 'w', encoding='utf-8') as f: + for val in frames_pitch: + f.write(f'{val:.3f}\n') + # extract pitch on the symbol level + # we use average pitch value per symbol + symbols_pitch = get_symbols_pitch(frames_pitch, markers) + # save symbols pitch values + pitch_file = os.path.join(features_dir, f'{file_name}.symbols_f0') + with open(pitch_file, 'w', encoding='utf-8') as f: + f.writelines(symbols_pitch) + else: + logger.warning(f'Ignoring {wav_file} -- audio has length inferior to {hparams.minimum_wav_duration / 1000}s after trimming') + + +def get_files_for_features_extraction(line, markers_dir, log_queue): + ''' Return file name if .markers file exists + ''' + # check if markers file exist for the corresponding line + line = line.strip().split(sep='|') # [file_name, text] + file_name = line[0].strip() + markers = os.path.join(markers_dir, f'{file_name}.markers') + if os.path.isfile(markers): + return file_name + else: + return None + + +def extract_features(dataset_dir, features_dir, hparams, n_jobs): + ''' Extract features for training + ''' + # iterate over speakers + _logger.info('--' * 30) + _logger.info('Extracting Features'.upper()) + _logger.info('--' * 30) + for speaker in hparams.speakers: + _logger.info(f'Speaker: "{speaker}"') + # check wavs and markers dir exist + wavs_dir = os.path.join(dataset_dir, speaker, 'wavs') + markers_dir = os.path.join(dataset_dir, speaker, 'align') + assert(os.path.isdir(wavs_dir)), _logger.error(f'There is no such directory: {wavs_dir}') + assert(os.path.isdir(markers_dir)), _logger.error(f'There is no such directory: {markers_dir}') + # check metadata file exist + spk_features_dir = os.path.join(features_dir, speaker) + metadata = os.path.join(spk_features_dir, 'metadata.csv') + assert(os.path.isfile(metadata)), _logger.error(f'There is no such file: {metadata}') + + # get all files that can be used for features extraction + with open(metadata, 'r', encoding='utf-8') as f: + lines = f.readlines() + file_names = launch_multi_process(iterable=lines, func=get_files_for_features_extraction, + n_jobs=n_jobs, markers_dir=markers_dir, timer_verbose=False) + file_names = [x for x in file_names if x is not None] + + # check current files that exist in features dir + # avoid to process files that already have been processed in a previous features extraction + curr_files = [x.replace('.symbols_f0', '').strip() for x in os.listdir(spk_features_dir) if x.endswith('.symbols_f0')] + missing_files = [x for x in file_names if x not in curr_files] + _logger.info(f'{len(curr_files)} files already processed. {len(missing_files)} new files need to be processed') + + # extract features + files = [(os.path.join(markers_dir, f'{x}.markers'), os.path.join(wavs_dir, f'{x}.wav')) for x in missing_files] + launch_multi_process(iterable=files, func=_extract_features, n_jobs=n_jobs, + features_dir=spk_features_dir, hparams=hparams) + + # save config used to perform features extraction + hparams.save_hyper_params(os.path.join(spk_features_dir, 'config.json')) + _logger.info('') + # remove tmp directory + rmtree(TMP_DIR, ignore_errors=True) diff --git a/src/daft_exprt/features_stats.py b/src/daft_exprt/features_stats.py new file mode 100644 index 0000000..acee8ff --- /dev/null +++ b/src/daft_exprt/features_stats.py @@ -0,0 +1,165 @@ +import collections +import logging +import logging.handlers +import os +import uuid + +import numpy as np + +from daft_exprt.utils import launch_multi_process + + +_logger = logging.getLogger(__name__) + + +def get_symbols_durations(markers_file, hparams, log_queue): + ''' extract symbols durations in markers file + ''' + # create logger from logging queue + qh = logging.handlers.QueueHandler(log_queue) + root = logging.getLogger() + if not root.hasHandlers(): + root.setLevel(logging.INFO) + root.addHandler(qh) + logger = logging.getLogger(f"worker{str(uuid.uuid4())}") + + # check file exists + assert(os.path.isfile(markers_file)), logger.error(f'There is no such file "{markers_file}"') + # read markers lines + with open(markers_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + markers = [line.strip().split(sep='\t') for line in lines] # [[begin, end, nb_frames, symbol, word, word_idx], ...] + + # extract duration for each symbol that is in markers + symbols_durations = [] + for marker in markers: + begin, end, _, symbol, _, _ = marker + assert(symbol in hparams.symbols), logger.error(f'{markers_file} -- Symbol "{symbol}" does not exist') + begin, end = float(begin), float(end) + symbols_durations.append([symbol, end - begin]) + + return symbols_durations + + +def get_non_zero_energy_values(energy_file, log_queue): + ''' Extract non-zero energy values in energy file + ''' + # create logger from logging queue + qh = logging.handlers.QueueHandler(log_queue) + root = logging.getLogger() + if not root.hasHandlers(): + root.setLevel(logging.INFO) + root.addHandler(qh) + logger = logging.getLogger(f"worker{str(uuid.uuid4())}") + + # check file exists + assert(os.path.isfile(energy_file)), logger.error(f'There is no such file "{energy_file}"') + # read energy lines + with open(energy_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + energy_vals = [float(line.strip()) for line in lines] + # remove non-zero energy values + energy_vals = list(filter(lambda a: a != 0., energy_vals)) + + return energy_vals + + +def get_voiced_pitch_values(pitch_file, log_queue): + ''' Extract voiced pitch values in pitch file + ''' + # create logger from logging queue + qh = logging.handlers.QueueHandler(log_queue) + root = logging.getLogger() + if not root.hasHandlers(): + root.setLevel(logging.INFO) + root.addHandler(qh) + logger = logging.getLogger(f"worker{str(uuid.uuid4())}") + + # check file exists + assert(os.path.isfile(pitch_file)), logger.error(f'There is no such file "{pitch_file}"') + # read pitch lines + with open(pitch_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + pitch_vals = [float(line.strip()) for line in lines] + # remove unvoiced pitch values + pitch_vals = list(filter(lambda a: a != 0., pitch_vals)) + + return pitch_vals + + +def extract_features_stats(hparams, n_jobs): + ''' Extract features stats for training and inference + ''' + # only use the training set to extract features stats + with open(hparams.training_files, 'r', encoding='utf-8') as f: + lines = f.readlines() + training_files = [line.strip().split(sep='|') for line in lines] # [[features_dir, features_file, speaker_id], ...] + + # iterate over speakers + _logger.info('--' * 30) + _logger.info('Extracting Features Stats'.upper()) + _logger.info('--' * 30) + symbols_durations = [] + speaker_stats = {f'spk {id}': {'energy': [], 'pitch': []} + for id in set(hparams.speakers_id)} + for speaker_id in set(hparams.speakers_id): + _logger.info(f'Speaker ID: {speaker_id}') + # extract all files associated to speaker ID + spk_training_files = [[x[0], x[1]] for x in training_files if int(x[2]) == speaker_id] + + # extract symbol durations + markers_files = [os.path.join(x[0], f'{x[1]}.markers') for x in spk_training_files] + symbols_durs = launch_multi_process(iterable=markers_files, func=get_symbols_durations, + n_jobs=n_jobs, hparams=hparams, timer_verbose=False) + symbols_durs = [y for x in symbols_durs for y in x] + symbols_durations.extend(symbols_durs) + + # extract non-zero energy values + energy_files = [os.path.join(x[0], f'{x[1]}.symbols_nrg') for x in spk_training_files] + energy_vals = launch_multi_process(iterable=energy_files, func=get_non_zero_energy_values, + n_jobs=n_jobs, timer_verbose=False) + energy_vals = [y for x in energy_vals for y in x] + speaker_stats[f'spk {speaker_id}']['energy'].extend(energy_vals) + + # extract voiced symbols pitch values + pitch_files = [os.path.join(x[0], f'{x[1]}.symbols_f0') for x in spk_training_files] + pitch_vals = launch_multi_process(iterable=pitch_files, func=get_voiced_pitch_values, + n_jobs=n_jobs, timer_verbose=False) + pitch_vals = [y for x in pitch_vals for y in x] + speaker_stats[f'spk {speaker_id}']['pitch'].extend(pitch_vals) + _logger.info('') + + # compute symbols durations stats + symbols_stats = collections.defaultdict(list) + for item in symbols_durations: + symbol, duration = item + symbols_stats[symbol].append(duration) + for symbol in symbols_stats: + min, max = np.min(symbols_stats[symbol]), np.max(symbols_stats[symbol]) + mean, std = np.mean(symbols_stats[symbol]), np.std(symbols_stats[symbol]) + symbols_stats[symbol] = { + 'dur_min': min, 'dur_max': max, + 'dur_mean': mean, 'dur_std': std + } + # compute energy and pitch stats for each speaker + for speaker, vals in speaker_stats.items(): + energy_vals, pitch_vals = vals['energy'], vals['pitch'] + speaker_stats[speaker] = { + 'energy': { + 'mean': np.mean(energy_vals), + 'std': np.std(energy_vals), + 'min': np.min(energy_vals), + 'max': np.max(energy_vals) + }, + 'pitch': { + 'mean': np.mean(pitch_vals), + 'std': np.std(pitch_vals), + 'min': np.min(pitch_vals), + 'max': np.max(pitch_vals) + } + } + # merge stats + stats = {**speaker_stats} + stats['symbols'] = symbols_stats + + return stats diff --git a/src/daft_exprt/fine_tune.py b/src/daft_exprt/fine_tune.py new file mode 100644 index 0000000..3f313cd --- /dev/null +++ b/src/daft_exprt/fine_tune.py @@ -0,0 +1,184 @@ +import argparse +import json +import logging +import os +import time + +import librosa +import numpy as np +import torch + +from scipy.io.wavfile import write + +from daft_exprt.data_loader import prepare_data_loaders +from daft_exprt.extract_features import mel_spectrogram_HiFi, rescale_wav_to_float32 +from daft_exprt.hparams import HyperParams +from daft_exprt.model import DaftExprt +from daft_exprt.utils import estimate_required_time + + +_logger = logging.getLogger(__name__) + + +def fine_tuning(hparams): + ''' Extract mel-specs and audio files for Vocoder fine-tuning + + :param hparams: hyper-params used for pre-processing and training + ''' + # --------------------------------------------------------- + # create model + # --------------------------------------------------------- + # load model on GPU + torch.cuda.set_device(0) + model = DaftExprt(hparams).cuda(0) + + # --------------------------------------------------------- + # load checkpoint + # --------------------------------------------------------- + assert(hparams.checkpoint != ""), _logger.error(f'No checkpoint specified -- {hparams.checkpoint}') + checkpoint_dict = torch.load(hparams.checkpoint, map_location=f'cuda:{0}') + state_dict = {k.replace('module.', ''): v for k, v in checkpoint_dict['state_dict'].items()} + model.load_state_dict(state_dict) + + # --------------------------------------------------------- + # prepare Data Loaders + # --------------------------------------------------------- + hparams.multiprocessing_distributed = False + train_loader, _, _, _ = \ + prepare_data_loaders(hparams, num_workers=0, drop_last=False) + + # --------------------------------------------------------- + # create folders to store fine-tuning data set + # --------------------------------------------------------- + experiment_root = os.path.dirname(hparams.training_files) + ft_data_set = os.path.join(experiment_root, 'fine_tuning_dataset') + hparams.ft_data_set = ft_data_set + for speaker in hparams.speakers: + os.makedirs(os.path.join(ft_data_set, speaker), exist_ok=True) + + # ============================================== + # MAIN LOOP + # ============================================== + model.eval() # set eval mode + start = time.time() + with torch.no_grad(): + # iterate over examples of train set + for idx, batch in enumerate(train_loader): + estimate_required_time(nb_items_in_list=len(train_loader), current_index=idx, + time_elapsed=time.time() - start, interval=1) + inputs, _, file_ids = model.parse_batch(0, batch) + feature_dirs, feature_files = file_ids # (B, ) and (B, ) + + outputs = model(inputs) + _, _, _, decoder_preds, _ = outputs + mel_spec_preds, output_lengths = decoder_preds + mel_spec_preds = mel_spec_preds.detach().cpu().numpy() # (B, nb_mels, T_max) + output_lengths = output_lengths.detach().cpu().numpy() # (B, ) + + # iterate over examples in the batch + for idx in range(mel_spec_preds.shape[0]): + mel_spec_pred = mel_spec_preds[idx] # (nb_mels, T_max) + output_length = output_lengths[idx] + feature_dir = feature_dirs[idx] + feature_file = feature_files[idx] + # crop mel-spec prediction to the correct size + mel_spec_pred = mel_spec_pred[:, :output_length] # (nb_mels, T) + # extract speaker name + speaker_name = [speaker for speaker in hparams.speakers if feature_dir.endswith(speaker)] + assert(len(speaker_name) == 1), _logger.error(f'{feature_dir} -- {feature_file} -- {speaker_name}') + speaker_name = speaker_name[0] + # read wav file to range [-1, 1] in np.float32 + wav_file = os.path.join(hparams.data_set_dir, speaker_name, 'wavs', f'{feature_file}.wav') + wav, fs = librosa.load(wav_file, sr=hparams.sampling_rate) + wav = rescale_wav_to_float32(wav) + # crop audio to remove tailing silences based on markers file + markers_file = os.path.join(hparams.data_set_dir, speaker_name, 'align', f'{feature_file}.markers') + with open(markers_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + sent_begin = float(lines[0].strip().split(sep='\t')[0]) + sent_end = float(lines[-1].strip().split(sep='\t')[1]) + wav = wav[int(sent_begin * fs): int(sent_end * fs)] + # check target and predicted mel-spec have the same size + mel_spec_tgt = mel_spectrogram_HiFi(wav, hparams) + assert(mel_spec_tgt.shape == mel_spec_pred.shape), \ + _logger.error(f'{feature_dir} -- {feature_file} -- {mel_spec_tgt.shape} -- {mel_spec_pred.shape}') + # save audio and mel-spec if they have the correct size (superior to 1s) + if len(wav) >= fs: + # convert to int16 + wav = wav * 32768.0 + wav = wav.astype('int16') + # store files in fine-tuning data set + mel_spec_file = os.path.join(hparams.ft_data_set, speaker_name, f'{feature_file}.npy') + wav_file = os.path.join(hparams.ft_data_set, speaker_name, f'{feature_file}.wav') + try: + np.save(mel_spec_file, mel_spec_pred) + write(wav_file, fs, wav) + except Exception as e: + _logger.error(f'{feature_dir} -- {feature_file} -- {e}') + if os.path.isfile(mel_spec_file): + os.remove(mel_spec_file) + if os.path.isfile(wav_file): + os.remove(wav_file) + else: + _logger.warning(f'{feature_dir} -- {feature_file} -- Ignoring because audio is < 1s') + + +def launch_fine_tuning(data_set_dir, config_file, log_file): + ''' Launch fine-tuning + ''' + # set logger config + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file) + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + + # get hyper-parameters + with open(config_file) as f: + data = f.read() + config = json.loads(data) + hparams = HyperParams(verbose=False, **config) + + # update hparams + hparams.data_set_dir = data_set_dir + hparams.config_file = config_file + + # save hyper-params to config.json + hparams.save_hyper_params(hparams.config_file) + + # define cudnn variables + torch.manual_seed(0) + torch.backends.cudnn.enabled = hparams.cudnn_enabled + torch.backends.cudnn.benchmark = hparams.cudnn_benchmark + torch.backends.cudnn.deterministic = hparams.cudnn_deterministic + + # display fine-tuning setup info + _logger.info(f'PyTorch version -- {torch.__version__}') + _logger.info(f'CUDA version -- {torch.version.cuda}') + _logger.info(f'CUDNN version -- {torch.backends.cudnn.version()}') + _logger.info(f'CUDNN enabled = {torch.backends.cudnn.enabled}') + _logger.info(f'CUDNN deterministic = {torch.backends.cudnn.deterministic}') + _logger.info(f'CUDNN benchmark = {torch.backends.cudnn.benchmark}\n') + + # create fine-tuning data set + fine_tuning(hparams) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--data_set_dir', type=str, required=True, + help='Data set containing .wav files') + parser.add_argument('--config_file', type=str, required=True, + help='JSON configuration file to initialize hyper-parameters for fine-tuning') + parser.add_argument('--log_file', type=str, required=True, + help='path to save logger outputs') + + args = parser.parse_args() + + # launch fine-tuning + launch_fine_tuning(args.data_set_dir, args.config_file, args.log_file) diff --git a/src/daft_exprt/generate.py b/src/daft_exprt/generate.py new file mode 100644 index 0000000..2f9bd36 --- /dev/null +++ b/src/daft_exprt/generate.py @@ -0,0 +1,494 @@ +import collections +import logging +import logging.handlers +import os +import random +import re +import time +import uuid + +import librosa +import numpy as np +import torch + +from scipy.io import wavfile +from shutil import rmtree + +from daft_exprt.cleaners import collapse_whitespace, text_cleaner +from daft_exprt.extract_features import extract_energy, extract_pitch, mel_spectrogram_HiFi, rescale_wav_to_float32 +from daft_exprt.griffin_lim import griffin_lim_reconstruction_from_mel_spec +from daft_exprt.symbols import ascii, eos, punctuation, whitespace +from daft_exprt.utils import chunker, launch_multi_process, plot_2d_data + + +_logger = logging.getLogger(__name__) +FILE_ROOT = os.path.dirname(os.path.realpath(__file__)) + + +def phonemize_sentence(sentence, hparams, log_queue): + ''' Phonemize sentence using MFA + ''' + # get MFA variables + dictionary = hparams.mfa_dictionary + g2p_model = hparams.mfa_g2p_model + # load dictionary and extract word transcriptions + word_trans = collections.defaultdict(list) + with open(dictionary, 'r', encoding='utf-8') as f: + lines = [line.strip().split() for line in f.readlines()] + for line in lines: + word_trans[line[0].lower()].append(line[1:]) + # characters to consider in the sentence + if hparams.language == 'english': + all_chars = ascii + punctuation + else: + raise NotImplementedError() + + # clean sentence + # "that's, an 'example! ' of a sentence. '" + sentence = text_cleaner(sentence.strip(), hparams.language).lower().strip() + # split sentence: + # [',', "that's", ',', 'an', "example'", '!', "'", 'of', 'a', 'sentence', '.', '.', '.', "'"] + sent_words = re.findall(f"[\w']+|[{punctuation}]", sentence.lower().strip()) + # remove characters that are not letters or punctuation: + # [',', "that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence', '.', '.', '.'] + sent_words = [x for x in sent_words if len(re.sub(f'[^{all_chars}]', '', x)) != 0] + # be sure to begin the sentence with a word and not a punctuation + # ["that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence', '.', '.', '.'] + while sent_words[0] in punctuation: + sent_words.pop(0) + # keep only one punctuation type at the end + # ["that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence'] + punctuation_end = None + while sent_words[-1] in punctuation: + punctuation_end = sent_words.pop(-1) + sent_words.append(punctuation_end) + + # phonemize words and add word boundaries + sentence_phonemized, unk_words = [], [] + while len(sent_words) != 0: + word = sent_words.pop(0) + if word in word_trans: + phones = random.choice(word_trans[word]) + sentence_phonemized.append(phones) + else: + unk_words.append(word) + sentence_phonemized.append('') + # at this point we pass to the next word + # we must add a word boundary between two consecutive words + if len(sent_words) != 0: + word_bound = sent_words.pop(0) if sent_words[0] in punctuation else whitespace + sentence_phonemized.append(word_bound) + # add EOS token + sentence_phonemized.append(eos) + + # use MFA g2p model to phonemize unknown words + if len(unk_words) != 0: + rand_name = str(uuid.uuid4()) + oovs = os.path.join(FILE_ROOT, f'{rand_name}_oovs.txt') + with open(oovs, 'w', encoding='utf-8') as f: + for word in unk_words: + f.write(f'{word}\n') + # generate transcription for unknown words + oovs_trans = os.path.join(FILE_ROOT, f'{rand_name}_oovs_trans.txt') + tmp_dir = os.path.join(FILE_ROOT, f'{rand_name}') + os.system(f'mfa g2p {g2p_model} {oovs} {oovs_trans} -t {tmp_dir}') + # extract transcriptions + with open(oovs_trans, 'r', encoding='utf-8') as f: + lines = [line.strip().split() for line in f.readlines()] + for line in lines: + transcription = line[1:] + unk_idx = sentence_phonemized.index('') + sentence_phonemized[unk_idx] = transcription + # remove files + os.remove(oovs) + os.remove(oovs_trans) + rmtree(tmp_dir, ignore_errors=True) + + return sentence_phonemized + + +def save_mel_spec_plot_and_audio(item, output_dir, hparams, log_queue): + ''' Save mel-outputs/alignment plots and generate an audio using Griffin-Lim algorithm + + :param item: (n_mel_channels, T + 1) -- mel-spectrogram numpy array + :param alignments: (L, T + 1) -- alignment numpy array + :param output_dir: directory to save plots and audio + :param file_name: filename to save plots and audio + :param hparams: hyper-parameters used for pre-processing and training + :param log_queue: logging queue for multi-processing + ''' + # create logger from logging queue + qh = logging.handlers.QueueHandler(log_queue) + root = logging.getLogger() + if not root.hasHandlers(): + root.setLevel(logging.INFO) + root.addHandler(qh) + logger = logging.getLogger(f"worker{str(uuid.uuid4())}") + + # extract items + file_name, mel_spec, weight = item + # create a figure from the output + plot_2d_data(data=(mel_spec, weight), + x_labels=('Mel-Spec Prediction', 'Alignments'), + filename=os.path.join(output_dir, file_name + '.png')) + # generate audio using Griffin-Lim + waveform = griffin_lim_reconstruction_from_mel_spec(mel_spec, hparams, logger) + if waveform != []: + wavfile.write(os.path.join(output_dir, file_name + '.wav'), hparams.sampling_rate, waveform) + + +def collate_tensors(batch_sentences, batch_dur_factors, batch_energy_factors, + batch_pitch_factors, pitch_transform, batch_refs, + batch_speaker_ids, batch_file_names, hparams): + ''' Extract PyTorch tensors for each sentence and collate them for batch generation + ''' + # gather batch + batch = [] + for sentence, dur_factors, energy_factors, pitch_factors, refs in \ + zip(batch_sentences, batch_dur_factors, batch_energy_factors, batch_pitch_factors, batch_refs): + # encode input text as a sequence of int symbols + symbols = [] + for item in sentence: + if isinstance(item, list): # correspond to phonemes of a word + symbols += [hparams.symbols.index(phone) for phone in item] + else: # correspond to word boundaries + symbols.append(hparams.symbols.index(item)) + symbols = torch.IntTensor(symbols) # (L, ) + # extract duration factors + if dur_factors is None: + dur_factors = [1. for _ in range(len(symbols))] + dur_factors = torch.FloatTensor(dur_factors) # (L, ) + assert(len(dur_factors) == len(symbols)), \ + _logger.error(f'{len(dur_factors)} duration factors whereas there a {len(symbols)} symbols') + # extract energy factors + if energy_factors is None: + energy_factors = [1. for _ in range(len(symbols))] + energy_factors = torch.FloatTensor(energy_factors) # (L, ) + assert(len(energy_factors) == len(symbols)), \ + _logger.error(f'{len(energy_factors)} energy factors whereas there a {len(symbols)} symbols') + # extract pitch factors + if pitch_factors is None: + if pitch_transform == 'add': + pitch_factors = [0. for _ in range(len(symbols))] + elif pitch_transform == 'multiply': + pitch_factors = [1. for _ in range(len(symbols))] + pitch_factors = torch.FloatTensor(pitch_factors) # (L, ) + assert(len(pitch_factors) == len(symbols)), \ + _logger.error(f'{len(pitch_factors)} pitch factors whereas there a {len(symbols)} symbols') + # extract references + refs = np.load(refs) + energy_ref, pitch_ref, mel_spec_ref = refs['energy'], refs['pitch'], refs['mel_spec'] + energy_ref = torch.from_numpy(energy_ref).float() # (T_ref, ) + pitch_ref = torch.from_numpy(pitch_ref).float() # (T_ref, ) + mel_spec_ref = torch.from_numpy(mel_spec_ref).float() # (n_mel_channels, T_ref) + # gather data + batch.append([symbols, dur_factors, energy_factors, pitch_factors, energy_ref, pitch_ref, mel_spec_ref]) + + # find symbols sequence max length + input_lengths, ids_sorted_decreasing = \ + torch.sort(torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True) + max_input_len = input_lengths[0] + # right pad sequences to max input length + symbols = torch.LongTensor(len(batch), max_input_len).zero_() + dur_factors = 1 + torch.FloatTensor(len(batch), max_input_len).zero_() + energy_factors = 1 + torch.FloatTensor(len(batch), max_input_len).zero_() + if pitch_transform == 'add': + pitch_factors = torch.FloatTensor(len(batch), max_input_len).zero_() + elif pitch_transform == 'multiply': + pitch_factors = 1 + torch.FloatTensor(len(batch), max_input_len).zero_() + # fill padded arrays + for i in range(len(ids_sorted_decreasing)): + # extract batch sequences + symbols_seq = batch[ids_sorted_decreasing[i]][0] + dur_factors_seq = batch[ids_sorted_decreasing[i]][1] + energy_factors_seq = batch[ids_sorted_decreasing[i]][2] + pitch_factors_seq = batch[ids_sorted_decreasing[i]][3] + # add sequences to padded arrays + symbols[i, :symbols_seq.size(0)] = symbols_seq + dur_factors[i, :dur_factors_seq.size(0)] = dur_factors_seq + energy_factors[i, :energy_factors_seq.size(0)] = energy_factors_seq + pitch_factors[i, :pitch_factors_seq.size(0)] = pitch_factors_seq + + # find reference max length + max_ref_len = max([x[6].size(1) for x in batch]) + # right zero-pad references to max output length + energy_refs = torch.FloatTensor(len(batch), max_ref_len).zero_() + pitch_refs = torch.FloatTensor(len(batch), max_ref_len).zero_() + mel_spec_refs = torch.FloatTensor(len(batch), hparams.n_mel_channels, max_ref_len).zero_() + ref_lengths = torch.LongTensor(len(batch)) + # fill padded arrays + for i in range(len(ids_sorted_decreasing)): + # extract batch sequences + energy_ref_seq = batch[ids_sorted_decreasing[i]][4] + pitch_ref_seq = batch[ids_sorted_decreasing[i]][5] + mel_spec_ref_seq = batch[ids_sorted_decreasing[i]][6] + # add sequences to padded arrays + energy_refs[i, :energy_ref_seq.size(0)] = energy_ref_seq + pitch_refs[i, :pitch_ref_seq.size(0)] = pitch_ref_seq + mel_spec_refs[i, :, :mel_spec_ref_seq.size(1)] = mel_spec_ref_seq + ref_lengths[i] = mel_spec_ref_seq.size(1) + + # reorganize speaker IDs and file names + file_names = [] + speaker_ids = torch.LongTensor(len(batch)) + for i in range(len(ids_sorted_decreasing)): + file_names.append(batch_file_names[ids_sorted_decreasing[i]]) + speaker_ids[i] = batch_speaker_ids[ids_sorted_decreasing[i]] + + return symbols, dur_factors, energy_factors, pitch_factors, input_lengths, \ + energy_refs, pitch_refs, mel_spec_refs, ref_lengths, speaker_ids, file_names + + +def generate_batch_mel_specs(model, batch_sentences, batch_refs, batch_dur_factors, + batch_energy_factors, batch_pitch_factors, pitch_transform, + batch_speaker_ids, batch_file_names, output_dir, hparams, + n_jobs, use_griffin_lim=True): + ''' Generate batch mel-specs using Daft-Exprt + ''' + # add speaker info to file name + for idx, file_name in enumerate(batch_file_names): + file_name += f'_spk_{batch_speaker_ids[idx]}' + file_name += f'_ref_{os.path.basename(batch_refs[idx]).replace(".npz", "")}' + batch_file_names[idx] = file_name + _logger.info(f'Generating "{batch_sentences[idx]}" as "{file_name}"') + # collate batch tensors + symbols, dur_factors, energy_factors, pitch_factors, input_lengths, \ + energy_refs, pitch_refs, mel_spec_refs, ref_lengths, speaker_ids, file_names = \ + collate_tensors(batch_sentences, batch_dur_factors, batch_energy_factors, + batch_pitch_factors, pitch_transform, batch_refs, + batch_speaker_ids, batch_file_names, hparams) + # put tensors on GPU + gpu = next(model.parameters()).device + symbols = symbols.cuda(gpu, non_blocking=True).long() # (B, L_max) + dur_factors = dur_factors.cuda(gpu, non_blocking=True).float() # (B, L_max) + energy_factors = energy_factors.cuda(gpu, non_blocking=True).float() # (B, L_max) + pitch_factors = pitch_factors.cuda(gpu, non_blocking=True).float() # (B, L_max) + input_lengths = input_lengths.cuda(gpu, non_blocking=True).long() # (B, ) + energy_refs = energy_refs.cuda(gpu, non_blocking=True).float() # (B, T_max) + pitch_refs = pitch_refs.cuda(gpu, non_blocking=True).float() # (B, T_max) + mel_spec_refs = mel_spec_refs.cuda(gpu, non_blocking=True).float() # (B, n_mel_channels, T_max) + ref_lengths = ref_lengths.cuda(gpu, non_blocking=True).long() # (B, ) + speaker_ids = speaker_ids.cuda(gpu, non_blocking=True).long() # (B, ) + # perform inference + inputs = (symbols, dur_factors, energy_factors, pitch_factors, input_lengths, + energy_refs, pitch_refs, mel_spec_refs, ref_lengths, speaker_ids) + try: + encoder_preds, decoder_preds, alignments = model.inference(inputs, pitch_transform, hparams) + except: + encoder_preds, decoder_preds, alignments = model.module.inference(inputs, pitch_transform, hparams) + # parse outputs + duration_preds, durations_int, energy_preds, pitch_preds, input_lengths = encoder_preds + mel_spec_preds, output_lengths = decoder_preds + weights = alignments + # transfer data to cpu and convert to numpy array + duration_preds = duration_preds.detach().cpu().numpy() # (B, L_max) + durations_int = durations_int.detach().cpu().numpy() # (B, L_max) + energy_preds = energy_preds.detach().cpu().numpy() # (B, L_max) + pitch_preds = pitch_preds.detach().cpu().numpy() # (B, L_max) + input_lengths = input_lengths.detach().cpu().numpy() # (B, ) + mel_spec_preds = mel_spec_preds.detach().cpu().numpy() # (B, n_mel_channels, T_max) + output_lengths = output_lengths.detach().cpu().numpy() # (B) + weights = weights.detach().cpu().numpy() # (B, L_max, T_max) + + # save preds for each element in the batch + predictions = {} + for line_idx in range(mel_spec_preds.shape[0]): + # crop prosody preds to the correct length + duration_pred = duration_preds[line_idx, :input_lengths[line_idx]] # (L, ) + duration_int = durations_int[line_idx, :input_lengths[line_idx]] # (L, ) + energy_pred = energy_preds[line_idx, :input_lengths[line_idx]] # (L, ) + pitch_pred = pitch_preds[line_idx, :input_lengths[line_idx]] # (L, ) + # crop mel-spec to the correct length + mel_spec_pred = mel_spec_preds[line_idx, :, :output_lengths[line_idx]] # (n_mel_channels, T) + # crop weights to the correct length + weight = weights[line_idx, :input_lengths[line_idx], :output_lengths[line_idx]] + # save generated spectrogram + file_name = file_names[line_idx] + np.savez(os.path.join(output_dir, f'{file_name}.npz'), mel_spec=mel_spec_pred) + # store predictions + predictions[f'{file_name}'] = [duration_pred, duration_int, energy_pred, pitch_pred, mel_spec_pred, weight] + + # save plots and generate audio using Griffin-Lim + if use_griffin_lim: + items = [[file_name, mel_spec, weight] for file_name, (_, _, _, _, mel_spec, weight) in predictions.items()] + launch_multi_process(iterable=items, func=save_mel_spec_plot_and_audio, n_jobs=n_jobs, + timer_verbose=False, output_dir=output_dir, hparams=hparams) + + return predictions + + +def generate_mel_specs(model, sentences, file_names, speaker_ids, refs, output_dir, hparams, + dur_factors=None, energy_factors=None, pitch_factors=None, batch_size=1, + n_jobs=1, use_griffin_lim=False, get_time_perf=False): + ''' Generate mel-specs using Daft-Exprt + + sentences = [ + sentence_1, + ... + sentence_N + ] + each sentence is a list of symbols: + sentence = [ + [symbols_word_1], + word_boundary_symbol, + [symbols_word_2], + word_boundary_symbol, + ... + ] + for example, here is a sentence of 5 words, 6 word boundaries and a total of 17 symbols: + sentence = [['IH0', 'Z'], ' ', ['IH0', 'T'], ',', ['AH0'], ' ', ['G', 'UH1', 'D'], ' ', ['CH', 'OY1', 'S'], '?', '~'] + + file_names = [ + file_name_1, + ... + file_name_N + ] + + speaker_ids = [ + speaker_id_1, + ... + speaker_id_N + ] + + refs = [ + path_to_ref_1.npz, + ... + path_to_ref_N.npz + ] + + dur_factors = [ + [factor_sentence_1_symbol_1, factor_sentence_1_symbol_2, ...], + ... + [factor_sentence_N_symbol_1, factor_sentence_N_symbol_2, ...] + ] + if None, duration predictions are not modified + + energy_factors = [ + [factor_sentence_1_symbol_1, factor_sentence_1_symbol_2, ...], + ... + [factor_sentence_N_symbol_1, factor_sentence_N_symbol_2, ...] + ] + if None, energy predictions are not modified + + pitch_factors = [ + "transform", + [ + [factor_sentence_1_symbol_1, factor_sentence_1_symbol_2, ...], + ... + [factor_sentence_N_symbol_1, factor_sentence_N_symbol_2, ...] + ] + ] + There are 2 types of transforms: + - pitch shift: "add" + - pitch multiply: "multiply" + if None, pitch predictions are not modified + ''' + # set default values if prosody factors are None + dur_factors = [None for _ in range(len(sentences))] if dur_factors is None else dur_factors + energy_factors = [None for _ in range(len(sentences))] if energy_factors is None else energy_factors + pitch_factors = ['add', [None for _ in range(len(sentences))]] if pitch_factors is None else pitch_factors + # get pitch transform + pitch_transform = pitch_factors[0].lower() + pitch_factors = pitch_factors[1] + assert(pitch_transform in ['add', 'multiply']), _logger.error(f'Pitch transform "{pitch_transform}" is not currently supported') + # check lists have the same size + assert (len(file_names) == len(sentences)), _logger.error(f'{len(file_names)} filenames but there are {len(sentences)} sentences to generate') + assert (len(speaker_ids) == len(sentences)), _logger.error(f'{len(speaker_ids)} speaker IDs but there are {len(sentences)} sentences to generate') + assert (len(refs) == len(sentences)), _logger.error(f'{len(refs)} references but there are {len(sentences)} sentences to generate') + assert (len(dur_factors) == len(sentences)), _logger.error(f'{len(dur_factors)} duration factors but there are {len(sentences)} sentences to generate') + assert (len(energy_factors) == len(sentences)), _logger.error(f'{len(energy_factors)} energy factors but there are {len(sentences)} sentences to generate') + assert (len(pitch_factors) == len(sentences)), _logger.error(f'{len(pitch_factors)} pitch factors but there are {len(sentences)} sentences to generate') + + # we don't need computational graph for inference + model.eval() # set eval mode + os.makedirs(output_dir, exist_ok=True) + predictions, time_per_batch = {}, [] + with torch.no_grad(): + for batch_sentences, batch_refs, batch_dur_factors, batch_energy_factors, \ + batch_pitch_factors, batch_speaker_ids, batch_file_names in \ + zip(chunker(sentences, batch_size), chunker(refs, batch_size), + chunker(dur_factors, batch_size), chunker(energy_factors, batch_size), + chunker(pitch_factors, batch_size), chunker(speaker_ids, batch_size), + chunker(file_names, batch_size)): + sentence_begin = time.time() if get_time_perf else None + batch_predictions = generate_batch_mel_specs(model, batch_sentences, batch_refs, batch_dur_factors, + batch_energy_factors, batch_pitch_factors, pitch_transform, + batch_speaker_ids, batch_file_names, output_dir, hparams, + n_jobs, use_griffin_lim) + predictions.update(batch_predictions) + time_per_batch += [time.time() - sentence_begin] if get_time_perf else [] + + # display overall time performance + if get_time_perf: + # get duration of each sentence + durations = [] + for prediction in predictions.values(): + _, _, _, _, mel_spec, _ = prediction + nb_frames = mel_spec.shape[1] + nb_wav_samples = (nb_frames - 1) * hparams.hop_length + hparams.filter_length + if hparams.centered: + nb_wav_samples -= 2 * int(hparams.filter_length / 2) + duration = nb_wav_samples / hparams.sampling_rate + durations.append(duration) + _logger.info(f'') + _logger.info(f'{len(predictions)} sentences ({sum(durations):.2f}s) generated in {sum(time_per_batch):.2f}s') + _logger.info(f'DaftExprt RTF: {sum(durations)/sum(time_per_batch):.2f}') + + return predictions + + +def extract_reference_parameters(audio_ref, output_dir, hparams): + ''' Extract energy, pitch and mel-spectrogram parameters from audio + Save numpy arrays to .npz file + ''' + # check if file name already exists + os.makedirs(output_dir, exist_ok=True) + file_name = os.path.basename(audio_ref).replace('.wav', '') + ref_file = os.path.join(output_dir, f'{file_name}.npz') + if not os.path.isfile(ref_file): + # read wav file to range [-1, 1] in np.float32 + wav, fs = librosa.load(audio_ref, sr=hparams.sampling_rate) + wav = rescale_wav_to_float32(wav) + # get log pitch + pitch = extract_pitch(wav, fs, hparams) + # extract mel-spectrogram + mel_spec = mel_spectrogram_HiFi(wav, hparams) + # get energy + energy = extract_energy(np.exp(mel_spec)) + # check sizes are correct + assert(len(pitch) == mel_spec.shape[1]), f'{len(pitch)} -- {mel_spec.shape[1]}' + assert(len(energy) == mel_spec.shape[1]), f'{len(energy)} -- {mel_spec.shape[1]}' + # save references to .npz file + np.savez(ref_file, energy=energy, pitch=pitch, mel_spec=mel_spec) + + +def prepare_sentences_for_inference(text_file, output_dir, hparams, n_jobs): + ''' Phonemize and format sentences to synthesize + ''' + # create output directory or delete everything if it already exists + if os.path.exists(output_dir): + rmtree(output_dir) + os.makedirs(output_dir, exist_ok=False) + + # extract sentences to synthesize + assert(os.path.isfile(text_file)), _logger.error(f'There is no such file {text_file}') + with open(os.path.join(text_file), 'r', encoding='utf-8') as f: + sentences = [line.strip() for line in f] + file_names = [f'{os.path.basename(text_file)}_line{idx}' for idx in range(len(sentences))] + # phonemize + hparams.update_mfa_paths() + sentences = launch_multi_process(iterable=sentences, func=phonemize_sentence, + n_jobs=n_jobs, timer_verbose=False, hparams=hparams) + + # save the sentences in a file + with open(os.path.join(output_dir, 'sentences_to_generate.txt'), 'w', encoding='utf-8') as f: + for sentence, file_name in zip(sentences, file_names): + text = '' + for item in sentence: + if isinstance(item, list): # corresponds to phonemes of a word + item = '{' + ' '.join(item) + '}' + text = f'{text} {item} ' + text = collapse_whitespace(text).strip() + f.write(f'{file_name}|{text}\n') + + return sentences, file_names diff --git a/src/daft_exprt/griffin_lim.py b/src/daft_exprt/griffin_lim.py new file mode 100644 index 0000000..bf2371b --- /dev/null +++ b/src/daft_exprt/griffin_lim.py @@ -0,0 +1,198 @@ +import logging + +import numpy as np +import scipy + +from librosa.filters import mel as librosa_mel_fn + + +_logger = logging.getLogger(__name__) + + +def _nnls_obj(x, shape, A, B): + ''' Compute the objective and gradient for NNLS + ''' + # scipy's lbfgs flattens all arrays, so we first reshape + # the iterate x + x = x.reshape(shape) + + # compute the difference matrix + diff = np.dot(A, x) - B + + # compute the objective value + value = 0.5 * np.sum(diff ** 2) + + # and the gradient + grad = np.dot(A.T, diff) + + # flatten the gradient + return value, grad.flatten() + + +def _nnls_lbfgs_block(A, B, x_init=None, **kwargs): + ''' Solve the constrained problem over a single block + + :param A: the basis matrix -- shape = (m, d) + :param B: the regression targets -- shape = (m, N) + :param x_init: initial guess -- shape = (d, N) + :param kwargs: additional keyword arguments to `scipy.optimize.fmin_l_bfgs_b` + + :return: non-negative matrix x such that Ax ~= B -- shape = (d, N) + ''' + # if we don't have an initial point, start at the projected + # least squares solution + if x_init is None: + x_init = np.linalg.lstsq(A, B, rcond=None)[0] + np.clip(x_init, 0, None, out=x_init) + + # adapt the hessian approximation to the dimension of the problem + kwargs.setdefault("m", A.shape[1]) + + # construct non-negative bounds + bounds = [(0, None)] * x_init.size + shape = x_init.shape + + # optimize + x, obj_value, diagnostics = scipy.optimize.fmin_l_bfgs_b( + _nnls_obj, x_init, args=(shape, A, B), bounds=bounds, **kwargs + ) + # reshape the solution + return x.reshape(shape) + + +def nnls(A, B, **kwargs): + ''' Non-negative least squares. + Given two matrices A and B, find a non-negative matrix X + that minimizes the sum squared error: + err(X) = sum_i,j ((AX)[i,j] - B[i, j])^2 + + :param A: the basis matrix -- shape = (m, n) + :param B: the target matrix -- shape = (m, N) + :param kwargs: additional keyword arguments to `scipy.optimize.fmin_l_bfgs_b` + + :return: non-negative matrix X minimizing ``|AX - B|^2`` -- shape = (n, N) + ''' + # if B is a single vector, punt up to the scipy method + if B.ndim == 1: + return scipy.optimize.nnls(A, B)[0] + + # constrain block sizes to 256 KB + MAX_MEM_BLOCK = 2 ** 8 * 2 ** 10 + n_columns = MAX_MEM_BLOCK // (A.shape[-1] * A.itemsize) + n_columns = max(n_columns, 1) + + # process in blocks + if B.shape[-1] <= n_columns: + return _nnls_lbfgs_block(A, B, **kwargs).astype(A.dtype) + + x = np.linalg.lstsq(A, B, rcond=None)[0].astype(A.dtype) + np.clip(x, 0, None, out=x) + x_init = x + + for bl_s in range(0, x.shape[-1], n_columns): + bl_t = min(bl_s + n_columns, B.shape[-1]) + x[:, bl_s:bl_t] = _nnls_lbfgs_block( + A, B[:, bl_s:bl_t], x_init=x_init[:, bl_s:bl_t], **kwargs + ) + return x + + +def mel_to_linear(mel_spectrogram, hparams): + ''' Convert a mel-spectrogram to a linear spectrogram + + :param mel_spectrogram: Numpy array of the input mel spectrogram -- shape = (n_mels, T) + :param hparams: hyper-parameters used for pre-processing and training + + :return: numpy array containing the spectrogram in linear frequency space -- shape = (n_fft // 2 + 1, T) + ''' + # find the number of mel components + n_mels = mel_spectrogram.shape[0] + # get filter parameters -- (n_mels, 1 + n_fft//2) + mel_filter_bank = librosa_mel_fn(hparams.sampling_rate, hparams.filter_length, n_mels, hparams.mel_fmin, hparams.mel_fmax) + + # solve the non-linear least squares problem + return nnls(mel_filter_bank, mel_spectrogram) + + +def reconstruct_signal_griffin_lim(magnitude_spectrogram, step_size, iterations, logger): + ''' Reconstruct an audio signal from a magnitude spectrogram + + Given a magnitude spectrogram as input, reconstruct the audio signal and return it using + the Griffin-Lim algorithm + From the paper: "Signal estimation from modified short-time fourier transform" by Griffin and Lim, in IEEE + transactions on Acoustics, Speech, and Signal Processing. Vol ASSP-32, No. 2, April 1984. + + :param magnitude_spectrogram: Numpy array magnitude spectrogram -- shape = (n_fft // 2 + 1, T) + The rows correspond to frequency bins and the columns correspond to time slices + :param step_size: length (in samples) between successive analysis windows + :param iterations: Number of iterations for the Griffin-Lim algorithm + Typically a few hundred is sufficient + :param logger: logger object + + :return: the reconstructed time domain signal as a 1-dim Numpy array and the spectrogram that was used + to produce the signal + ''' + # shape = (T, n_fft // 2 + 1) + magnitude_spectrogram = np.transpose(magnitude_spectrogram) + + # find the number of samples used in the FFT window and extract the time steps + n_fft = (magnitude_spectrogram.shape[1] - 1) * 2 + time_slices = magnitude_spectrogram.shape[0] + + # compute the number of samples needed + len_samples = int(time_slices * step_size + n_fft) + + # initialize the reconstructed signal to noise + x_reconstruct = np.random.randn(len_samples) + window = np.hanning(n_fft) + n = iterations # number of iterations of Griffin-Lim algorithm + + while n > 0: + # decrement and compute FFT + n -= 1 + reconstruction_spectrogram = np.array([np.fft.rfft(window * x_reconstruct[i: i + n_fft]) + for i in range(0, len(x_reconstruct) - n_fft, step_size)]) + + # Discard magnitude part of the reconstruction and use the supplied magnitude spectrogram instead + proposal_spectrogram = magnitude_spectrogram * np.exp(1.0j * np.angle(reconstruction_spectrogram)) + + # store previous reconstructed signal and create a new one by iFFT + prev_x = x_reconstruct + x_reconstruct = np.zeros(len_samples) + + for i, j in enumerate(range(0, len(x_reconstruct) - n_fft, step_size)): + x_reconstruct[j: j + n_fft] += window * np.real(np.fft.irfft(proposal_spectrogram[i])) + + # normalise signal due to overlap add + x_reconstruct = x_reconstruct / (n_fft / step_size / 2) + + # compute diff between two signals and report progress + diff = np.sqrt(sum((x_reconstruct - prev_x) ** 2) / x_reconstruct.size) + logger.debug(f'Reconstruction iteration: {iterations - n}/{iterations} -- RMSE: {diff * 1e6:.3f}e-6') + + return x_reconstruct, proposal_spectrogram + + +def griffin_lim_reconstruction_from_mel_spec(mel_spec, hparams, logger): + ''' Convert a mel-spectrogram into an audio waveform using Griffin-Lim algorithm + + :param mel_spec: mel-spectrogram corresponding to the audio to generate + :param hparams: hyper-parameters used for pre-processing and training + :param logger: logger object + + :return: the reconstructed audio waveform + ''' + # remove np.log + mel_spec = np.exp(mel_spec) + + # pass from mel to linear + linear_spec = mel_to_linear(mel_spec, hparams) + + # use Griffin-Lim algorithm + waveform = [] + if len(linear_spec.shape) == 2: + waveform, _ = reconstruct_signal_griffin_lim(linear_spec[:, :-2], hparams.hop_length, + iterations=30, logger=logger) + waveform = waveform / np.max(abs(waveform)) + + return waveform diff --git a/src/daft_exprt/hparams.py b/src/daft_exprt/hparams.py new file mode 100644 index 0000000..f0de873 --- /dev/null +++ b/src/daft_exprt/hparams.py @@ -0,0 +1,244 @@ +import json +import logging +import os +import sys + +from pathlib import Path + +from daft_exprt.symbols import pad, symbols_english + + +_logger = logging.getLogger(__name__) + + +''' + Hyper-parameters used for pre-processing and training +''' + + +class HyperParams(object): + def __init__(self, verbose=True, **kwargs): + ''' Initialize hyper-parameter values for data pre-processing and training + + :param verbose: whether to display logger info/warnings or not + :param kwargs: keyword arguments to modify hyper-params values + ''' + # display some logger info + if verbose: + _logger.info('--' * 30) + _logger.info('Setting Hyper-Parameters'.upper()) + _logger.info('--' * 30) + + ########################################### + #### hard-coded hyper-parameter values #### + ########################################### + # misc hyper-parameters + self.minimum_wav_duration = 1000 # minimum duration (ms) of the audio files used for training + + # mel-spec extraction hyper-parameters + self.centered = True # extraction window is centered on the time step when doing FFT + self.min_clipping = 1e-5 # min clipping value when creating mel-specs + self.sampling_rate = 22050 # sampling rate of the audios in the data set + self.mel_fmin = 0 # lowest frequency (in Hz) of the mel-spectrogram + self.mel_fmax = 8000 # highest frequency (in Hz) of the mel-spectrogram + self.n_mel_channels = 80 # number of mel bands to generate + self.filter_length = 1024 # FFT window length (in samples) + self.hop_length = 256 # length (in samples) between successive analysis windows for FFT + + # REAPER pitch extraction hyper-parameters + self.f0_interval = 0.005 + self.min_f0 = 40 + self.max_f0 = 500 + self.uv_interval = 0.01 + self.uv_cost = 0.9 + self.order = 1 + self.cutoff = 25 + + # training hyper-parameters + self.seed = 1234 # seed used to initialize weights + self.cudnn_enabled = True # parameter used when initializing training + self.cudnn_benchmark = False # parameter used when initializing training + self.cudnn_deterministic = True # parameter used when initializing training + self.dist_backend = 'nccl' # parameter used to perform distributed training + self.nb_iterations = 370000 # total number of iterations to perform during training + self.iters_per_checkpoint = 10000 # number of iterations between successive checkpoints + self.iters_check_for_model_improvement = 5000 # number of iterations between successive evaluation on the validation set + self.batch_size = 16 # batch size per GPU card + self.accumulation_steps = 3 # number of iterations before updating model parameters (gradient accumulation) + self.checkpoint = '' # checkpoint to use to restart training at a specific place + + # loss weigths hyper-parameters + self.lambda_reversal = 1. # lambda multiplier used in reversal gradient layer + self.adv_max_weight = 1e-2 # max weight to apply on speaker adversarial loss + self.post_mult_weight = 1e-3 # weight to apply on FiLM scalar post-multipliers + self.dur_weight = 1. # weight to apply on duration loss + self.energy_weight = 1. # weight to apply on energy loss + self.pitch_weight = 1. # weight to apply on pitch loss + self.mel_spec_weight = 1. # weight to apply on mel-spec loss + + # optimizer hyper-parameters + self.optimizer = 'adam' # optimizer to use for training + self.betas = (0.9, 0.98) # betas coefficients in Adam + self.epsilon = 1e-9 # used for numerical stability in Adam + self.weight_decay = 1e-6 # weight decay (L2 regularization) to use in the optimizer + self.initial_learning_rate = 1e-4 # value of learning rate at iteration 0 + self.max_learning_rate = 1e-3 # max value of learning rate during training + self.warmup_steps = 10000 # linearly increase the learning rate for the first warmup steps + self.grad_clip_thresh = float('inf') # gradient clipping threshold to stabilize training + + # Daft-Exprt module hyper-parameters + self.prosody_encoder = { + 'nb_blocks': 4, + 'hidden_embed_dim': 128, + 'attn_nb_heads': 8, + 'attn_dropout': 0.1, + 'conv_kernel': 3, + 'conv_channels': 1024, + 'conv_dropout': 0.1 + } + + self.phoneme_encoder = { + 'nb_blocks': 4, + 'hidden_embed_dim': 128, + 'attn_nb_heads': 2, + 'attn_dropout': 0.1, + 'conv_kernel': 3, + 'conv_channels': 1024, + 'conv_dropout': 0.1 + } + + self.local_prosody_predictor = { + 'nb_blocks': 1, + 'conv_kernel': 3, + 'conv_channels': 256, + 'conv_dropout': 0.1, + } + + self.gaussian_upsampling_module = { + 'conv_kernel': 3 + } + + self.frame_decoder = { + 'nb_blocks': 4, + 'attn_nb_heads': 2, + 'attn_dropout': 0.1, + 'conv_kernel': 3, + 'conv_channels': 1024, + 'conv_dropout': 0.1 + } + + ###################################################################### + #### hyper-parameter values that have to be specified in **kwargs #### + ###################################################################### + self.training_files = None # path to training files + self.validation_files = None # path to validation files + self.output_directory = None # path to save training outputs (checkpoints, config files, audios, logging ...) + + self.language = None # spoken language of the speaker(s) + self.speakers = None # speakers we want to use for training or transfer learning + + ########################################################################################## + #### hyper-parameter inferred from other hyper-params values or specified in **kwargs #### + ########################################################################################## + self.stats = {} # features stats used during training and inference + self.symbols = [] # list of symbols used in the specified language + + self.n_speakers = 0 # number of speakers to use with a lookup table + self.speakers_id = [] # ID associated to each speaker -- starts from 0 + + ######################################################## + #### update hyper-parameter variables with **kwargs #### + ######################################################## + for key, value in kwargs.items(): + if hasattr(self, key) and getattr(self, key) is not None and getattr(self, key) != value and verbose: + _logger.warning(f'Changing parameter "{key}" = {value} (was {getattr(self, key)})') + setattr(self, key, value) + + # check if all hyper-params have an assigned value + for param, value in self.__dict__.items(): + assert(value is not None), _logger.error(f'Hyper-parameter "{param}" is None -- please specify a value') + + # give a default value to hyper-parameters that have not been specified in **kwargs + self._set_default_hyper_params(verbose=verbose) + + def _set_default_hyper_params(self, verbose): + ''' Give a default value to hyper-parameters that have not been specified in **kwargs + + :param verbose: whether to display logger info/warnings or not + ''' + # update MFA paths + self.update_mfa_paths() + # set stats if not already set + stats_file = os.path.join(self.output_directory, 'stats.json') + if len(self.stats) == 0 and os.path.isfile(stats_file): + with open(stats_file) as f: + data = f.read() + stats = json.loads(data) + self.stats = stats + + # set symbols if not already set + if len(self.symbols) == 0: + if self.language == 'english': + self.symbols = symbols_english + else: + _logger.error(f'Language: {self.language} -- No default value for "symbols" -- please specify a value') + sys.exit(1) + if verbose: + _logger.info(f'Language: {self.language} -- {len(self.symbols)} symbols used') + # set number of symbols + self.n_symbols = len(self.symbols) + # check padding symbol is at index 0 + # zero padding is used in the DataLoader and Daft-Exprt model + assert(self.symbols.index(pad) == 0), _logger.error(f'Padding symbol "{pad}" must be at index 0') + + # set speakers ID if not already set + if len(self.speakers_id) == 0: + self.speakers_id = [i for i in range(len(self.speakers))] + if verbose: + _logger.info(f'Nb speakers: {len(self.speakers)} -- Changed "speakers_id" to {self.speakers_id}') + # set n_speakers if not already set + if self.n_speakers == 0: + self.n_speakers = len(set(self.speakers_id)) + 1 + if verbose: + _logger.info(f'Nb speakers: {len(set(self.speakers_id))} -- Changed "n_speakers" to {self.n_speakers}\n') + + # check number of speakers is coherent + assert (self.n_speakers >= len(set(self.speakers_id))), \ + _logger.error(f'Parameter "n_speakers" must be superior or equal to the number of speakers -- ' + f'"n_speakers" = {self.n_speakers} -- Number of speakers = {len(set(self.speakers_id))}') + # check items in the lists are unique and have the same size + assert (len(self.speakers) == len(set(self.speakers))), \ + _logger.error(f'Speakers are not unique: {len(self.speakers)} -- {len(set(self.speakers))}') + assert (len(self.speakers) == len(self.speakers_id)), \ + _logger.error(f'Parameters "speakers" and "speakers_id" don\'t have the same length: ' + f'{len(self.speakers)} -- {len(self.speakers_id)}') + + # check FFT/Mel-Spec extraction parameters are correct + assert(self.filter_length % self.hop_length == 0), _logger.error(f'filter_length must be a multiple of hop_length') + + def update_mfa_paths(self): + ''' Update MFA paths to match the ones in the current environment + ''' + # paths used by MFA + home = str(Path.home()) + self.mfa_dictionary = os.path.join(home, 'Documents', 'MFA', 'pretrained_models', 'dictionary', f'{self.language}.dict') + self.mfa_g2p_model = os.path.join(home, 'Documents', 'MFA', 'pretrained_models', 'g2p', f'{self.language}_g2p.zip') + self.mfa_acoustic_model = os.path.join(home, 'Documents', 'MFA', 'pretrained_models', 'acoustic', f'{self.language}.zip') + # check MFA files exist + assert(os.path.isfile(self.mfa_dictionary)), _logger.error(f'There is no such file "{self.mfa_dictionary}"') + assert(os.path.isfile(self.mfa_g2p_model)), _logger.error(f'There is no such file "{self.mfa_g2p_model}"') + assert(os.path.isfile(self.mfa_acoustic_model)), _logger.error(f'There is no such file "{self.mfa_acoustic_model}"') + + def save_hyper_params(self, json_file): + ''' Save hyper-parameters to JSON file + + :param json_file: path of the JSON file to store hyper-parameters + ''' + # create directory if it does not exists + dirname = os.path.dirname(json_file) + os.makedirs(dirname, exist_ok=True) + # extract hyper-parameters used + hyper_params = self.__dict__.copy() + # save hyper-parameters to JSON file + with open(json_file, 'w') as f: + json.dump(hyper_params, f, indent=4, sort_keys=True) diff --git a/src/daft_exprt/logger.py b/src/daft_exprt/logger.py new file mode 100644 index 0000000..3882c0f --- /dev/null +++ b/src/daft_exprt/logger.py @@ -0,0 +1,157 @@ +import random + +import numpy as np +import torch + +from torch.utils.tensorboard import SummaryWriter + +from daft_exprt.extract_features import duration_to_integer +from daft_exprt.utils import histogram_plot, plot_2d_data, scatter_plot + + +class DaftExprtLogger(SummaryWriter): + def __init__(self, logdir): + super(DaftExprtLogger, self).__init__(logdir) + + def log_training(self, loss, indiv_loss, grad_norm, learning_rate, duration, iteration): + ''' Log training info + + :param loss: training batch loss + :param indiv_loss: training batch individual losses + :param grad_norm: norm of the gradient + :param learning_rate: learning rate + :param duration: duration per iteration + :param iteration: current training iteration + ''' + self.add_scalars("DaftExprt.optimization", {'grad_norm': grad_norm, 'learning_rate': learning_rate, + 'duration': duration}, iteration) + self.add_scalars("DaftExprt.training", {'training_loss': loss}, iteration) + for key in indiv_loss: + if indiv_loss[key] != 0: + if 'loss' in key: + self.add_scalars(f"DaftExprt.training", {f'{key}': indiv_loss[key]}, iteration) + + def log_validation(self, val_loss, val_indiv_loss, val_targets, val_outputs, model, hparams, iteration): + ''' Log validation info + + :param val_loss: validation loss + :param val_indiv_loss: individual validation losses + :param val_targets: list of ground-truth values on the valid set + :param val_outputs: list of predicted values on the valid set + :param model: model used for training/validation + :param hparams: hyper-parameters used for training + :param iteration: current training iteration + ''' + # plot validation losses + self.add_scalars("DaftExprt.validation", {"validation_loss": val_loss}, iteration) + for key in val_indiv_loss: + self.add_scalars("DaftExprt.validation", {f'{key}': val_indiv_loss[key]}, iteration) + + # # plot distribution of model parameters + # for tag, value in model.named_parameters(): + # tag = tag.replace('.', '/') + # self.add_histogram(tag, value.detach().cpu().numpy(), iteration) + + # choose random index to extract batch of targets and outputs + idx = random.randint(0, len(val_targets) - 1) + targets = val_targets[idx] + outputs = val_outputs[idx] + # extract predicted outputs and ground-truth values + duration_targets, energy_targets, pitch_targets, mel_spec_targets, _ = targets + _, _, encoder_preds, decoder_preds, alignments = outputs + duration_preds, energy_preds, pitch_preds, input_lengths = encoder_preds + mel_spec_preds, output_lengths = decoder_preds + weights = alignments + # choose random index in the batch + idx = random.randint(0, mel_spec_preds.size(0) - 1) + # extract corresponding sequence length + input_length = input_lengths[idx].item() + output_length = output_lengths[idx].item() + # transfer data to cpu and convert to numpy array + duration_targets = duration_targets[idx, :input_length].detach().cpu().numpy() # (L, ) + duration_preds = duration_preds[idx, :input_length].detach().cpu().numpy() # (L, ) + energy_targets = energy_targets[idx, :input_length].detach().cpu().numpy() # (L, ) + energy_preds = energy_preds[idx, :input_length].detach().cpu().numpy() # (L, ) + pitch_targets = pitch_targets[idx, :input_length].detach().cpu().numpy() # (L, ) + pitch_preds = pitch_preds[idx, :input_length].detach().cpu().numpy() # (L, ) + mel_spec_targets = mel_spec_targets[idx, :, :output_length].detach().cpu().numpy() # (n_mel_channels, T) + mel_spec_preds = mel_spec_preds[idx, :, :output_length].detach().cpu().numpy() # (n_mel_channels, T) + weights = weights[idx, :input_length, :output_length].detach().cpu().numpy() # (L, T) + + # convert target float durations to int durations + duration_int_targets = np.zeros(len(duration_targets), dtype='int32') # (L, ) + end_prev, symbols_idx, durations_float = 0., [], [] + for symbol_id in range(len(duration_targets)): + symb_dur = duration_targets[symbol_id] + if symb_dur != 0.: # ignore 0 durations + symbols_idx.append(symbol_id) + durations_float.append([end_prev, end_prev + symb_dur]) + end_prev += symb_dur + int_durs = duration_to_integer(durations_float, hparams) # (L, ) + duration_int_targets[symbols_idx] = int_durs # (L, ) + # extract target alignments + col_idx = 0 + alignment_targets = np.zeros((len(duration_int_targets), mel_spec_targets.shape[1])) # (L, T) + for symbol_id in range(alignment_targets.shape[0]): + nb_frames = duration_int_targets[symbol_id] + alignment_targets[symbol_id, col_idx: col_idx + nb_frames] = 1. + col_idx += nb_frames + + # extract all FiLM parameter predictions on the validation set + # FiLM parameters for Encoder Module + encoder_film = [output[1][1] for output in val_outputs] # (B, nb_blocks, nb_film_params) + encoder_film = torch.cat(encoder_film, dim=0) # (B_tot, nb_blocks, nb_film_params) + encoder_film = encoder_film.detach().cpu().numpy() # (B_tot, nb_blocks, nb_film_params) + # FiLM parameters for Prosody Predictor Module + prosody_pred_film = [output[1][2] for output in val_outputs] # (B, nb_blocks, nb_film_params) + prosody_pred_film = torch.cat(prosody_pred_film, dim=0) # (B_tot, nb_blocks, nb_film_params) + prosody_pred_film = prosody_pred_film.detach().cpu().numpy() # (B_tot, nb_blocks, nb_film_params) + # FiLM parameters for Decoder Module + decoder_film = [output[1][3] for output in val_outputs] # (B, nb_blocks, nb_film_params) + decoder_film = torch.cat(decoder_film, dim=0) # (B_tot, nb_blocks, nb_film_params) + decoder_film = decoder_film.detach().cpu().numpy() # (B_tot, nb_blocks, nb_film_params) + + # plot histograms of gammas and betas for each block of each module + for module, tensor in zip(['encoder', 'prosody_predictor', 'decoder'], + [encoder_film, prosody_pred_film, decoder_film]): + nb_blocks = tensor.shape[1] + nb_gammas = int(tensor.shape[2] / 2) + gammas = histogram_plot(data=[tensor[:, i, :nb_gammas] for i in range(nb_blocks)], + x_labels=[f'Value -- Block {i}' for i in range(nb_blocks)], + y_labels=['Frequency' for _ in range(nb_blocks)]) + self.add_figure(tag=f'{module} -- FiLM gammas', figure=gammas, global_step=iteration) + betas = histogram_plot(data=[tensor[:, i, nb_gammas:] for i in range(nb_blocks)], + x_labels=[f'Value -- Block {i}' for i in range(nb_blocks)], + y_labels=['Frequency' for _ in range(nb_blocks)]) + self.add_figure(tag=f'{module} -- FiLM betas', figure=betas, global_step=iteration) + # plot duration target and duration pred + durations = scatter_plot(data=(duration_targets, duration_preds), + colors=('blue', 'red'), + labels=('ground-truth', 'predicted'), + x_label='Symbol ID', + y_label='Duration (sec)') + self.add_figure(tag='durations', figure=durations, global_step=iteration) + # plot energy target and energy pred + energies = scatter_plot(data=(energy_targets, energy_preds), + colors=('blue', 'red'), + labels=('ground-truth', 'predicted'), + x_label='Symbol ID', + y_label='Energy (normalized)') + self.add_figure(tag='energies', figure=energies, global_step=iteration) + # plot pitch target and pitch pred + pitch = scatter_plot(data=(pitch_targets, pitch_preds), + colors=('blue', 'red'), + labels=('ground-truth', 'predicted'), + x_label='Symbol ID', + y_label='Pitch (normalized)') + self.add_figure(tag='pitch', figure=pitch, global_step=iteration) + # plot mel-spec target and mel-spec pred + mel_specs = plot_2d_data(data=(mel_spec_targets, mel_spec_preds), + x_labels=('Frames -- Ground Truth', 'Frames -- Predicted'), + y_labels=('Frequencies', 'Frequencies')) + self.add_figure(tag='mel-spectrogram', figure=mel_specs, global_step=iteration) + # plot alignment target and alignment pred + alignments = plot_2d_data(data=(alignment_targets, weights), + x_labels=('Frames -- Ground Truth', 'Frames -- Predicted (from Ground Truth)'), + y_labels=('Symbol ID', 'Symbol ID')) + self.add_figure(tag='alignments', figure=alignments, global_step=iteration) diff --git a/src/daft_exprt/loss.py b/src/daft_exprt/loss.py new file mode 100644 index 0000000..0965015 --- /dev/null +++ b/src/daft_exprt/loss.py @@ -0,0 +1,106 @@ +import torch + +from torch import nn + + +class DaftExprtLoss(nn.Module): + def __init__(self, gpu, hparams): + super(DaftExprtLoss, self).__init__() + self.nb_channels = hparams.n_mel_channels + self.warmup_steps = hparams.warmup_steps + self.adv_max_weight = hparams.adv_max_weight + self.post_mult_weight = hparams.post_mult_weight + self.dur_weight = hparams.dur_weight + self.energy_weight = hparams.energy_weight + self.pitch_weight = hparams.pitch_weight + self.mel_spec_weight = hparams.mel_spec_weight + + self.L1Loss = nn.L1Loss(reduction='none').cuda(gpu) + self.MSELoss = nn.MSELoss(reduction='none').cuda(gpu) + self.CrossEntropy = nn.CrossEntropyLoss().cuda(gpu) + + def update_adversarial_weight(self, iteration): + ''' Update adversarial weight value based on iteration + ''' + weight_iter = iteration * self.warmup_steps ** -1.5 * self.adv_max_weight / self.warmup_steps ** -0.5 + weight = min(self.adv_max_weight, weight_iter) + + return weight + + def forward(self, outputs, targets, iteration): + ''' Compute training loss + + :param outputs: outputs predicted by the model + :param targets: ground-truth targets + :param iteration: current training iteration + ''' + # extract ground-truth targets + # targets are already zero padded + duration_targets, energy_targets, pitch_targets, mel_spec_targets, speaker_ids = targets + duration_targets.requires_grad = False + energy_targets.requires_grad = False + pitch_targets.requires_grad = False + mel_spec_targets.requires_grad = False + speaker_ids.requires_grad = False + + # extract predictions + # predictions are already zero padded + speaker_preds, film_params, encoder_preds, decoder_preds, _ = outputs + post_multipliers, _, _, _ = film_params + duration_preds, energy_preds, pitch_preds, input_lengths = encoder_preds + mel_spec_preds, output_lengths= decoder_preds + + # compute adversarial speaker objective + speaker_loss = self.CrossEntropy(speaker_preds, speaker_ids) + + # compute L2 penalized loss on FiLM scalar post-multipliers + if self.post_mult_weight != 0.: + post_mult_loss = torch.norm(post_multipliers, p=2) + else: + post_mult_loss = torch.tensor([0.]).cuda(speaker_loss.device, non_blocking=True).float() + + # compute duration loss + duration_loss = self.MSELoss(duration_preds, duration_targets) # (B, L_max) + # divide by length of each sequence in the batch + duration_loss = torch.sum(duration_loss, dim=1) / input_lengths # (B, ) + duration_loss = torch.mean(duration_loss) + + # compute energy loss + energy_loss = self.MSELoss(energy_preds, energy_targets) # (B, L_max) + # divide by length of each sequence in the batch + energy_loss = torch.sum(energy_loss, dim=1) / input_lengths # (B, ) + energy_loss = torch.mean(energy_loss) + + # compute pitch loss + pitch_loss = self.MSELoss(pitch_preds, pitch_targets) # (B, L_max) + # divide by length of each sequence in the batch + pitch_loss = torch.sum(pitch_loss, dim=1) / input_lengths # (B, ) + pitch_loss = torch.mean(pitch_loss) + + # compute mel-spec loss + mel_spec_l1_loss = self.L1Loss(mel_spec_preds, mel_spec_targets) # (B, n_mel_channels, T_max) + mel_spec_l2_loss = self.MSELoss(mel_spec_preds, mel_spec_targets) # (B, n_mel_channels, T_max) + # divide by length of each sequence in the batch + mel_spec_l1_loss = torch.sum(mel_spec_l1_loss, dim=(1, 2)) / (self.nb_channels * output_lengths) # (B, ) + mel_spec_l1_loss = torch.mean(mel_spec_l1_loss) + mel_spec_l2_loss = torch.sum(mel_spec_l2_loss, dim=(1, 2)) / (self.nb_channels * output_lengths) # (B, ) + mel_spec_l2_loss = torch.mean(mel_spec_l2_loss) + + # add weights + speaker_weight = self.update_adversarial_weight(iteration) + speaker_loss = speaker_weight * speaker_loss + post_mult_loss = self.post_mult_weight * post_mult_loss + duration_loss = self.dur_weight * duration_loss + energy_loss = self.energy_weight * energy_loss + pitch_loss = self.pitch_weight * pitch_loss + mel_spec_l1_loss = self.mel_spec_weight * mel_spec_l1_loss + mel_spec_l2_loss = self.mel_spec_weight * mel_spec_l2_loss + + loss = speaker_loss + post_mult_loss + duration_loss + energy_loss + pitch_loss + mel_spec_l1_loss + mel_spec_l2_loss + + # create individual loss tracker + individual_loss = {'speaker_loss': speaker_loss.item(), 'post_mult_loss': post_mult_loss.item(), + 'duration_loss': duration_loss.item(), 'energy_loss': energy_loss.item(), 'pitch_loss': pitch_loss.item(), + 'mel_spec_l1_loss': mel_spec_l1_loss.item(), 'mel_spec_l2_loss': mel_spec_l2_loss.item()} + + return loss, individual_loss diff --git a/src/daft_exprt/mfa.py b/src/daft_exprt/mfa.py new file mode 100644 index 0000000..cf051a4 --- /dev/null +++ b/src/daft_exprt/mfa.py @@ -0,0 +1,255 @@ +import logging +import logging.handlers +import os +import uuid + +import tgt + +from shutil import move, rmtree + +from daft_exprt.cleaners import text_cleaner +from daft_exprt.symbols import MFA_SIL_WORD_SYMBOL, MFA_SIL_PHONE_SYMBOLS, MFA_UNK_WORD_SYMBOL, \ + MFA_UNK_PHONE_SYMBOL, SIL_WORD_SYMBOL, SIL_PHONE_SYMBOL +from daft_exprt.utils import launch_multi_process + + +_logger = logging.getLogger(__name__) + + +''' + Align speaker corpuses using MFA + https://montreal-forced-aligner.readthedocs.io/en/latest/ +''' + + +def move_file(file, src_dir, dst_dir, log_queue): + ''' Dummy function to move a file in multi-processing mode + ''' + move(os.path.join(src_dir, file), os.path.join(dst_dir, file)) + + +def prepare_corpus(corpus_dir, language): + ''' Prepare corpus for MFA + Create .lab files for each audio file + ''' + # check wavs directory and speaker metadata file exist + wavs_dir = os.path.join(corpus_dir, 'wavs') + metadata = os.path.join(corpus_dir, 'metadata.csv') + assert(os.path.isdir(wavs_dir)), _logger.error(f'There is no such directory: {wavs_dir}') + assert(os.path.isfile(metadata)), _logger.error(f'There is no such file: {metadata}') + + # extract lines from metadata.csv + with open(metadata, 'r', encoding='utf-8') as f: + lines = f.readlines() + lines = [x.strip().split(sep='|') for x in lines] # [[file_name, text], ...] + # check there is only 1 pipe "|" separator + for line in lines: + assert(len(line) == 2), _logger.error(f'Problem in metadata file: {corpus_dir} -- Line: {line}') + # extract file names and corresponding text + file_names = [line[0].strip() for line in lines] + texts = [line[1].strip() for line in lines] + + # create .lab file for each audio file + wavs = [os.path.join(wavs_dir, x) for x in os.listdir(wavs_dir) if x.endswith('.wav')] + for wav in wavs: + # search metadata lines associated to wav file + wav_name = os.path.basename(wav).replace('.wav', '').strip() + lines_idx = [idx for idx, file_name in enumerate(file_names) if wav_name == file_name] + # only create .lab if ONE line is associated to wav file + if len(lines_idx) == 1: + # get corresponding text and clean it + text = texts[lines_idx[0]] + text = text_cleaner(text, language).strip() + # save it to .lab file + with open(os.path.join(wavs_dir, f'{wav_name}.lab'), 'w', encoding='utf-8') as f: + f.write(text) + # remove lines for computational efficiency + for i, idx in enumerate(lines_idx): + del file_names[idx - i] + del texts[idx - i] + + +def _extract_markers(text_grid_file, log_queue): + ''' Extract word/phone alignment markers from .TextGrid file + ''' + # create logger from logging queue + qh = logging.handlers.QueueHandler(log_queue) + root = logging.getLogger() + if not root.hasHandlers(): + root.setLevel(logging.INFO) + root.addHandler(qh) + logger = logging.getLogger(f"worker{str(uuid.uuid4())}") + + # load text grid + text_grid = tgt.io.read_textgrid(text_grid_file, include_empty_intervals=True) + # extract word and phone tiers + words_tier = text_grid.get_tier_by_name("words") + words = [[word.start_time, word.end_time, word.text] for word in words_tier._objects] + phones_tier = text_grid.get_tier_by_name("phones") + phones = [[phone.start_time, phone.end_time, phone.text] for phone in phones_tier._objects] + # set silence symbol according to chosen nomenclature + for marker in words: + _, _, word = marker + if word == MFA_SIL_WORD_SYMBOL: + marker[-1] = SIL_WORD_SYMBOL + for marker in phones: + _, _, phone = marker + if phone in MFA_SIL_PHONE_SYMBOLS: + marker[-1] = SIL_PHONE_SYMBOL + # merge subsequent silences on phoneme level + # for example, it is possible to have: AH0 - SIL - SIL - OW0 + # this should be merged as follows: AH0 - SIL - OW0 + phones_old = phones.copy() + phones = [phones_old[0]] + for marker in phones_old[1:]: + _, end, phone = marker + prev_phone = phones[-1][2] + if prev_phone == phone == SIL_PHONE_SYMBOL: + phones[-1][1] = end + else: + phones.append(marker) + + # gather words and phones markers together + # ignore if an unknown word/phone is detected + # ignore if a silence is detected withing the word + silence_error = False + all_words = [word for _, _, word in words] + all_phones = [phone for _, _, phone in phones] + if MFA_UNK_WORD_SYMBOL not in all_words and MFA_UNK_PHONE_SYMBOL not in all_phones: + markers = [] + for word_idx, word_marker in enumerate(words): + begin_word, end_word, word = word_marker + for phone_marker in phones: + begin_phone, end_phone, phone = phone_marker + if begin_word <= begin_phone and end_phone <= end_word: + # check silent word and phoneme have a one to one correspondance + if word == SIL_WORD_SYMBOL: + assert(phone == SIL_PHONE_SYMBOL and begin_word == begin_phone and end_word == end_phone), \ + logger.error(f'{text_grid_file} -- error with silence -- word number {word_idx}') + else: # check there are no silence errors + if phone == SIL_PHONE_SYMBOL: + logger.warning(f'{text_grid_file} -- silence within word -- word number {word_idx} -- Ignoring file') + silence_error = True + # add to list + markers.append([f'{begin_phone:.3f}', f'{end_phone:.3f}', phone, word, str(word_idx)]) + else: + # check phone does not overlap with word + assert(end_phone <= begin_word or end_word <= begin_phone), \ + logger.error(f'{text_grid_file} -- word and phoneme overlap -- word number {word_idx}') + + if not silence_error: + # trim leading and tailing silences + phone_lead, phone_tail = markers[0][2], markers[-1][2] + if phone_lead == SIL_PHONE_SYMBOL: + markers.pop(0) + if phone_tail == SIL_PHONE_SYMBOL: + markers.pop(-1) + # check everything is correct with trimming + phone_lead, phone_tail = markers[0][2], markers[-1][2] + assert(phone_lead != SIL_PHONE_SYMBOL and phone_tail != SIL_PHONE_SYMBOL), \ + logger.error(f'{text_grid_file} -- problem with sentence triming') + # check timings are correct + for marker_curr, marker_next in zip(markers[:-1], markers[1:]): + begin_curr, end_curr = marker_curr[0], marker_curr[1] + begin_next, end_next = marker_next[0], marker_next[1] + assert(float(end_curr) == float(begin_next)), logger.error(f'{text_grid_file} -- problem with sentence timing') + assert(float(begin_curr) < float(end_curr)), logger.error(f'{text_grid_file} -- problem with sentence timing') + assert(float(begin_next) < float(end_next)), logger.error(f'{text_grid_file} -- problem with sentence timing') + + # save file in .markers format + text_grid_dir = os.path.dirname(text_grid_file) + file_name = os.path.basename(text_grid_file).replace('.TextGrid', '') + with open(os.path.join(text_grid_dir, f'{file_name}.markers'), 'w', encoding='utf-8') as f: + f.writelines(['\t'.join(x) + '\n' for x in markers]) + + +def extract_markers(text_grid_dir, n_jobs): + ''' Extract word/phone alignment markers from .TextGrid files contained in TextGrid directory + ''' + # get all .TextGrid files contained in the directory that do not have .markers files + all_grid_files = [os.path.join(text_grid_dir, x) for x in os.listdir(text_grid_dir) if x.endswith('.TextGrid')] + grid_files_to_process = [x for x in all_grid_files if not os.path.isfile(x.replace('.TextGrid', '.markers'))] + _logger.info(f'Folder: {text_grid_dir} -- {len(all_grid_files) - len(grid_files_to_process)} TextGrid files already processed -- ' + f'{len(grid_files_to_process)} TextGrid files need to be processed') + + # extract markers for words and phones + launch_multi_process(iterable=grid_files_to_process, func=_extract_markers, n_jobs=n_jobs, timer_verbose=False) + + +def mfa(dataset_dir, hparams, n_jobs): + ''' Run MFA on every speaker data set and extract timing markers for words and phones + ''' + _logger.info('--' * 30) + _logger.info('Running MFA for each speaker data set'.upper()) + _logger.info('--' * 30) + + # perform alignment for each speaker + for speaker in hparams.speakers: + _logger.info(f'Speaker: "{speaker}"') + # check if alignment has already been performed + corpus_dir = os.path.join(dataset_dir, speaker) + align_out_dir = os.path.join(corpus_dir, 'align') + if not os.path.isdir(align_out_dir): + # initialize variables + language = hparams.language + dictionary = hparams.mfa_dictionary + g2p_model = hparams.mfa_g2p_model + acoustic_model = hparams.mfa_acoustic_model + temp_dir = os.path.join(corpus_dir, 'tmp') + + # create .lab files for each audio file + _logger.info('Preparing MFA corpus') + prepare_corpus(corpus_dir, language) + + # # uncomment if you need to validate your corpus before MFA alignment + # # validate corpuses to ensure there are no issues with the data format + # _logger.info('Validating corpus') + # tmp_dir = os.path.join(temp_dir, 'validate') + # os.system(f'mfa validate {corpus_dir} {dictionary} ' + # f'{acoustic_model} -t {tmp_dir} -j {n_jobs}') + # # use a G2P model to generate a pronunciation dictionary for unknown words + # # this can later be added manually to the dictionary + # oovs = os.path.join(tmp_dir, os.path.basename(speaker), 'corpus_data', 'oovs_found.txt') + # if os.path.isfile(oovs): + # _logger.info('Generating transcriptions for unknown words') + # oovs_trans = os.path.join(corpus_dir, 'oovs_transcriptions.txt') + # os.system(f'mfa g2p {g2p_model} {oovs} {oovs_trans}') + + # perform forced alignment with a pretrained acoustic model + _logger.info('Performing forced alignment using a pretrained model') + tmp_dir = os.path.join(temp_dir, 'align') + os.system(f'mfa align {corpus_dir} {dictionary} {acoustic_model} ' + f'{align_out_dir} -t {tmp_dir} -j {n_jobs} -v -c') + + # extract word/phone alignment markers from .TextGrid files + _logger.info('Extracting markers') + text_grid_dir = os.path.join(align_out_dir, 'wavs') + assert(os.path.isdir(text_grid_dir)), _logger.error(f'There is no such dir {text_grid_dir}') + all_files = [x for x in os.listdir(text_grid_dir)] + launch_multi_process(iterable=all_files, func=move_file, n_jobs=n_jobs, + src_dir=text_grid_dir, dst_dir=align_out_dir, timer_verbose=False) + rmtree(text_grid_dir, ignore_errors=True) + extract_markers(align_out_dir, n_jobs) + # move .lab files to markers dir + _logger.info('Moving .lab files to markers directory') + wavs_dir = os.path.join(corpus_dir, 'wavs') + lab_files = [x for x in os.listdir(wavs_dir) if x.endswith('.lab')] + launch_multi_process(iterable=lab_files, func=move_file, n_jobs=n_jobs, + src_dir=wavs_dir, dst_dir=align_out_dir, timer_verbose=False) + # remove temp dir + rmtree(temp_dir, ignore_errors=True) + # display stats + wavs = [x for x in os.listdir(wavs_dir) if x.endswith('.wav')] + markers = [x for x in os.listdir(align_out_dir) if x.endswith('.markers')] + _logger.info(f'{len(markers) / len(wavs) * 100:.2f}% of the data set aligned') + else: + # extract word/phone alignment markers from .TextGrid files + _logger.info('MFA alignment already performed') + _logger.info('Extracting markers') + extract_markers(align_out_dir, n_jobs) + # display stats + wavs_dir = os.path.join(corpus_dir, 'wavs') + wavs = [x for x in os.listdir(wavs_dir) if x.endswith('.wav')] + markers = [x for x in os.listdir(align_out_dir) if x.endswith('.markers')] + _logger.info(f'{len(markers) / len(wavs) * 100:.2f}% of the data set aligned') + _logger.info('') diff --git a/src/daft_exprt/model.py b/src/daft_exprt/model.py new file mode 100644 index 0000000..b53728d --- /dev/null +++ b/src/daft_exprt/model.py @@ -0,0 +1,923 @@ +import numpy as np +import torch + +from collections import namedtuple + +from torch import nn +from torch.autograd import Function +from torch.distributions import Normal +from torch.nn.parameter import Parameter + +from daft_exprt.extract_features import duration_to_integer + + +def get_mask_from_lengths(lengths): + ''' Create a masked tensor from given lengths + + :param lengths: torch.tensor of size (B, ) -- lengths of each example + + :return mask: torch.tensor of size (B, max_length) -- the masked tensor + ''' + max_len = torch.max(lengths) + ids = torch.arange(0, max_len).cuda(lengths.device, non_blocking=True).long() + mask = (ids < lengths.unsqueeze(1)).bool() + return mask + + +class GradientReversalFunction(Function): + @staticmethod + def forward(ctx, x, lambda_): + ctx.lambda_ = lambda_ + return x.clone() + + @staticmethod + def backward(ctx, grads): + lambda_ = ctx.lambda_ + lambda_ = grads.new_tensor(lambda_) + dx = -lambda_ * grads + return dx, None + + +class GradientReversal(torch.nn.Module): + ''' Gradient Reversal Layer + Y. Ganin, V. Lempitsky, + "Unsupervised Domain Adaptation by Backpropagation", + in ICML, 2015. + Forward pass is the identity function + In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed) + ''' + def __init__(self, hparams): + super(GradientReversal, self).__init__() + self.lambda_ = hparams.lambda_reversal + + def forward(self, x): + return GradientReversalFunction.apply(x, self.lambda_) + + +class LinearNorm(nn.Module): + ''' Linear Norm Module: + - Linear Layer + ''' + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) + nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + ''' Forward function of Linear Norm + x = (*, in_dim) + ''' + x = self.linear_layer(x) # (*, out_dim) + + return x + + +class ConvNorm1D(nn.Module): + ''' Conv Norm 1D Module: + - Conv 1D + ''' + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm1D, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, bias=bias) + nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + ''' Forward function of Conv Norm 1D + x = (B, L, in_channels) + ''' + x = x.transpose(1, 2) # (B, in_channels, L) + x = self.conv(x) # (B, out_channels, L) + x = x.transpose(1, 2) # (B, L, out_channels) + + return x + + +class ConvNorm2D(nn.Module): + ''' Conv Norm 2D Module: + - Conv 2D + ''' + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=0, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm2D, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, bias=bias) + nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + ''' Forward function of Conv Norm 2D: + x = (B, H, W, in_channels) + ''' + x = x.permute(0, 3, 1, 2) # (B, in_channels, H, W) + x = self.conv(x) # (B, out_channels, H, W) + x = x.permute(0, 2, 3, 1) # (B, H, W, out_channels) + + return x + + +class PositionalEncoding(nn.Module): + ''' Positional Encoding Module: + - Sinusoidal Positional Embedding + ''' + def __init__(self, embed_dim, max_len=5000, timestep=10000.): + super(PositionalEncoding, self).__init__() + self.embed_dim = embed_dim + pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1) + div_term = torch.exp(torch.arange(0, self.embed_dim, 2).float() * (-np.log(timestep) / self.embed_dim)) # (embed_dim // 2, ) + self.pos_enc = torch.FloatTensor(max_len, self.embed_dim).zero_() # (max_len, embed_dim) + self.pos_enc[:, 0::2] = torch.sin(pos * div_term) + self.pos_enc[:, 1::2] = torch.cos(pos * div_term) + + def forward(self, x): + ''' Forward function of Positional Encoding: + x = (B, N) -- Long or Int tensor + ''' + # initialize tensor + nb_frames_max = torch.max(torch.cumsum(x, dim=1)) + pos_emb = torch.FloatTensor(x.size(0), nb_frames_max, self.embed_dim).zero_() # (B, nb_frames_max, embed_dim) + pos_emb = pos_emb.cuda(x.device, non_blocking=True).float() # (B, nb_frames_max, embed_dim) + + # can be used for absolute or relative positioning + for line_idx in range(x.size(0)): + pos_idx = [] + for column_idx in range(x.size(1)): + idx = x[line_idx, column_idx] + pos_idx.extend([i for i in range(idx)]) + emb = self.pos_enc[pos_idx] # (nb_frames, embed_dim) + pos_emb[line_idx, :emb.size(0), :] = emb + + return pos_emb + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention Module: + - Multi-Head Attention + A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser and I. Polosukhin + "Attention is all you need", + in NeurIPS, 2017. + - Dropout + - Residual Connection + - Layer Normalization + ''' + def __init__(self, hparams): + super(MultiHeadAttention, self).__init__() + self.multi_head_attention = nn.MultiheadAttention(hparams.hidden_embed_dim, + hparams.attn_nb_heads, + hparams.attn_dropout) + self.dropout = nn.Dropout(hparams.attn_dropout) + self.layer_norm = nn.LayerNorm(hparams.hidden_embed_dim) + + def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): + ''' Forward function of Multi-Head Attention: + query = (B, L_max, hidden_embed_dim) + key = (B, T_max, hidden_embed_dim) + value = (B, T_max, hidden_embed_dim) + key_padding_mask = (B, T_max) if not None + attn_mask = (L_max, T_max) if not None + ''' + # compute multi-head attention + # attn_outputs = (L_max, B, hidden_embed_dim) + # attn_weights = (B, L_max, T_max) + attn_outputs, attn_weights = self.multi_head_attention(query.transpose(0, 1), + key.transpose(0, 1), + value.transpose(0, 1), + key_padding_mask=key_padding_mask, + attn_mask=attn_mask) + attn_outputs = attn_outputs.transpose(0, 1) # (B, L_max, hidden_embed_dim) + # apply dropout + attn_outputs = self.dropout(attn_outputs) # (B, L_max, hidden_embed_dim) + # add residual connection and perform layer normalization + attn_outputs = self.layer_norm(attn_outputs + query) # (B, L_max, hidden_embed_dim) + + return attn_outputs, attn_weights + + +class PositionWiseConvFF(nn.Module): + ''' Position Wise Convolutional Feed-Forward Module: + - 2x Conv 1D with ReLU + - Dropout + - Residual Connection + - Layer Normalization + - FiLM conditioning (if film_params is not None) + ''' + def __init__(self, hparams): + super(PositionWiseConvFF, self).__init__() + self.convs = nn.Sequential( + ConvNorm1D(hparams.hidden_embed_dim, hparams.conv_channels, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.ReLU(), + ConvNorm1D(hparams.conv_channels, hparams.hidden_embed_dim, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='linear'), + nn.Dropout(hparams.conv_dropout) + ) + self.layer_norm = nn.LayerNorm(hparams.hidden_embed_dim) + + def forward(self, x, film_params): + ''' Forward function of PositionWiseConvFF: + x = (B, L_max, hidden_embed_dim) + film_params = (B, nb_film_params) + ''' + # pass through convs + outputs = self.convs(x) # (B, L_max, hidden_embed_dim) + # add residual connection and perform layer normalization + outputs = self.layer_norm(outputs + x) # (B, L_max, hidden_embed_dim) + # add FiLM transformation + if film_params is not None: + nb_gammas = int(film_params.size(1) / 2) + assert(nb_gammas == outputs.size(2)) + gammas = film_params[:, :nb_gammas].unsqueeze(1) # (B, 1, hidden_embed_dim) + betas = film_params[:, nb_gammas:].unsqueeze(1) # (B, 1, hidden_embed_dim) + outputs = gammas * outputs + betas # (B, L_max, hidden_embed_dim) + + return outputs + + +class FFTBlock(nn.Module): + ''' FFT Block Module: + - Multi-Head Attention + - Position Wise Convolutional Feed-Forward + - FiLM conditioning (if film_params is not None) + ''' + def __init__(self, hparams): + super(FFTBlock, self).__init__() + self.attention = MultiHeadAttention(hparams) + self.feed_forward = PositionWiseConvFF(hparams) + + def forward(self, x, film_params, mask): + ''' Forward function of FFT Block: + x = (B, L_max, hidden_embed_dim) + film_params = (B, nb_film_params) + mask = (B, L_max) + ''' + # attend + attn_outputs, _ = self.attention(x, x, x, key_padding_mask=mask) # (B, L_max, hidden_embed_dim) + attn_outputs = attn_outputs.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, hidden_embed_dim) + # feed-forward pass + outputs = self.feed_forward(attn_outputs, film_params) # (B, L_max, hidden_embed_dim) + outputs = outputs.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, hidden_embed_dim) + + return outputs + + +class SpeakerClassifier(nn.Module): + ''' Speaker Classifier Module: + - 3x Linear Layers with ReLU + ''' + def __init__(self, hparams): + super(SpeakerClassifier, self).__init__() + nb_speakers = hparams.n_speakers - 1 + embed_dim = hparams.prosody_encoder['hidden_embed_dim'] + + self.classifier = nn.Sequential( + GradientReversal(hparams), + LinearNorm(embed_dim, embed_dim, w_init_gain='relu'), + nn.ReLU(), + LinearNorm(embed_dim, embed_dim, w_init_gain='relu'), + nn.ReLU(), + LinearNorm(embed_dim, nb_speakers, w_init_gain='linear') + ) + + def forward(self, x): + ''' Forward function of Speaker Classifier: + x = (B, embed_dim) + ''' + # pass through classifier + outputs = self.classifier(x) # (B, nb_speakers) + + return outputs + + +class ProsodyEncoder(nn.Module): + ''' Prosody Encoder Module: + - Positional Encoding + - Energy Embedding: + - 1x Conv 1D + - Pitch Embedding: + - 1x Conv 1D + - Mel-Spec PreNet: + - 3x Conv 1D + - 4x FFT Blocks + - Speaker Embedding + - Linear Projection Layer + + This module predicts FiLM parameters to condition the Core Acoustic Model + References: + - E. Perez, F. Strub, H. de Vries, V. Dumoulin and A. Courville, + "FiLM: Visual Reasoning with a General Conditioning Layer", in AAAI, 2018. + - https://ml-retrospectives.github.io/neurips2019/accepted_retrospectives/2019/film/ + - https://distill.pub/2018/feature-wise-transformations/ + - B.N. Oreshkin, P. Rodriguez and A. Lacoste, + "TADAM: Task dependent adaptive metric for improved few-shot learning", arXiv:1805.10123, 2018. + ''' + def __init__(self, hparams): + super(ProsodyEncoder, self).__init__() + n_speakers = hparams.n_speakers + nb_mels = hparams.n_mel_channels + self.post_mult_weight = hparams.post_mult_weight + self.module_params = { + 'encoder': (hparams.phoneme_encoder['nb_blocks'], hparams.phoneme_encoder['hidden_embed_dim']), + 'prosody_predictor': (hparams.local_prosody_predictor['nb_blocks'], hparams.local_prosody_predictor['conv_channels']), + 'decoder': (hparams.frame_decoder['nb_blocks'], hparams.phoneme_encoder['hidden_embed_dim']) + } + Tuple = namedtuple('Tuple', hparams.prosody_encoder) + hparams = Tuple(**hparams.prosody_encoder) + + # positional encoding + self.pos_enc = PositionalEncoding(hparams.hidden_embed_dim) + # energy embedding + self.energy_embedding = ConvNorm1D(1, hparams.hidden_embed_dim, kernel_size=hparams.conv_kernel, + stride=1, padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='linear') + # pitch embedding + self.pitch_embedding = ConvNorm1D(1, hparams.hidden_embed_dim, kernel_size=hparams.conv_kernel, + stride=1, padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='linear') + # mel-spec pre-net convolutions + self.convs = nn.Sequential( + ConvNorm1D(nb_mels, hparams.conv_channels, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.ReLU(), + nn.LayerNorm(hparams.conv_channels), + nn.Dropout(hparams.conv_dropout), + ConvNorm1D(hparams.conv_channels, hparams.conv_channels, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.ReLU(), + nn.LayerNorm(hparams.conv_channels), + nn.Dropout(hparams.conv_dropout), + ConvNorm1D(hparams.conv_channels, hparams.hidden_embed_dim, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.ReLU(), + nn.LayerNorm(hparams.hidden_embed_dim), + nn.Dropout(hparams.conv_dropout) + ) + # FFT blocks + blocks = [] + for _ in range(hparams.nb_blocks): + blocks.append(FFTBlock(hparams)) + self.blocks = nn.ModuleList(blocks) + # speaker embedding + self.spk_embedding = nn.Embedding(n_speakers, hparams.hidden_embed_dim) + torch.nn.init.xavier_uniform_(self.spk_embedding.weight.data) + # projection layers for FiLM parameters + nb_tot_film_params = 0 + for _, module_params in self.module_params.items(): + nb_blocks, conv_channels = module_params + nb_tot_film_params += nb_blocks * conv_channels + self.gammas_predictor = LinearNorm(hparams.hidden_embed_dim, nb_tot_film_params, w_init_gain='linear') + self.betas_predictor = LinearNorm(hparams.hidden_embed_dim, nb_tot_film_params, w_init_gain='linear') + # initialize L2 penalized scalar post-multipliers + # one (gamma, beta) scalar post-multiplier per FiLM layer, i.e per block + if self.post_mult_weight != 0.: + nb_post_multipliers = 0 + for _, module_params in self.module_params.items(): + nb_blocks, _ = module_params + nb_post_multipliers += nb_blocks + self.post_multipliers = Parameter(torch.empty(2, nb_post_multipliers)) # (2, nb_post_multipliers) + nn.init.xavier_uniform_(self.post_multipliers, gain=nn.init.calculate_gain('linear')) # (2, nb_post_multipliers) + else: + self.post_multipliers = 1. + + def forward(self, frames_energy, frames_pitch, mel_specs, speaker_ids, output_lengths): + ''' Forward function of Prosody Encoder: + frames_energy = (B, T_max) + frames_pitch = (B, T_max) + mel_specs = (B, nb_mels, T_max) + speaker_ids = (B, ) + output_lengths = (B, ) + ''' + # compute positional encoding + pos = self.pos_enc(output_lengths.unsqueeze(1)) # (B, T_max, hidden_embed_dim) + # encode energy sequence + frames_energy = frames_energy.unsqueeze(2) # (B, T_max, 1) + energy = self.energy_embedding(frames_energy) # (B, T_max, hidden_embed_dim) + # encode pitch sequence + frames_pitch = frames_pitch.unsqueeze(2) # (B, T_max, 1) + pitch = self.pitch_embedding(frames_pitch) # (B, T_max, hidden_embed_dim) + # pass through convs + mel_specs = mel_specs.transpose(1, 2) # (B, T_max, nb_mels) + outputs = self.convs(mel_specs) # (B, T_max, hidden_embed_dim) + # create mask + mask = ~get_mask_from_lengths(output_lengths) # (B, T_max) + # add encodings and mask tensor + outputs = outputs + energy + pitch + pos # (B, T_max, hidden_embed_dim) + outputs = outputs.masked_fill(mask.unsqueeze(2), 0) # (B, T_max, hidden_embed_dim) + # pass through FFT blocks + for _, block in enumerate(self.blocks): + outputs = block(outputs, None, mask) # (B, T_max, hidden_embed_dim) + # average pooling on the whole time sequence + outputs = torch.sum(outputs, dim=1) / output_lengths.unsqueeze(1) # (B, hidden_embed_dim) + # store prosody embeddings + prosody_embeddings = outputs # (B, hidden_embed_dim) + # encode speaker IDs and add + speaker_ids = self.spk_embedding(speaker_ids) # (B, hidden_embed_dim) + outputs = outputs + speaker_ids # (B, hidden_embed_dim) + + # project outputs to predict all FiLM parameters + gammas = self.gammas_predictor(outputs) # (B, nb_tot_film_params) + betas = self.betas_predictor(outputs) # (B, nb_tot_film_params) + # split FiLM parameters per FiLM-ed module + modules_film_params = [] + column_idx, block_idx = 0, 0 + for _, module_params in self.module_params.items(): + nb_blocks, conv_channels = module_params + module_nb_film_params = nb_blocks * conv_channels + module_gammas = gammas[:, column_idx: column_idx + module_nb_film_params] # (B, module_nb_film_params) + module_betas = betas[:, column_idx: column_idx + module_nb_film_params] # (B, module_nb_film_params) + # split FiLM parameters for each block in the module + B = module_gammas.size(0) + module_gammas = module_gammas.view(B, nb_blocks, -1) # (B, nb_blocks, block_nb_film_params) + module_betas = module_betas.view(B, nb_blocks, -1) # (B, nb_blocks, block_nb_film_params) + # predict gammas in the delta regime, i.e. predict deviation from unity + # add gamma scalar L2 penalized post-multiplier for each block + if self.post_mult_weight != 0.: + gamma_post = self.post_multipliers[0, block_idx: block_idx + nb_blocks] # (nb_blocks, ) + gamma_post = gamma_post.unsqueeze(0).unsqueeze(-1) # (1, nb_blocks, 1) + else: + gamma_post = self.post_multipliers + module_gammas = gamma_post * module_gammas + 1 # (B, nb_blocks, block_nb_film_params) + # add betas scalar L2 penalized post-multiplier for each block + if self.post_mult_weight != 0.: + beta_post = self.post_multipliers[1, block_idx: block_idx + nb_blocks] # (nb_blocks, ) + beta_post = beta_post.unsqueeze(0).unsqueeze(-1) # (1, nb_blocks, 1) + else: + beta_post = self.post_multipliers + module_betas = beta_post * module_betas # (B, nb_blocks, block_nb_film_params) + # concatenate tensors and append to list + module_film_params = torch.cat((module_gammas, module_betas), dim=2) # (B, nb_blocks, nb_film_params) + modules_film_params.append(module_film_params) + # increment variables + block_idx += nb_blocks + column_idx += module_nb_film_params + encoder_film, prosody_pred_film, decoder_film = modules_film_params + + return prosody_embeddings, encoder_film, prosody_pred_film, decoder_film + + +class PhonemeEncoder(nn.Module): + ''' Phoneme Encoder Module: + - Symbols Embedding + - Positional Encoding + - 4x FFT Blocks with FiLM conditioning + ''' + def __init__(self, hparams): + super(PhonemeEncoder, self).__init__() + n_symbols = hparams.n_symbols + embed_dim = hparams.phoneme_encoder['hidden_embed_dim'] + Tuple = namedtuple('Tuple', hparams.phoneme_encoder) + hparams = Tuple(**hparams.phoneme_encoder) + + # symbols embedding and positional encoding + self.symbols_embedding = nn.Embedding(n_symbols, embed_dim) + torch.nn.init.xavier_uniform_(self.symbols_embedding.weight.data) + self.pos_enc = PositionalEncoding(embed_dim) + # FFT blocks + blocks = [] + for _ in range(hparams.nb_blocks): + blocks.append(FFTBlock(hparams)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x, film_params, input_lengths): + ''' Forward function of Phoneme Encoder: + x = (B, L_max) + film_params = (B, nb_blocks, nb_film_params) + input_lengths = (B, ) + ''' + # compute symbols embedding + x = self.symbols_embedding(x) # (B, L_max, hidden_embed_dim) + # compute positional encoding + pos = self.pos_enc(input_lengths.unsqueeze(1)) # (B, L_max, hidden_embed_dim) + # create mask + mask = ~get_mask_from_lengths(input_lengths) # (B, L_max) + # add and mask + x = x + pos # (B, L_max, hidden_embed_dim) + x = x.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, hidden_embed_dim) + # pass through FFT blocks + for idx, block in enumerate(self.blocks): + x = block(x, film_params[:, idx, :], mask) # (B, L_max, hidden_embed_dim) + + return x + + +class LocalProsodyPredictor(nn.Module): + ''' Local Prosody Predictor Module: + - 2x Conv 1D + - FiLM conditioning + - Linear projection + ''' + def __init__(self, hparams): + super(LocalProsodyPredictor, self).__init__() + embed_dim = hparams.phoneme_encoder['hidden_embed_dim'] + Tuple = namedtuple('Tuple', hparams.local_prosody_predictor) + hparams = Tuple(**hparams.local_prosody_predictor) + + # conv1D blocks + blocks = [] + for idx in range(hparams.nb_blocks): + in_channels = embed_dim if idx == 0 else hparams.conv_channels + convs = nn.Sequential( + ConvNorm1D(in_channels, hparams.conv_channels, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.ReLU(), + nn.LayerNorm(hparams.conv_channels), + nn.Dropout(hparams.conv_dropout), + ConvNorm1D(hparams.conv_channels, hparams.conv_channels, + kernel_size=hparams.conv_kernel, stride=1, + padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.ReLU(), + nn.LayerNorm(hparams.conv_channels), + nn.Dropout(hparams.conv_dropout) + ) + blocks.append(convs) + self.blocks = nn.ModuleList(blocks) + # linear projection for prosody prediction + self.projection = LinearNorm(hparams.conv_channels, 3, w_init_gain='linear') + + def forward(self, x, film_params, input_lengths): + ''' Forward function of Local Prosody Predictor: + x = (B, L_max, hidden_embed_dim) + film_params = (B, nb_blocks, nb_film_params) + input_lengths = (B, ) + ''' + # pass through blocks and mask tensor + for idx, block in enumerate(self.blocks): + x = block(x) # (B, L_max, conv_channels) + # add FiLM transformation + block_film_params = film_params[:, idx, :] # (B, nb_film_params) + nb_gammas = int(block_film_params.size(1) / 2) + assert(nb_gammas == x.size(2)) + gammas = block_film_params[:, :nb_gammas].unsqueeze(1) # (B, 1, conv_channels) + betas = block_film_params[:, nb_gammas:].unsqueeze(1) # (B, 1, conv_channels) + x = gammas * x + betas # (B, L_max, conv_channels) + mask = ~get_mask_from_lengths(input_lengths) # (B, L_max) + x = x.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, conv_channels) + # predict prosody params and mask tensor + prosody_preds = self.projection(x) # (B, L_max, 3) + prosody_preds = prosody_preds.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, 3) + # extract prosody params + durations = prosody_preds[:, :, 0] # (B, L_max) + energies = prosody_preds[:, :, 1] # (B, L_max) + pitch = prosody_preds[:, :, 2] # (B, L_max) + + return durations, energies, pitch + + +class GaussianUpsamplingModule(nn.Module): + ''' Gaussian Upsampling Module: + - Duration Projection + - Energy Projection + - Pitch Projection + - Ranges Projection Layer + - Gaussian Upsampling + ''' + def __init__(self, hparams): + super(GaussianUpsamplingModule, self).__init__() + embed_dim = hparams.phoneme_encoder['hidden_embed_dim'] + Tuple = namedtuple('Tuple', hparams.gaussian_upsampling_module) + hparams = Tuple(**hparams.gaussian_upsampling_module) + + # duration, energy and pitch projection layers + self.duration_projection = ConvNorm1D(1, embed_dim, kernel_size=hparams.conv_kernel, + stride=1, padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='linear') + self.energy_projection = ConvNorm1D(1, embed_dim, kernel_size=hparams.conv_kernel, + stride=1, padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='linear') + self.pitch_projection = ConvNorm1D(1, embed_dim, kernel_size=hparams.conv_kernel, + stride=1, padding=int((hparams.conv_kernel - 1) / 2), + dilation=1, w_init_gain='linear') + # ranges predictor + self.projection = nn.Sequential( + LinearNorm(embed_dim, 1, w_init_gain='relu'), + nn.Softplus() + ) + + def forward(self, x, durations_float, durations_int, energies, pitch, input_lengths): + ''' Forward function of Gaussian Upsampling Module: + x = (B, L_max, hidden_embed_dim) + durations_float = (B, L_max) + durations_int = (B, L_max) + energies = (B, L_max) + pitch = (B, L_max) + input_lengths = (B, ) + ''' + # project durations + durations = durations_float.unsqueeze(2) # (B, L_max, 1) + durations = self.duration_projection(durations) # (B, L_max, hidden_embed_dim) + # project energies + energies = energies.unsqueeze(2) # (B, L_max, 1) + energies = self.energy_projection(energies) # (B, L_max, hidden_embed_dim) + # project pitch + pitch = pitch.unsqueeze(2) # (B, L_max, 1) + pitch = self.pitch_projection(pitch) # (B, L_max, hidden_embed_dim) + + # add energy and pitch to encoded input symbols + x = x + energies + pitch # (B, L_max, hidden_embed_dim) + + # predict ranges for each symbol and mask tensor + # use mask_value = 1. because ranges will be used as stds in Gaussian upsampling + # mask_value = 0. would cause NaN values + range_inputs = x + durations # (B, L_max, hidden_embed_dim) + ranges = self.projection(range_inputs) # (B, L_max, 1) + ranges = ranges.squeeze(2) # (B, L_max) + mask = ~get_mask_from_lengths(input_lengths) # (B, L_max) + ranges = ranges.masked_fill(mask, 1) # (B, L_max) + + # perform Gaussian upsampling + # compute Gaussian means + means = durations_int.float() / 2 # (B, L_max) + cumsum = torch.cumsum(durations_int, dim=1) # (B, L_max) + means[:, 1:] += cumsum[:, :-1] # (B, L_max) + # compute Gaussian distributions + means = means.unsqueeze(-1) # (B, L_max, 1) + stds = ranges.unsqueeze(-1) # (B, L_max, 1) + gaussians = Normal(means, stds) # (B, L_max, 1) + # create frames idx tensor + nb_frames_max = torch.max(cumsum) # T_max + frames_idx = torch.FloatTensor([i + 0.5 for i in range(nb_frames_max)]) # (T_max, ) + frames_idx = frames_idx.cuda(x.device, non_blocking=True).float() # (T_max, ) + # compute probs + probs = torch.exp(gaussians.log_prob(frames_idx)) # (B, L_max, T_max) + # apply mask to set probs out of sequence length to 0 + probs = probs.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, T_max) + # compute weights + weights = probs / (torch.sum(probs, dim=1, keepdim=True) + 1e-20) # (B, L_max, T_max) + # compute upsampled embedding + x_upsamp = torch.sum(x.unsqueeze(-1) * weights.unsqueeze(2), dim=1) # (B, input_dim, T_max) + x_upsamp = x_upsamp.permute(0, 2, 1) # (B, T_max, input_dim) + + return x_upsamp, weights + + +class FrameDecoder(nn.Module): + ''' Frame Decoder Module: + - Positional Encoding + - 4x FFT Blocks with FiLM conditioning + - Linear projection + ''' + def __init__(self, hparams): + super(FrameDecoder, self).__init__() + nb_mels = hparams.n_mel_channels + embed_dim = hparams.phoneme_encoder['hidden_embed_dim'] + hparams.frame_decoder['hidden_embed_dim'] = embed_dim + Tuple = namedtuple('Tuple', hparams.frame_decoder) + hparams = Tuple(**hparams.frame_decoder) + + # positional encoding + self.pos_enc = PositionalEncoding(embed_dim) + # FFT blocks + blocks = [] + for _ in range(hparams.nb_blocks): + blocks.append(FFTBlock(hparams)) + self.blocks = nn.ModuleList(blocks) + # linear projection for mel-spec prediction + self.projection = LinearNorm(embed_dim, nb_mels, w_init_gain='linear') + + def forward(self, x, film_params, output_lengths): + ''' Forward function of Decoder Embedding: + x = (B, T_max, hidden_embed_dim) + film_params = (B, nb_blocks, nb_film_params) + output_lengths = (B, ) + ''' + # compute positional encoding + pos = self.pos_enc(output_lengths.unsqueeze(1)) # (B, T_max, hidden_embed_dim) + # create mask + mask = ~get_mask_from_lengths(output_lengths) # (B, T_max) + # add and mask + x = x + pos # (B, T_max, hidden_embed_dim) + x = x.masked_fill(mask.unsqueeze(2), 0) # (B, T_max, hidden_embed_dim) + # pass through FFT blocks + for idx, block in enumerate(self.blocks): + x = block(x, film_params[:, idx, :], mask) # (B, T_max, hidden_embed_dim) + # predict mel-spec frames and mask tensor + mel_specs = self.projection(x) # (B, T_max, nb_mels) + mel_specs = mel_specs.masked_fill(mask.unsqueeze(2), 0) # (B, T_max, nb_mels) + mel_specs = mel_specs.transpose(1, 2) # (B, nb_mels, T_max) + + return mel_specs + + +class DaftExprt(nn.Module): + ''' DaftExprt model from J. Zaïdi, H. Seuté, B. van Niekerk, M.A. Carbonneau + "DaftExprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis" + arXiv:2108.02271, 2021. + ''' + def __init__(self, hparams): + super(DaftExprt, self).__init__() + self.prosody_encoder = ProsodyEncoder(hparams) + self.speaker_classifier = SpeakerClassifier(hparams) + self.phoneme_encoder = PhonemeEncoder(hparams) + self.prosody_predictor = LocalProsodyPredictor(hparams) + self.gaussian_upsampling = GaussianUpsamplingModule(hparams) + self.frame_decoder = FrameDecoder(hparams) + + def parse_batch(self, gpu, batch): + ''' Parse input batch + ''' + # extract tensors + symbols, durations_float, durations_int, symbols_energy, symbols_pitch, input_lengths, \ + frames_energy, frames_pitch, mel_specs, output_lengths, speaker_ids, feature_dirs, feature_files = batch + + # transfer tensors to specified GPU + symbols = symbols.cuda(gpu, non_blocking=True).long() # (B, L_max) + durations_float = durations_float.cuda(gpu, non_blocking=True).float() # (B, L_max) + durations_int = durations_int.cuda(gpu, non_blocking=True).long() # (B, L_max) + symbols_energy = symbols_energy.cuda(gpu, non_blocking=True).float() # (B, L_max) + symbols_pitch = symbols_pitch.cuda(gpu, non_blocking=True).float() # (B, L_max) + input_lengths = input_lengths.cuda(gpu, non_blocking=True).long() # (B, ) + frames_energy = frames_energy.cuda(gpu, non_blocking=True).float() # (B, T_max) + frames_pitch = frames_pitch.cuda(gpu, non_blocking=True).float() # (B, T_max) + mel_specs = mel_specs.cuda(gpu, non_blocking=True).float() # (B, n_mel_channels, T_max) + output_lengths = output_lengths.cuda(gpu, non_blocking=True).long() # (B, ) + speaker_ids = speaker_ids.cuda(gpu, non_blocking=True).long() # (B, ) + + # create inputs and targets + inputs = (symbols, durations_float, durations_int, symbols_energy, symbols_pitch, input_lengths, + frames_energy, frames_pitch, mel_specs, output_lengths, speaker_ids) + targets = (durations_float, symbols_energy, symbols_pitch, mel_specs, speaker_ids) + file_ids = (feature_dirs, feature_files) + + return inputs, targets, file_ids + + def forward(self, inputs): + ''' Forward function of DaftExprt + ''' + # extract inputs + symbols, durations_float, durations_int, symbols_energy, symbols_pitch, input_lengths, \ + frames_energy, frames_pitch, mel_specs, output_lengths, speaker_ids = inputs + input_lengths, output_lengths = input_lengths.detach(), output_lengths.detach() + + # extract FiLM parameters from reference and speaker ID + # (B, nb_blocks, nb_film_params) + prosody_embed, encoder_film, prosody_pred_film, decoder_film = self.prosody_encoder(frames_energy, frames_pitch, mel_specs, speaker_ids, output_lengths) + # pass through speaker classifier + spk_preds = self.speaker_classifier(prosody_embed) # (B, nb_speakers) + # embed phoneme symbols, add positional encoding and encode input sequence + enc_outputs = self.phoneme_encoder(symbols, encoder_film, input_lengths) # (B, L_max, hidden_embed_dim) + # predict prosody parameters + duration_preds, energy_preds, pitch_preds = self.prosody_predictor(enc_outputs, prosody_pred_film, input_lengths) # (B, L_max) + # perform Gaussian upsampling on symbols sequence + # use prosody ground-truth values for training + # symbols_upsamp = (B, T_max, hidden_embed_dim) + # weights = (B, L_max, T_max) + symbols_upsamp, weights = self.gaussian_upsampling(enc_outputs, durations_float, durations_int, symbols_energy, symbols_pitch, input_lengths) + # decode output sequence and predict mel-specs + mel_spec_preds = self.frame_decoder(symbols_upsamp, decoder_film, output_lengths) # (B, nb_mels, T_max) + + # parse outputs + speaker_preds = spk_preds + film_params = [self.prosody_encoder.post_multipliers, encoder_film, prosody_pred_film, decoder_film] + encoder_preds = [duration_preds, energy_preds, pitch_preds, input_lengths] + decoder_preds = [mel_spec_preds, output_lengths] + alignments = weights + + return speaker_preds, film_params, encoder_preds, decoder_preds, alignments + + def get_int_durations(self, duration_preds, hparams): + ''' Convert float durations to integer frame durations + ''' + # min float duration to have at least one mel-spec frame attributed to the symbol + fft_length = hparams.filter_length / hparams.sampling_rate + dur_min = fft_length / 2 + # set duration under min duration to 0. + duration_preds[duration_preds < dur_min] = 0. # (B, L_max) + # convert to int durations for each element in the batch + durations_int = torch.LongTensor(duration_preds.size(0), duration_preds.size(1)).zero_() # (B, L_max) + for line_idx in range(duration_preds.size(0)): + end_prev, symbols_idx, durations_float = 0., [], [] + for symbol_id in range(duration_preds.size(1)): + symb_dur = duration_preds[line_idx, symbol_id].item() + if symb_dur != 0.: # ignore 0 durations + symbols_idx.append(symbol_id) + durations_float.append([end_prev, end_prev + symb_dur]) + end_prev += symb_dur + int_durs = torch.LongTensor(duration_to_integer(durations_float, hparams)) # (L_max, ) + durations_int[line_idx, symbols_idx] = int_durs + # put on GPU + durations_int = durations_int.cuda(duration_preds.device, non_blocking=True).long() # (B, L_max) + + return duration_preds, durations_int + + def pitch_shift(self, pitch_preds, pitch_factors, hparams, speaker_ids): + ''' Pitch shift pitch predictions + Pitch factors are assumed to be in Hz + ''' + # keep track of unvoiced idx + zero_idxs = (pitch_preds == 0.).nonzero() # (N, 2) + # pitch factors are F0 shifts in Hz + # pitch_factors = [[+50, -20, ...], ..., [+30, -10, ...]] + for line_idx in range(pitch_preds.size(0)): + speaker_id = speaker_ids[line_idx].item() + pitch_mean = hparams.stats[f'spk {speaker_id}']['pitch']['mean'] + pitch_std = hparams.stats[f'spk {speaker_id}']['pitch']['std'] + pitch_preds[line_idx] = torch.exp(pitch_std * pitch_preds[line_idx] + pitch_mean) # (L_max) + # perform pitch shift in Hz domain + pitch_preds[line_idx] += pitch_factors[line_idx] # (L_max) + # go back to log and re-normalize using pitch training stats + pitch_preds[line_idx] = (torch.log(pitch_preds[line_idx]) - pitch_mean) / pitch_std # (L_max) + # set unvoiced idx to zero + pitch_preds[zero_idxs[:, 0], zero_idxs[:, 1]] = 0. + + return pitch_preds + + def pitch_multiply(self, pitch_preds, pitch_factors): + ''' Apply multiply transform to pitch prediction with respect to the mean + + Effects of factor values on the pitch: + ]0, +inf[ amplify + 0 no effect + ]-1, 0[ de-amplify + -1 flatten + ]-2, -1[ invert de-amplify + -2 invert + ]-inf, -2[ invert amplify + ''' + # multiply pitch for each element in the batch + for line_idx in range(pitch_preds.size(0)): + # keep track of voiced and unvoiced idx + non_zero_idxs = pitch_preds[line_idx].nonzero() # (M, ) + zero_idxs = (pitch_preds[line_idx] == 0.).nonzero() # (N, ) + # compute mean of voiced values + mean_pitch = torch.mean(pitch_preds[line_idx, non_zero_idxs]) + # compute deviation to the mean for each pitch prediction + pitch_deviation = pitch_preds[line_idx] - mean_pitch # (L_max) + # multiply factors to pitch deviation + pitch_deviation *= pitch_factors[line_idx] # (L_max) + # add deviation to pitch predictions + pitch_preds[line_idx] += pitch_deviation # (L_max) + # reset unvoiced values to 0 + pitch_preds[line_idx, zero_idxs] = 0. + + return pitch_preds + + def inference(self, inputs, pitch_transform, hparams): + ''' Inference function of DaftExprt + ''' + # symbols = (B, L_max) + # dur_factors = (B, L_max) + # energy_factors = (B, L_max) + # pitch_factors = (B, L_max) + # input_lengths = (B, ) + # energy_refs = (B, T_max) + # pitch_refs = (B, T_max) + # mel_spec_refs = (B, n_mel_channels, T_max) + # ref_lengths = (B, ) + # speaker_ids = (B, ) + symbols, dur_factors, energy_factors, pitch_factors, input_lengths, \ + energy_refs, pitch_refs, mel_spec_refs, ref_lengths, speaker_ids = inputs + + # extract FiLM parameters from reference and speaker ID + # (B, nb_blocks, nb_film_params) + _, encoder_film, prosody_pred_film, decoder_film = self.prosody_encoder(energy_refs, pitch_refs, mel_spec_refs, speaker_ids, ref_lengths) + # embed phoneme symbols, add positional encoding and encode input sequence + enc_outputs = self.phoneme_encoder(symbols, encoder_film, input_lengths) # (B, L_max, hidden_embed_dim) + # predict prosody parameters + duration_preds, energy_preds, pitch_preds = self.prosody_predictor(enc_outputs, prosody_pred_film, input_lengths) # (B, L_max) + + # multiply durations by duration factors and extract int durations + duration_preds *= dur_factors # (B, L_max) + duration_preds, durations_int = self.get_int_durations(duration_preds, hparams) # (B, L_max) + # add energy factors to energies + # set 0 energy for symbols with 0 duration + energy_preds *= energy_factors # (B, L_max) + energy_preds[durations_int == 0] = 0. # (B, L_max) + # set unvoiced pitch for symbols with 0 duration + # apply pitch factors using specified transformation + pitch_preds[durations_int == 0] = 0. + if pitch_transform == 'add': + pitch_preds = self.pitch_shift(pitch_preds, pitch_factors, hparams, speaker_ids) # (B, L_max) + elif pitch_transform == 'multiply': + pitch_preds = self.pitch_multiply(pitch_preds, pitch_factors) # (B, L_max) + else: + raise NotImplementedError + + # perform Gaussian upsampling on symbols sequence + # symbols_upsamp = (B, T_max, hidden_embed_dim) + # weights = (B, L_max, T_max) + symbols_upsamp, weights = self.gaussian_upsampling(enc_outputs, duration_preds, durations_int, energy_preds, pitch_preds, input_lengths) + # get sequence output length for each element in the batch + output_lengths = torch.sum(durations_int, dim=1) # (B, ) + output_lengths = output_lengths.cuda(symbols_upsamp.device, non_blocking=True).long() # (B, ) + assert(torch.max(output_lengths) == symbols_upsamp.size(1)) + # decode output sequence and predict mel-specs + mel_spec_preds = self.frame_decoder(symbols_upsamp, decoder_film, output_lengths) # (B, nb_mels, T_max) + + # parse outputs + encoder_preds = [duration_preds, durations_int, energy_preds, pitch_preds, input_lengths] + decoder_preds = [mel_spec_preds, output_lengths] + alignments = weights + + return encoder_preds, decoder_preds, alignments diff --git a/src/daft_exprt/normalize_numbers.py b/src/daft_exprt/normalize_numbers.py new file mode 100644 index 0000000..29359c7 --- /dev/null +++ b/src/daft_exprt/normalize_numbers.py @@ -0,0 +1,74 @@ +import inflect +import re + + +''' +from https://github.com/keithito/tacotron +''' + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text \ No newline at end of file diff --git a/src/daft_exprt/symbols.py b/src/daft_exprt/symbols.py new file mode 100644 index 0000000..fe4eacd --- /dev/null +++ b/src/daft_exprt/symbols.py @@ -0,0 +1,36 @@ +import string + + +# silence symbols and unknown word symbols used by MFA in ".TextGrid" files +MFA_SIL_WORD_SYMBOL = '' +MFA_SIL_PHONE_SYMBOLS = ['', 'sp', 'sil'] +MFA_UNK_WORD_SYMBOL = '' +MFA_UNK_PHONE_SYMBOL = 'spn' + +# silence symbols used in ".markers" files +# allows to only have 1 silence symbol instead of 3 +SIL_WORD_SYMBOL = '' +SIL_PHONE_SYMBOL = 'SIL' + +# PAD and EOS token +pad = '_' +eos = '~' + +# whitespace character +whitespace = ' ' + +# punctuation to consider in input sentence +punctuation = ',.!?' + +# Arpabet stressed phonetic set +arpabet_stressed = ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AO1', 'AO2', 'AW0', + 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', + 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', + 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', + 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'] + +# ascii letters +ascii = string.ascii_lowercase.upper() + string.ascii_lowercase + +# symbols used by Daft-Exprt in english language +symbols_english = list(pad + eos + whitespace + punctuation) + arpabet_stressed diff --git a/src/daft_exprt/train.py b/src/daft_exprt/train.py new file mode 100644 index 0000000..e59faf8 --- /dev/null +++ b/src/daft_exprt/train.py @@ -0,0 +1,638 @@ +import matplotlib +matplotlib.use('Agg') + +import argparse +import json +import logging +import math +import os +import random +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from dateutil.relativedelta import relativedelta +from shutil import copyfile + +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Adam + +from daft_exprt.data_loader import prepare_data_loaders +from daft_exprt.extract_features import FEATURES_HPARAMS, check_features_config_used +from daft_exprt.generate import extract_reference_parameters, prepare_sentences_for_inference, generate_mel_specs +from daft_exprt.hparams import HyperParams +from daft_exprt.logger import DaftExprtLogger +from daft_exprt.loss import DaftExprtLoss +from daft_exprt.model import DaftExprt +from daft_exprt.utils import get_nb_jobs + + +_logger = logging.getLogger(__name__) + + +def check_train_config(hparams): + ''' Check hyper-parameters used for training are the same than the one used to extract features + + :param hparams: hyper-parameters currently used for training + ''' + # extract features dirs used for training + with open(hparams.training_files, 'r', encoding='utf-8') as f: + lines = f.readlines() + features_dirs = [line.strip().split(sep='|')[0] for line in lines] + features_dirs = list(set(features_dirs)) + + # compare hyper-params + _logger.info('--' * 30) + _logger.info(f'Comparing training config with the one used to extract features'.upper()) + for features_dir in features_dirs: + same_config = check_features_config_used(features_dir, hparams) + assert(same_config), _logger.error(f'Parameters used for feature extraction in "{features_dir}" ' + f'mismatch with current training parameters.') + _logger.info('--' * 30 + '\n') + + +def save_checkpoint(model, optimizer, hparams, learning_rate, + iteration, best_val_loss=None, filepath=None): + ''' Save a model/optimizer state and store additional training info + + :param model: current model state + :param optimizer: current optimizer state + :param hparams: hyper-parameters used for training + :param learning_rate: current learning rate value + :param iteration: current training iteration + :param best_val_loss: current best validation loss + :param filepath: path to save the checkpoint + ''' + # get output directory where checkpoint is saved and make directory if it doesn't exists + output_directory = os.path.dirname(filepath) + os.makedirs(output_directory, exist_ok=True) + # save checkpoint + _logger.info(f'Saving model and optimizer state at iteration "{iteration}" to "{filepath}"') + torch.save({'iteration': iteration, + 'learning_rate': learning_rate, + 'best_val_loss': best_val_loss, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'config_params': hparams.__dict__.copy()}, filepath) + + +def load_checkpoint(checkpoint_path, gpu, model, optimizer, hparams): + ''' Load a model/optimizer state and additional training info + + :param checkpoint_path: path of the checkpoint to load + :param gpu: GPU ID that hosts the model + :param model: current model state we want to update with checkpoint + :param optimizer: current optimizer state we want to update with checkpoint + :param hparams: hyper-parameters used for training + + :return: model/optimizer and additional training info + ''' + # load checkpoint dict + # map model to be loaded to specified single gpu + assert os.path.isfile(checkpoint_path), \ + _logger.error(f'Checkpoint "{checkpoint_path}" does not exist') + _logger.info(f'Loading checkpoint "{checkpoint_path}"') + checkpoint_dict = torch.load(checkpoint_path, map_location=f'cuda:{gpu}') + # compare current hparams with the ones used in checkpoint + hparams_checkpoint = HyperParams(verbose=False, **checkpoint_dict['config_params']) + params_to_compare = hparams.__dict__.copy() + for param in params_to_compare: + if param in FEATURES_HPARAMS: + assert(getattr(hparams, param) == getattr(hparams_checkpoint, param)), \ + _logger.error(f'Parameter "{param}" is different between current config and the one used in checkpoint -- ' + f'Was {getattr(hparams_checkpoint, param)} in checkpoint and now is {getattr(hparams, param)}') + else: + if not hasattr(hparams, param): + _logger.warning(f'Parameter "{param}" does not exist in the current training config but existed in checkpoint config') + elif not hasattr(hparams_checkpoint, param): + _logger.warning(f'Parameter "{param}" exists in the current training confid but did not exist in checkpoint config') + elif getattr(hparams, param) != getattr(hparams_checkpoint, param): + _logger.warning(f'Parameter "{param}" has changed -- Was {getattr(hparams_checkpoint, param)} ' + f'in checkpoint and now is {getattr(hparams, param)}') + + # assign checkpoint weights to the model + try: + model.load_state_dict(checkpoint_dict['state_dict']) + except RuntimeError as e: + _logger.error(f'Error when trying to load the checkpoint -- "{e}"\n') + + # check if the optimizers are compatible + k_new = optimizer.param_groups + k_loaded = checkpoint_dict['optimizer']['param_groups'] + if len(k_loaded) != len(k_new): + _logger.warning(f'The optimizer in the loaded checkpoint does not have the same number of parameters ' + f'as the blank optimizer -- Creating a new optimizer.') + else: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + + # load additional values + iteration = checkpoint_dict['iteration'] + learning_rate = checkpoint_dict['learning_rate'] + best_val_loss = checkpoint_dict['best_val_loss'] + _logger.info(f'Loaded checkpoint "{checkpoint_path}" from iteration "{iteration}"\n') + + return model, optimizer, iteration, learning_rate, best_val_loss + + +def update_learning_rate(hparams, iteration): + ''' Increase the learning rate linearly for the first warmup_steps training steps, + and decrease it thereafter proportionally to the inverse square root of the step number + ''' + initial_learning_rate = hparams.initial_learning_rate + max_learning_rate = hparams.max_learning_rate + warmup_steps = hparams.warmup_steps + if iteration < warmup_steps: + learning_rate = (max_learning_rate - initial_learning_rate) / warmup_steps * iteration + initial_learning_rate + else: + learning_rate = iteration ** -0.5 * max_learning_rate / warmup_steps ** -0.5 + + return learning_rate + + +def generate_benchmark_sentences(model, hparams, output_dir): + ''' Generate benchmark sentences using Daft-Exprt model + + :param model: model to use for synthesis + :param hparams: hyper-params used for training/synthesis + :param output_dir: directory to store synthesized files + ''' + # set random speaker id + speaker_id = random.choice(hparams.speakers_id) + # choose reference for style transfer + with open(hparams.validation_files, 'r', encoding='utf-8') as f: + references = [line.strip().split('|') for line in f] + reference = random.choice(references) + reference_path, file_name = reference[0], reference[1] + speaker_name = [speaker for speaker in hparams.speakers if reference_path.endswith(speaker)][0] + audio_ref = f'{os.path.join(hparams.data_set_dir, speaker_name, "wavs", file_name)}.wav' + # display info + _logger.info('\nGenerating benchmark sentences with the following parameters:') + _logger.info(f'speaker_id = {speaker_id}') + _logger.info(f'audio_ref = {audio_ref}\n') + + # prepare benchmark sentences + n_jobs = get_nb_jobs('max') + text_file = os.path.join(hparams.benchmark_dir, hparams.language, 'sentences.txt') + sentences, file_names = \ + prepare_sentences_for_inference(text_file, output_dir, hparams, n_jobs) + # extract reference prosody parameters + extract_reference_parameters(audio_ref, output_dir, hparams) + # duplicate reference parameters + file_name = os.path.basename(audio_ref).replace('.wav', '') + refs = [os.path.join(output_dir, f'{file_name}.npz') for _ in range(len(sentences))] + # generate mel_specs and audios with Griffin-Lim + speaker_ids = [speaker_id for _ in range(len(sentences))] + generate_mel_specs(model, sentences, file_names, speaker_ids, refs, + output_dir, hparams, use_griffin_lim=True) + # copy audio ref + copyfile(audio_ref, os.path.join(output_dir, f'{file_name}.wav')) + + +def validate(gpu, model, criterion, val_loader, hparams): + ''' Handles all the validation scoring and printing + + :param gpu: GPU ID that hosts the model + :param model: model to evaluate + :param criterion: criterion used for training + :param val_loader: validation Data Loader + :param hparams: hyper-params used for training + + :return: validation loss score + ''' + # initialize variables + val_loss = 0. + val_indiv_loss = { + 'duration_loss': 0., 'energy_loss':0., 'pitch_loss': 0., + 'mel_spec_l1_loss': 0., 'mel_spec_l2_loss': 0. + } + val_targets, val_outputs = [], [] + + # set eval mode + model.eval() + with torch.no_grad(): + # iterate over validation set + for i, batch in enumerate(val_loader): + if hparams.multiprocessing_distributed: + inputs, targets, _ = model.module.parse_batch(gpu, batch) + else: + inputs, targets, _ = model.parse_batch(gpu, batch) + outputs = model(inputs) + loss, individual_loss = criterion(outputs, targets, iteration=0) + val_targets.append(targets) + val_outputs.append(outputs) + val_loss += loss.item() + for key in val_indiv_loss: + val_indiv_loss[key] += individual_loss[key] + # normalize losses + val_loss = val_loss / (i + 1) + for key in val_indiv_loss: + val_indiv_loss[key] = val_indiv_loss[key] / (i + 1) + + return val_loss, val_indiv_loss, val_targets, val_outputs + + +def train(gpu, hparams, log_file): + ''' Train Daft-Exprt model + + :param gpu: GPU ID to host the model + :param hparams: hyper-params used for training + :param log_file: file path for logging + ''' + # --------------------------------------------------------- + # initialize distributed group + # --------------------------------------------------------- + if hparams.multiprocessing_distributed: + # for multiprocessing distributed training, rank needs to be the + # global rank among all the processes + hparams.rank = hparams.rank * hparams.ngpus_per_node + gpu + dist.init_process_group(backend=hparams.dist_backend, init_method=hparams.dist_url, + world_size=hparams.world_size, rank=hparams.rank) + + # --------------------------------------------------------- + # create loggers + # --------------------------------------------------------- + # set logger config + # we log INFO to file only from rank0, node0 to avoid unnecessary log duplication + if hparams.rank == 0: + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file) + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + # create tensorboard logger + log_dir = os.path.dirname(log_file) + tensorboard_logger = DaftExprtLogger(log_dir) + else: + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file) + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.ERROR + ) + + # --------------------------------------------------------- + # create model + # --------------------------------------------------------- + # load model on GPU + torch.cuda.set_device(gpu) + model = DaftExprt(hparams).cuda(gpu) + + # for multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices + if hparams.multiprocessing_distributed: + model = DDP(model, device_ids=[gpu]) + + # --------------------------------------------------------- + # define training loss and optimizer + # --------------------------------------------------------- + criterion = DaftExprtLoss(gpu, hparams) + optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), + betas=hparams.betas, eps=hparams.epsilon, + weight_decay=hparams.weight_decay, amsgrad=False) + + # --------------------------------------------------------- + # optionally resume from a checkpoint + # --------------------------------------------------------- + iteration, best_val_loss = 1, float('inf') + if hparams.checkpoint != "": + model, optimizer, iteration, learning_rate, best_val_loss = \ + load_checkpoint(hparams.checkpoint, gpu, model, optimizer, hparams) + iteration += 1 # next iteration is iteration + 1 + + # --------------------------------------------------------- + # set learning rate + # --------------------------------------------------------- + learning_rate = update_learning_rate(hparams, iteration) + for param_group in optimizer.param_groups: + if param_group['lr'] is not None: + param_group['lr'] = learning_rate + + # --------------------------------------------------------- + # prepare Data Loaders + # --------------------------------------------------------- + train_loader, train_sampler, val_loader, nb_training_examples = \ + prepare_data_loaders(hparams, num_workers=8) + + # --------------------------------------------------------- + # display training info + # --------------------------------------------------------- + # compute the number of epochs + nb_iterations_per_epoch = int(len(train_loader) / hparams.accumulation_steps) + epoch_offset = max(0, int(iteration / nb_iterations_per_epoch)) + epochs = int(hparams.nb_iterations / nb_iterations_per_epoch) + 1 + + _logger.info('**' * 40) + _logger.info(f"Batch size: {hparams.batch_size * hparams.accumulation_steps * hparams.world_size:_}") + _logger.info(f"Nb examples: {nb_training_examples:_}") + _logger.info(f"Nb iterations per epoch: {nb_iterations_per_epoch:_}") + _logger.info(f"Nb total of epochs: {epochs:_}") + _logger.info(f"Started at epoch: {epoch_offset:_}") + _logger.info('**' * 40 + '\n') + + # ========================================================= + # MAIN TRAINNIG LOOP + # ========================================================= + # set variables + tot_loss = 0. + indiv_loss = { + 'speaker_loss': 0., 'post_mult_loss': 0., + 'duration_loss': 0., 'energy_loss':0., 'pitch_loss': 0., + 'mel_spec_l1_loss': 0., 'mel_spec_l2_loss': 0. + } + total_time = 0. + start = time.time() + accumulation_step = 0 + + model.train() # set training mode + model.zero_grad() # set gradients to 0 + for epoch in range(epoch_offset, epochs): + _logger.info(30 * '=') + _logger.info(f"| Epoch: {epoch}") + _logger.info(30 * '=' + '\n') + + # shuffle dataset + if hparams.multiprocessing_distributed: + train_sampler.set_epoch(epoch) + + # iterate over examples + for batch in train_loader: + # --------------------------------------------------------- + # forward pass + # --------------------------------------------------------- + if hparams.multiprocessing_distributed: + inputs, targets, _ = model.module.parse_batch(gpu, batch) + else: + inputs, targets, _ = model.parse_batch(gpu, batch) + + outputs = model(inputs) + loss, individual_loss = criterion(outputs, targets, iteration) # loss / batch_size + loss = loss / hparams.accumulation_steps # loss / (batch_size * accumulation_steps) + + # track losses + tot_loss += loss.item() + for key in individual_loss: + # individual losses are already detached from the graph + # individual_loss / (batch_size * accumulation_steps) + indiv_loss[key] += individual_loss[key] / hparams.accumulation_steps + + # --------------------------------------------------------- + # backward pass + # --------------------------------------------------------- + loss.backward() + accumulation_step += 1 + + # --------------------------------------------------------- + # accumulate gradient + # --------------------------------------------------------- + if accumulation_step == hparams.accumulation_steps: + # clip gradients + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) + # update weights + optimizer.step() + + # --------------------------------------------------------- + # reporting + # --------------------------------------------------------- + if not math.isnan(tot_loss): + if hparams.rank == 0: + # get current learning rate + for param_group in optimizer.param_groups: + if param_group['lr'] is not None: + learning_rate = param_group['lr'] + break + # display iteration stats + duration = time.time() - start + total_time += duration + _logger.info(f'Train loss [{iteration}]: {tot_loss:.6f} Grad Norm {grad_norm:.6f} ' + f'{duration:.2f}s/it (LR {learning_rate:.6f})') + # update tensorboard logging + tensorboard_logger.log_training(tot_loss, indiv_loss, grad_norm, + learning_rate, duration, iteration) + # barrier for distributed processes + if hparams.multiprocessing_distributed: + dist.barrier() + + # --------------------------------------------------------- + # model evaluation + # --------------------------------------------------------- + if iteration % hparams.iters_check_for_model_improvement == 0: + # validate model + _logger.info('Validating....') + val_loss, val_indiv_loss, val_targets, val_outputs = validate(gpu, model, criterion, val_loader, hparams) + if hparams.rank == 0: + # display remaining time + _logger.info(f"Validation loss {iteration}: {val_loss:.6f} ") + _logger.info("estimated required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}". + format(relativedelta(seconds=int((hparams.nb_iterations - iteration) * + (total_time / hparams.iters_check_for_model_improvement))))) + total_time = 0 + # log validation loss + tensorboard_logger.log_validation(val_loss, val_indiv_loss, val_targets, + val_outputs, model, hparams, iteration) + + # save as the best model + if val_loss < best_val_loss: + # update validation loss + _logger.info('Congrats!!! A new best model. You are the best!') + best_val_loss = val_loss + # save checkpoint and generate benchmark sentences + checkpoint_path = os.path.join(hparams.output_directory, 'checkpoints', 'DaftExprt_best') + save_checkpoint(model, optimizer, hparams, learning_rate, + iteration, best_val_loss, checkpoint_path) + output_dir = os.path.join(hparams.output_directory, 'checkpoints', 'best_checkpoint') + generate_benchmark_sentences(model, hparams, output_dir) + # barrier for distributed processes + if hparams.multiprocessing_distributed: + dist.barrier() + + # --------------------------------------------------------- + # save the model + # --------------------------------------------------------- + if iteration % hparams.iters_per_checkpoint == 0: + if hparams.rank == 0: + checkpoint_path = os.path.join(hparams.output_directory, 'checkpoints', f'DaftExprt_{iteration}') + save_checkpoint(model, optimizer, hparams, learning_rate, + iteration, best_val_loss, checkpoint_path) + output_dir = os.path.join(hparams.output_directory, 'checkpoints', f'chk_{iteration}') + generate_benchmark_sentences(model, hparams, output_dir) + # barrier for distributed processes + if hparams.multiprocessing_distributed: + dist.barrier() + + # --------------------------------------------------------- + # reset variables + # --------------------------------------------------------- + iteration += 1 + tot_loss = 0. + indiv_loss = { + 'speaker_loss': 0., 'post_mult_loss': 0., + 'duration_loss': 0., 'energy_loss':0., 'pitch_loss': 0., + 'mel_spec_l1_loss': 0., 'mel_spec_l2_loss': 0. + } + start = time.time() + accumulation_step = 0 + + model.train() # set training mode + model.zero_grad() # set gradients to 0 + + # --------------------------------------------------------- + # adjust learning rate + # --------------------------------------------------------- + learning_rate = update_learning_rate(hparams, iteration) + for param_group in optimizer.param_groups: + if param_group['lr'] is not None: + param_group['lr'] = learning_rate + + +def launch_training(data_set_dir, config_file, benchmark_dir, log_file, world_size=1, rank=0, + multiprocessing_distributed=True, master='tcp://localhost:54321'): + ''' Launch training in distributed mode or on a single GPU + PyTorch distributed training is performed using DistributedDataParrallel API + Inspired from https://github.com/pytorch/examples/blob/master/imagenet/main.py + + - multiprocessing_distributed=False: + Training is performed using only GPU 0 on the machine + + - multiprocessing_distributed=True: + Multi-processing distributed training is performed with DistributedDataParrallel API. + X distributed processes are launched on the machine, with X the total number of GPUs + on the machine. Each process replicates the same model to a unique GPU, and each GPU + consumes a different partition of the input data. DistributedDataParrallel takes care + of gradient averaging and model parameter update on all GPUs. This is the go-to method + when model can fit on one GPU card. + - world_size=1: + One machine is used for distributed training. The machine launches X distributed processes. + - world_size=N: + N machines are used for distributed training. Each machine launches X distributed processes. + ''' + # set logger config + if rank == 0: + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file) + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO + ) + else: + logging.basicConfig( + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file) + ], + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.ERROR + ) + + # get hyper-parameters + with open(config_file) as f: + data = f.read() + config = json.loads(data) + hparams = HyperParams(verbose=False, **config) + + # count number of GPUs on the machine + ngpus_per_node = torch.cuda.device_count() + + # set default values + if multiprocessing_distributed: + hparams.dist_url = f'{master}' + # since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + # here we assume that each node has the same number of GPUs + world_size = ngpus_per_node * world_size + else: + rank, gpu = 0, 0 + + # update hparams + hparams.data_set_dir = data_set_dir + hparams.config_file = config_file + hparams.benchmark_dir = benchmark_dir + + hparams.rank = rank + hparams.world_size = world_size + hparams.ngpus_per_node = ngpus_per_node + hparams.multiprocessing_distributed = multiprocessing_distributed + + # check that config used for training is the same than the one used for features extraction + check_train_config(hparams) + # save hyper-params to config.json + if rank == 0: + hparams.save_hyper_params(hparams.config_file) + + # check if multiprocessing distributed is deactivated but feasible + if not multiprocessing_distributed and ngpus_per_node > 1: + _logger.warning(f'{ngpus_per_node} GPUs detected but distributed training is not set. ' + f'Training on only 1 GPU.\n') + + # define cudnn variables + torch.manual_seed(0) + torch.backends.cudnn.enabled = hparams.cudnn_enabled + torch.backends.cudnn.benchmark = hparams.cudnn_benchmark + torch.backends.cudnn.deterministic = hparams.cudnn_deterministic + if hparams.seed is not None: + random.seed(hparams.seed) + torch.manual_seed(hparams.seed) + torch.backends.cudnn.deterministic = True + _logger.warning('You have chosen to seed training. This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! You may see unexpected behavior when ' + 'restarting from checkpoints.\n') + + # display training setup info + _logger.info(f'PyTorch version -- {torch.__version__}') + _logger.info(f'CUDA version -- {torch.version.cuda}') + _logger.info(f'CUDNN version -- {torch.backends.cudnn.version()}') + _logger.info(f'CUDNN enabled = {torch.backends.cudnn.enabled}') + _logger.info(f'CUDNN deterministic = {torch.backends.cudnn.deterministic}') + _logger.info(f'CUDNN benchmark = {torch.backends.cudnn.benchmark}\n') + + # clear handlers + _logger.handlers.clear() + + # launch multi-processing distributed training + if multiprocessing_distributed: + # use torch.multiprocessing.spawn to launch distributed processes + mp.spawn(train, nprocs=ngpus_per_node, args=(hparams, log_file)) + # simply call train function + else: + train(gpu, hparams, log_file) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--multiprocessing_distributed', action='store_true', + help='Use multi-processing distributed training to launch N processes per ' + 'node, which has N GPUs. This is the fastest way to use PyTorch for ' + 'either single node or multi node data parallel training') + parser.add_argument('--world_size', type=int, default=1, + help='number of nodes for distributed training') + parser.add_argument('--rank', type=int, default=0, + help='node rank for distributed training') + parser.add_argument('--master', type=str, default='tcp://localhost:54321', + help='url used to set up distributed training') + parser.add_argument('--data_set_dir', type=str, required=True, + help='Data set containing .wav files') + parser.add_argument('--config_file', type=str, required=True, + help='JSON configuration file to initialize hyper-parameters for training') + parser.add_argument('--benchmark_dir', type=str, required=True, + help='directory to load benchmark sentences') + parser.add_argument('--log_file', type=str, required=True, + help='path to save logger outputs') + + args = parser.parse_args() + + # launch training + launch_training(args.data_set_dir, args.config_file, args.benchmark_dir, args.log_file, + args.world_size, args.rank, args.multiprocessing_distributed, args.master) diff --git a/src/daft_exprt/utils.py b/src/daft_exprt/utils.py new file mode 100644 index 0000000..40e0f12 --- /dev/null +++ b/src/daft_exprt/utils.py @@ -0,0 +1,227 @@ +import logging +import multiprocessing as mp +import sys +import threading +import time + +import matplotlib.pyplot as plt + +from functools import partial +from multiprocessing import Pool + +from dateutil.relativedelta import relativedelta + + +_logger = logging.getLogger(__name__) + + +def histogram_plot(data, x_labels, y_labels, figsize=(16, 4)): + ''' Histogram plot for different set of data + ''' + # create subplot + fig, axes = plt.subplots(1, len(data), figsize=figsize, squeeze=False) + # create an histogram plot for each item in data + for i in range(len(data)): + # (B, N) --> (B * N) + data_vals = data[i].ravel() + # plot histogram + axes[0, i].hist(data_vals, bins=50, density=True) + # add axis labels + if x_labels is not None: + axes[0, i].set(xlabel=x_labels[i]) + if y_labels is not None: + axes[0, i].set(ylabel=y_labels[i]) + plt.close(fig) + + return fig + + +def scatter_plot(data, colors, labels, x_label=None, y_label=None, figsize=(16, 4)): + ''' Scatter plots of different data points + ''' + # create subplot + fig, axes = plt.subplots(1, 1, figsize=figsize, squeeze=False) + # fill with data + for item, color in zip(data, colors): + axes[0, 0].scatter(range(len(item)), item, color=color, marker='o') + # add plots labels + axes[0, 0].legend(labels=labels) + # add axis labels + if x_label is not None: + axes[0, 0].set(xlabel=x_label) + if y_label is not None: + axes[0, 0].set(ylabel=y_label) + plt.close(fig) + + return fig + + +def plot_2d_data(data, x_labels=None, y_labels=None, filename=None, figsize=(16, 4)): + ''' Create several 2D plots for each item given by data + + :param data: sequence of numpy arrays -- length (L, ) + :param x_labels: labels to give to each plot on the x axis -- length (L, ) if not None + :param y_labels: labels to give to each plot on the y axis -- length (L, ) if not None + :param filename: file to save the figure + :param figsize: size of the plots + + :return: the 2D plot + ''' + # initialize the subplot -- put squeeze to false to avoid errors when data is of length 1 + fig, axes = plt.subplots(1, len(data), figsize=figsize, squeeze=False) + + # create a plot for each item given by data + for i in range(len(data)): + if len(data[i].shape) == 1: + axes[0, i].scatter(range(len(data[i])), data[i], alpha=0.5, marker='.', s=10) + elif len(data[i].shape) == 2: + axes[0, i].imshow(data[i], aspect='auto', origin='lower', interpolation='none') + if x_labels is not None: + axes[0, i].set(xlabel=x_labels[i]) + if y_labels is not None: + axes[0, i].set(ylabel=y_labels[i]) + + # save the figure and return it + if filename is not None: + fig.savefig(filename) + + plt.close(fig) + return fig + + +def chunker(seq, size): + ''' creates a list of chunks + https://stackoverflow.com/a/434328 + + :param seq: the sequence we want to create chunks from + :param size: size of the chunks + ''' + return (seq[pos: pos + size] for pos in range(0, len(seq), size)) + + +def prog_bar(i, n, bar_size=16): + """ Create a progress bar to estimate remaining time + + :param i: current iteration + :param n: total number of iterations + :param bar_size: size of the bar + + :return: a visualisation of the progress bar + """ + bar = '' + done = (i * bar_size) // n + + for j in range(bar_size): + bar += '█' if j <= done else '░' + + message = f'{bar} {i}/{n}' + return message + + +def estimate_required_time(nb_items_in_list, current_index, time_elapsed, interval=100): + """ Compute a remaining time estimation to process all items contained in a list + + :param nb_items_in_list: all list items that have to be processed + :param current_index: current list index, contained in [0, nb_items_in_list - 1] + :param time_elapsed: time elapsed to process current_index items in the list + :param interval: estimate remaining time when (current_index % interval) == 0 + + :return: time elapsed since the last time estimation + """ + current_index += 1 # increment current_idx by 1 + if current_index % interval == 0 or current_index == nb_items_in_list: + # make time estimation and put to string format + seconds = (nb_items_in_list - current_index) * (time_elapsed / current_index) + time_estimation = relativedelta(seconds=int(seconds)) + time_estimation_string = f'{time_estimation.hours:02}:{time_estimation.minutes:02}:{time_estimation.seconds:02}' + + # extract progress bar + progress_bar = prog_bar(i=current_index, n=nb_items_in_list) + + # display info + if current_index == nb_items_in_list: + sys.stdout.write(f'\r{progress_bar} -- estimated required time = {time_estimation_string} -- Finished!') + else: + sys.stdout.write(f'\r{progress_bar} -- estimated required time = {time_estimation_string} -- ') + + +def get_nb_jobs(n_jobs): + """ Return the number of parallel jobs specified by n_jobs + + :param n_jobs: the number of jobs the user want to use in parallel + + :return: the number of parallel jobs + """ + # set nb_jobs to max by default + nb_jobs = mp.cpu_count() + if n_jobs != 'max': + if int(n_jobs) > mp.cpu_count(): + _logger.warning(f'Max number of parallel jobs is "{mp.cpu_count()}" but received "{int(n_jobs)}" -- ' + f'setting nb of parallel jobs to {nb_jobs}') + else: + nb_jobs = int(n_jobs) + + return nb_jobs + + +def logger_thread(q): + ''' Thread logger to listen to log outputs in multi-processing mode + ''' + while True: + log_record = q.get() + if log_record is None: + break + _logger.handle(log_record) + + +def launch_multi_process(iterable, func, n_jobs, chunksize=1, ordered=True, timer_verbose=True, **kwargs): + """ Calls function using multi-processing pipes + https://guangyuwu.wordpress.com/2018/01/12/python-differences-between-imap-imap_unordered-and-map-map_async/ + + :param iterable: items to process with function func + :param func: function to multi-process + :param n_jobs: number of parallel jobs to use + :param chunksize: size of chunks given to each worker + :param ordered: True: iterable is returned while still preserving the ordering of the input iterable + False: iterable is returned regardless of the order of the input iterable -- better perf + :param timer_verbose: display time estimation when set to True + :param kwargs: additional keyword arguments taken by function func + + :return: function outputs + """ + # set up a queue and listen to log messages on it in another thread + m = mp.Manager() + q = m.Queue() + lp = threading.Thread(target=logger_thread, args=(q, )) + lp.start() + + # define pool of workers + pool = Pool(processes=n_jobs) + # define partial function and pool function + func = partial(func, log_queue=q, **kwargs) + pool_func = pool.imap if ordered else pool.imap_unordered + + # initialize variables + func_returns = [] + nb_items_in_list = len(iterable) if timer_verbose else None + start = time.time() if timer_verbose else None + # iterate over iterable + for i, func_return in enumerate(pool_func(func, iterable, chunksize=chunksize)): + # store function output + func_returns.append(func_return) + # compute remaining time + if timer_verbose: + estimate_required_time(nb_items_in_list=nb_items_in_list, current_index=i, + time_elapsed=time.time() - start) + if timer_verbose: + sys.stdout.write('\n') + + # wait for all worker to finish and close the pool + pool.close() + pool.join() + + # put a null message in the queue so that it stops the logging thread + q.put(None) + lp.join() + + return func_returns