-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1004 from Simnol22/sb_tutorial
SpeechBrain Tutorial
- Loading branch information
Showing
5 changed files
with
990 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
******************** | ||
SpeechBrain | ||
******************** | ||
|
||
In this short tutorial, we're going to demonstrate how Oríon can be integrated to a `SpeechBrain | ||
<https://speechbrain.github.io/>`_ speech recognition model. | ||
The files mentioned in this tutorial are available in the `Oríon | ||
<https://github.com/Epistimio/orion/tree/develop/examples>`_ repository. | ||
|
||
Installation and setup | ||
====================== | ||
|
||
Make sure Oríon is installed (:doc:`/install/core`). | ||
|
||
Then install SpeechBrain using ``$ pip install speechbrain`` | ||
|
||
Code used in this tutorial | ||
========================== | ||
|
||
In this tutorial, we are going to use some code from the `SpeechBrain | ||
<https://github.com/speechbrain/speechbrain>` repository. More specifically, a speech recognition | ||
template made as an example. We will repurpose this example to adapt it for Oríon. The template | ||
used for creating this tutorial can be found `here | ||
<https://github.com/speechbrain/speechbrain/tree/develop/templates/speech_recognition/ASR>`. | ||
You can also directly see the code modified for this example here : | ||
``examples/speechbrain_tutorial``. | ||
|
||
We used the ``train.py`` file, but created a ``main.py``, with the ``main`` function, | ||
which we slightly modified for optimizing the hyperparamers with Oríon. | ||
|
||
Adapting the Speechbrain for Oríon | ||
================================== | ||
|
||
The Adaptation for using Oríon is quite simple. | ||
|
||
1) We first need to import ``orion.report_objective()`` into the project. | ||
|
||
.. code-block:: python | ||
from orion.client import report_objective | ||
2) We then need to change the evaluation from the training data to the validation data. | ||
The evaluation method should look like this. It returns the validation loss. | ||
|
||
.. literalinclude:: /../../examples/speechbrain_tutorial/main.py | ||
:language: python | ||
:lines: 75-80 | ||
|
||
3) Finally, we call ``report_objective`` at the end to return the final objective value, | ||
the validation loss, to Oríon. | ||
|
||
.. code-block:: python | ||
report_objective(valid_stats) | ||
The code is now adapted and ready to be used with Oríon. | ||
|
||
Execution | ||
========= | ||
|
||
We are now going to call the ``orion hunt`` command. | ||
Notice that we still need to give the ``train.yaml`` | ||
file to speechbrain, since the general configuration is in there. However, we are going to specify | ||
the hyperparameters that we want to optimize in the command line, | ||
which will automatically overrides the ones set in the ``train.yaml``. When an argument | ||
is defined both in the yaml configuration file and in command line, SpeechBrain | ||
gives precedence to values provided in command line. Thus, defining the hyperparamers through | ||
the command line for Oríon allows overriding the values in ``train.yaml`` in SpeechBrain. | ||
|
||
.. code-block:: bash | ||
orion hunt \ | ||
--enable-evc -n <experiment_name> \ | ||
python main.py train.yaml \ | ||
--lr~'loguniform(0.05, 0.2)' \ | ||
--ctc_weight~'loguniform(0.25, 0.75)' \ | ||
--label_smoothing~'loguniform(1e-10, 10e-5)' \ | ||
--coverage_penalty~'loguniform(1.0, 2.0)' \ | ||
--temperature~'loguniform(1.0, 1.5)' \ | ||
--temperature_lm~'loguniform(1.0, 1.5)' | ||
Results | ||
======= | ||
|
||
When an experiment reaches its termination criterion, basically ``max-trials``, | ||
you can see the results using the following command: | ||
|
||
.. code-block:: bash | ||
$ orion info -n <experiment_name> | ||
Which outputs the following statistics: | ||
|
||
.. code-block:: bash | ||
Stats | ||
===== | ||
completed: True | ||
trials completed: 209 | ||
best trial: | ||
id: 8675cfcfba768243e1ed1ac7825c69b6 | ||
evaluation: 0.13801406680803444 | ||
params: | ||
/coverage_penalty: 1.396 | ||
/ctc_weight: 0.389 | ||
/label_smoothing: 2.044e-10 | ||
/lr: 0.06462 | ||
/temperature: 1.175 | ||
/temperature_lm: 1.087 | ||
start time: 2022-09-29 14:37:41.048314 | ||
finish time: 2022-09-30 20:08:07.384765 | ||
duration: 1 day, 5:30:26.336451 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import logging | ||
import sys | ||
|
||
import speechbrain as sb | ||
import torch | ||
from hyperpyyaml import load_hyperpyyaml | ||
from mini_librispeech_prepare import prepare_mini_librispeech | ||
from speechbrain.utils.distributed import run_on_main | ||
from train import ASR, dataio_prepare | ||
|
||
from orion.client import report_objective | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if __name__ == "__main__": | ||
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) | ||
|
||
# Initialize ddp (useful only for multi-GPU DDP training) | ||
sb.utils.distributed.ddp_init_group(run_opts) | ||
|
||
# Load hyperparameters file with command-line overrides | ||
with open(hparams_file) as fin: | ||
hparams = load_hyperpyyaml(fin, overrides) | ||
|
||
# Create experiment directory | ||
sb.create_experiment_directory( | ||
experiment_directory=hparams["output_folder"], | ||
hyperparams_to_save=hparams_file, | ||
overrides=overrides, | ||
) | ||
|
||
# Data preparation, to be run on only one process. | ||
sb.utils.distributed.run_on_main( | ||
prepare_mini_librispeech, | ||
kwargs={ | ||
"data_folder": hparams["data_folder"], | ||
"save_json_train": hparams["train_annotation"], | ||
"save_json_valid": hparams["valid_annotation"], | ||
"save_json_test": hparams["test_annotation"], | ||
}, | ||
) | ||
|
||
# We can now directly create the datasets for training, valid, and test | ||
datasets = dataio_prepare(hparams) | ||
|
||
# In this case, pre-training is essential because mini-librispeech is not | ||
# big enough to train an end-to-end model from scratch. With bigger dataset | ||
# you can train from scratch and avoid this step. | ||
# We download the pretrained LM from HuggingFace (or elsewhere depending on | ||
# the path given in the YAML file). The tokenizer is loaded at the same time. | ||
run_on_main(hparams["pretrainer"].collect_files) | ||
hparams["pretrainer"].load_collected(device=torch.device("cpu")) | ||
|
||
# Trainer initialization | ||
asr_brain = ASR( | ||
modules=hparams["modules"], | ||
opt_class=hparams["opt_class"], | ||
hparams=hparams, | ||
run_opts=run_opts, | ||
checkpointer=hparams["checkpointer"], | ||
) | ||
|
||
# The `fit()` method iterates the training loop, calling the methods | ||
# necessary to update the parameters of the model. Since all objects | ||
# with changing state are managed by the Checkpointer, training can be | ||
# stopped at any point, and will be resumed on next call. | ||
asr_brain.fit( | ||
asr_brain.hparams.epoch_counter, | ||
datasets["train"], | ||
datasets["valid"], | ||
train_loader_kwargs=hparams["train_dataloader_opts"], | ||
valid_loader_kwargs=hparams["valid_dataloader_opts"], | ||
) | ||
|
||
# Load best checkpoint for evaluation | ||
valid_stats = asr_brain.evaluate( | ||
test_set=datasets["valid"], | ||
min_key="WER", | ||
test_loader_kwargs=hparams["valid_dataloader_opts"], | ||
) | ||
|
||
report_objective(valid_stats) |
Oops, something went wrong.