My Experiments with Finetuning Seamless M4T #459
Replies: 6 comments 16 replies
-
Thanks for the great work! I was following your scripts, but I ran on multiple GPUs (4) setting with 16GB VRAM. I came across with issues like your notebook on English, do you have any ideas on this? It would be great if you can share a bit more on this! Thanks! torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.34 GiB. GPU 0 has a total capacity of 14.75 GiB of which 3.09 GiB is free. Process 241510 has 11.66 GiB memory in use. Of the allocated memory 8.14 GiB is allocated by PyTorch, and 3.38 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/bin/m4t_finetune", line 8, in <module>
sys.exit(main())
File "/content/seamless_communication/src/seamless_communication/cli/m4t/finetune/finetune.py", line 212, in main
finetune.run(stop_at=args.max_batches)
File "/content/seamless_communication/src/seamless_communication/cli/m4t/finetune/trainer.py", line 401, in run
self._eval_model()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/content/seamless_communication/src/seamless_communication/cli/m4t/finetune/trainer.py", line 341, in _eval_model
loss = self.calc_loss(batch, *self.model(batch))
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/content/seamless_communication/src/seamless_communication/cli/m4t/finetune/trainer.py", line 104, in forward
with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
File "/usr/lib/python3.10/contextlib.py", line 153, in __exit__
self.gen.throw(typ, value, traceback)
AttributeError: 'list_iterator' object has no attribute 'throw' |
Beta Was this translation helpful? Give feedback.
-
When i use m4t_evaluate to use checkpoints from fine-tuning part,
|
Beta Was this translation helpful? Give feedback.
-
Hi again! I have another question. How is it possible to fine-tune using torch run with a fixed batch size? When I fine-tune on a single GPU, I can use a batch size of 10. However, when I use two GPUs, I can't use a batch size of 10. Thanks in advance. |
Beta Was this translation helpful? Give feedback.
-
Hi, i am working on TTS for indic languages, since some doesn't have support for speech. But while finetuning, I am getting the following error. I don't know what is happening since I am following the documentation. |
Beta Was this translation helpful? Give feedback.
-
Hello, I'm fine-tuning seamlessstreaming and would like to ask if you're working on something similar to share? |
Beta Was this translation helpful? Give feedback.
-
Any luck on fine-tuning on a new language? |
Beta Was this translation helpful? Give feedback.
-
I started working on finetuning Seamless M4T and in this post I will describe how I got on and what the process was like. If you are unaware, finetuning is what LLM circles use to refer to what used to be called transfer learning. The idea is, a pretrained model may not have the best performance on all the tasks it supports, thus to improve performance on a specific task
we train on examples for that task.
SeamlessM4T supports many tasks like Speech-to-Text-Translation (S2TT), Text-to-Text-Translation (T2TT), Transcription (ASR). For the purpose of this, I will talk about the transcription task, also called automatic speech recognition (ASR).
Installation and Setup
The simplest way to get Seamless CLIs installed and running is to run these
Contents
Preparing a Dataset for Finetuning
The first thing to do is prepare a dataset on which we will run our experiements. There are two ways to do this:
m4t_prepare_dataset
CLI.Use Supported Datasets
Currently the
m4t_prepare_dataset
CLI supports two datasets,google/fleurs
speechcolab/gigaspeech
Run this to download Google Fleurs into a folder.
mkdir -p fleurs/english m4t_prepare_dataset \ --name google/fleurs \ --split test \ --source_lang eng \ --target_lang eng \ --save_dir fleurs/english
To download Gigaspeech, you can change the argument
--name google/fleurs
to--name speechcolab/gigaspeech
. However you will need to go to the Gigaspeech page and fill-in the access form, and get a Huggingface token and supply that as an argument.After running these you will notice that the CLI generates a
<...>_manifest.json
file. This is the file that the evaluation and finetune CLI can digest.Build your Own
To use your own dataset you need to write some code to generate the
manifest.json
file. This is just a long file containing locations of the audio files and some metadata.Here I have shown an example for Gigaspeech. To use this for another dataset, you need to know the columns and build the JSON object accordingly.
Getting Baseline Performance Figures
The next thing we want to do is get some performance metrics for the task we want to finetune on, so that we have some numbers to compare results with. We need to simply run the model and see how well it performs out of the box.
Continuing from the Google Fluers example, we will now start training a model for ASR. We will use the
seamlessM4T_medium
model from now because this is not too big to manage. The same things will apply toseamlessM4T_small
andseamlessM4T_v2_large
.The output shows the Word Error Rate (WER) which is the standard metric for evaluating ASR.
I ran the same evaluation on 6 languages and the results are shown here:
You can find details in my notebook.
Running a Training Session
We are now ready to run our first round of training. For this we need to have already downloaded a dataset of our choice and we will use the
manifest.json
file of its train and validation splits.There are various modes we can run our finetune with:
Training the Full Model
Run the
m4t_finetune
CLI to start training a model. We use a small batch size and a very small learning rate, and we use a patience of 10 evaluations which means that if the loss does not improve after 10 evaluations we stop training (early stopping).This method trains the whole model and so, we need to make sure we don't use a very high learning rate as that can lead to the model forgetting its pre-training and only focussing on the finetune examples.
Training with Frozen Layers
Alternatively you can choose to train some parts of the model which you know might improve performance. Usually the later layers are trained for this purpose as it is assumed that these will have learned the task specific details that we want to finetune for.
Run the
m4t_finetune
CLI with the--freeze_layers
option and the layer names to freeze.Naturally you may not know the layer names for all the layers in a M4T model. For this all you need to do is use
model.named_parameters()
to get the names in a list like so...This will give you a pretty long list with all the nested modules as well but you can use the name of a parent to freeze everything inside it. For example
--freeze_layers model.text_decoder
freezes allmodel.text_decoder.layers[0...6]
.You can find names of layers in
seamlessM4T_medium
in my notebook.Evaluating the Trained Model
Now after getting a trained model checkpoint, we can evaluate the performance again to see how much better this new model is compared to the stock model.
Run the
m4t_evaluate
CLI with the--load_checkpoint
option to load the checkpoint.However, this takes quite a while because it evaluates over the whole test dataset. Alternatively you can also run a mini-evaluation, which you will find in my Finetuning with Frozen Layers notebook.. I recommend using the mini-evaluation, because it makes iterations faster.
Here is an example evaluation output from the evaluation.
Results
Here I thought I will also mention the results I got from different modes of finetuning.
Though I know that I could have gotten a lower error rate with full tuning, the interesting fact is that I get almost the same score with partial tuning. I was able to get almost identical tuning results with training only around 40% of the model weights.
Notes and Observations
Further Work - LoRA Finetuning
Low-Rank Adaptation (LoRA) is a method of fine-tuning that can use even less memory than freezing layers. This is yet to be implemented in the finetune CLI, although the basic components already exist in FairSeq2. The implementation would just be adding another option to the CLI to choose LoRA finetuning and set parameters.
Training Hardware
All of this was run on a A100 GPU provided by Google Colab Pro+, with 40GB of GPU RAM and 85GB of system RAM. The
seamlessM4T_v2_large
model needs more memory than this, so I based all this on the medium model. However it is possible that the large model may be run with frozen layers.Notebooks
Beta Was this translation helpful? Give feedback.
All reactions