-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Noisy student training for wenet #1600
Merged
Merged
Changes from 23 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
220d7b2
add NST module
NevermoreCY fd2cbeb
Delete test.text
NevermoreCY 92974fd
Delete text.md
NevermoreCY 4d68c70
Update README
wd929 2f84cd7
Update README.md
wd929 bd36149
Update README.md
wd929 1c4043b
Update README.md
wd929 1683016
add example toy data
NevermoreCY 0496b1d
add few lines in README
NevermoreCY 7595ac9
add toy wav and txt
NevermoreCY b3c880c
add toy data
NevermoreCY 0b47162
Update README.md
NevermoreCY aa5bd02
fixed some formatting issue
NevermoreCY 5663ec0
Merge remote-tracking branch 'origin/main'
NevermoreCY b4dfc71
fixed some formatting issue
NevermoreCY 9a059a2
fixed some formatting issue
NevermoreCY 145fe9c
fixed some formatting issue
NevermoreCY dc53142
fix mis-comment on cv_configs
NevermoreCY 463d09a
fix some formatting issue
NevermoreCY 974ada1
fix some formatting issue
NevermoreCY 3b6e69b
Merge branch 'wenet-e2e:main' into main
NevermoreCY 0918842
commit after rebase
NevermoreCY 2289c79
Merge branch 'main' of https://github.com/NevermoreCY/wenet
NevermoreCY 555edae
fixed bugs mentioned by @xingchensong, add data list fusion.
NevermoreCY a31a7d1
modify run_nst.sh
NevermoreCY 06e58a6
add config
NevermoreCY d6751d3
add run.sh
NevermoreCY 7c8a859
fix bug in generate_data_list.py
NevermoreCY 68d974a
fix bug in generate_data_list.py
NevermoreCY 91baa96
fix bug in generate_data_list.py
NevermoreCY 98eced6
fix bug in generate_data_list.py
NevermoreCY 898ccd0
fix bug in generate_data_list.py
NevermoreCY bd341f6
fix bug in generate_data_list.py
NevermoreCY eb58e5b
format issue
NevermoreCY 52dcdfe
format issue
NevermoreCY 6460e82
delete train_nst.py and executor_nst.py
NevermoreCY 52c4783
Merge branch 'wenet-e2e:main' into main
NevermoreCY b553f01
formatting
NevermoreCY 40b7b61
Merge remote-tracking branch 'origin/main'
NevermoreCY 62e8cf1
formatting
NevermoreCY de8bbe2
fix comments
NevermoreCY 0141767
Update README.md
wd929 908752c
Update README.md
wd929 7f5ee03
delete example data, fix run_nst.sh
NevermoreCY 4b19bb8
fix readme
NevermoreCY 74caa24
fix config
NevermoreCY File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Recipe to run Noisy Student Training with LM filter in WeNet | ||
|
||
Noisy Student Training (NST) has recently demonstrated extremely strong performance in Automatic Speech Recognition (ASR). | ||
|
||
Here, we provide a recipe to run NST with `LM filter` strategy using AISHELL-1 as supervised data and WenetSpeech as unsupervised data from [this paper](https://arxiv.org/abs/2211.04717), where hypotheses with and without Language Model are generated and CER differences between them are utilized as a filter threshold to improve the ASR performances of non-target domain datas. | ||
|
||
## Table of Contents | ||
|
||
- [Guideline](#guideline) | ||
- [Data preparation](#data-preparation) | ||
- [Initial supervised teacher](#initial-supervised-teacher) | ||
- [Noisy student interations](#noisy-student-interations) | ||
- [Performance Record](#performance-record) | ||
- [Supervised baseline and standard NST](##supervised-baseline-and-standard-nst) | ||
- [Supervised AISHELL-1 and unsupervised 1khr WenetSpeech](#supervised-aishell-1-and-unsupervised-1khr-wenetspeech) | ||
- [Supervised AISHELL-2 and unsupervised 4khr WenetSpeech](#supervised-aishell-2-and-unsupervised-4khr-wenetspeech) | ||
- [Citations](#citations) | ||
|
||
## Guideline | ||
|
||
First, you have to prepare supervised and unsupervised data for NST, and then train an initial supervised teacher. After that, you can iterate the noisy student interations until the model converge. | ||
|
||
### Data preparation | ||
|
||
To run this recipe, you should follow the steps from [WeNet examples](https://github.com/wenet-e2e/wenet/tree/main/examples) to prepare [AISHELL1](https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/s0) and [WenetSpeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0) data. | ||
We extract 1khr data from WenetSpeech and data should be prepared and stored in the following format: | ||
|
||
``` | ||
data_dir/ | ||
├── train | ||
├──── data_aishell.list | ||
├──── wenet_1khr.list | ||
├──── wav_dir/ | ||
└── utter_time.json (optional) | ||
``` | ||
- `*.list` contains paths for all the data shards for training. | ||
- A Json file containing the audio length should be prepared as `utter_time.json` if you want to apply the `speaking rate` filter. | ||
- A wav_dir contains all the audio data (id.wav) and labels (id.txt which is optional) for unsupervised data. | ||
> **HINTS** We include a tiny example under `local/example` to make it clearer for reproduction. | ||
|
||
### Initial supervised teacher | ||
|
||
To train an initial supervised teacher model, run the following command: | ||
|
||
```bash | ||
bash run_nst.sh --dir exp/conformer_test_fully_supervised --supervised_data_list data_aishell.list --data_list wenet_1khr.list --dir_split wenet_split_60_test/ --out_data_list data/train/wenet_1khr_nst0.list --enable_nst 0 | ||
``` | ||
- `dir` contains the training parameters. | ||
- `supervised_data_list` contains paths for supervised data shards. | ||
- `data_list` contains paths for unsupervised data shards which is used for inference. | ||
- `dir_split` is the directory stores split unsupervised data for parallel computing. | ||
- `out_data_list` is the pseudo label data list file path. | ||
- `enable_nst` indicates whether we train with pseudo label, for initial teacher we set it to 0. | ||
- This recipe uses the default `num_split=1` while we strongly recommend use larger number to decrease the inference and shards generation time. | ||
|
||
Full arguments are listed below, you can check the `run_nst.sh` for more information about each stage and their arguments. We used `num_split = 60` and generate shards with different cpu for the experiments in our paper which saved us lots of inference time and data shards generation time. | ||
|
||
```bash | ||
bash run_nst.sh --stage 1 --stop-stage 8 --dir exp/conformer_test_fully_supervised --supervised_data_list data_aishell.list --enable_nst 0 --num_split 1 --data_list wenet_1khr.list --dir_split wenet_split_60_test/ --job_num 0 --hypo_name hypothesis_nst0.txt --label 1 --wav_dir data/train/wenet_1k_untar/ --cer_hypo_dir wenet_cer_hypo --cer_label_dir wenet_cer_label --label_file label.txt --cer_hypo_threshold 10 --speak_rate_threshold 0 --utter_time_file utter_time.json --untar_dir data/train/wenet_1khr_untar/ --tar_dir data/train/wenet_1khr_tar/ --out_data_list data/train/wenet_1khr.list | ||
``` | ||
|
||
### Noisy student interations | ||
|
||
After finishing the initial fully supervised baseline, we now have the pseudo-label data list which is `wenet_1khr_nst0.list` if you follow the guideline. We will use it as the `pseudo_data` in the training step and the `pseudo-label` for next NST iteration will be generated. | ||
|
||
Here is an example command: | ||
|
||
```bash | ||
bash run_nst.sh --dir exp/conformer_nst1 --supervised_data_list data_aishell.list --pseudo_data_list wenet_1khr_nst0.list --enable_nst 1 --job_num 0 --hypo_name hypothesis_nst1.txt --untar_dir data/train/wenet_1khr_untar_nst1/ --tar_dir data/train/wenet_1khr_tar_nst1/ --out_data_list data/train/wenet_1khr_nst1.list | ||
``` | ||
Most of the arguments are same as the initial teacher training, here we add extra argument `pseudo_data_list` for path of pseudo data list. The `enbale_nst` must be set to 1 if you want to train with pseudo data. The index for `hypo_name` and `tar_dir` need to be changed if you don't want to overlap the previous generated data. | ||
The `output_data_list` can be used as the input of `pseudo_data_list` for next NST itearion. | ||
|
||
Full arguments are listed below, you can check the `run_nst.sh` for more information about each stage and their arguments: | ||
```bash | ||
bash run_nst.sh --stage 1 --stop-stage 8 --dir exp/conformer_nst1 --supervised_data_list data_aishell.list --pseudo_data_list wenet_1khr_nst0 --enable_nst 1 --num_split 1 --dir_split wenet_split_60_test/ --job_num 0 --hypo_name hypothesis_nst1.txt --label 0 --wav_dir data/train/wenet_1k_untar/ --cer_hypo_dir wenet_cer_hypo --cer_label_dir wenet_cer_label --label_file label.txt --cer_hypo_threshold 10 --speak_rate_threshold 0 --utter_time_file utter_time.json --untar_dir data/train/wenet_1khr_untar_nst1/ --tar_dir data/train/wenet_1khr_tar_nst1/ --out_data_list data/train/wenet_1khr_nst1.list | ||
``` | ||
|
||
## Performance Record | ||
|
||
### Supervised baseline and standard NST | ||
* Without filter strategy, first iteration | ||
* Feature info: using FBANK feature, dither, cmvn, online speed perturb | ||
* Training info: lr 0.002, batch size 32, 8 gpu, acc_grad 4, 240 epochs, dither 0.1 | ||
* Decoding info: ctc_weight 0.3, average_num 30 | ||
|
||
|
||
| Supervised | Unsupervised | Test CER | | ||
|--------------------------|--------------|----------| | ||
| AISHELL-1 Only | ---- | 4.85 | | ||
| AISHELL-1+WenetSpeech | ---- | 3.54 | | ||
| AISHELL-1+AISHELL-2 | ---- | 1.01 | | ||
| AISHELL-1 (standard NST) | WenetSpeech | 5.52 | | ||
|
||
|
||
|
||
### Supervised AISHELL-1 and unsupervised 1khr WenetSpeech | ||
|
||
* Feature info: using FBANK feature | ||
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1 | ||
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75 | ||
|
||
| # nst iteration | AISHELL-1 test CER | Pseudo CER| Filtered CER | Filtered hours | | ||
|----------------|--------------------|-----------|--------------|----------------| | ||
| 0 | 4.85 | 47.10 | 25.18 | 323 | | ||
| 1 | 4.86 | 37.02 | 20.93 | 436 | | ||
| 2 | 4.75 | 31.81 | 19.74 | 540 | | ||
| 3 | 4.69 | 28.27 | 17.85 | 592 | | ||
| 4 | 4.48 | 26.64 | 14.76 | 588 | | ||
| 5 | 4.41 | 24.70 | 15.86 | 670 | | ||
| 6 | 4.34 | 23.64 | 15.40 | 669 | | ||
| 7 | 4.31 | 23.79 | 15.75 | 694 | | ||
|
||
### Supervised AISHELL-2 and unsupervised 4khr WenetSpeech | ||
|
||
* Feature info: using FBANK feature | ||
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1 | ||
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75 | ||
|
||
| # nst iteration | AISHELL-2 test CER | Pseudo CER | Filtered CER | Filtered hours | | ||
|----------------|--------------------|------------|--------------|----------------| | ||
| 0 | 5.48 | 30.10 | 11.73 | 1637 | | ||
| 1 | 5.09 | 28.31 | 9.39 | 2016 | | ||
| 2 | 4.88 | 25.38 | 9.99 | 2186 | | ||
| 3 | 4.74 | 22.47 | 10.66 | 2528 | | ||
| 4 | 4.73 | 22.23 | 10.43 | 2734 | | ||
|
||
|
||
|
||
## Citations | ||
|
||
``` bibtex | ||
|
||
@article{chen2022NST, | ||
title={Improving Noisy Student Training on Non-target Domain Data for Automatic Speech Recognition}, | ||
author={Chen, Yu and Wen, Ding and Lai, Junjie}, | ||
journal={arXiv preprint arXiv:2203.15455}, | ||
year={2022} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# network architecture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is not required, right? |
||
# encoder related | ||
encoder: conformer # ~ 30M parameters | ||
encoder_conf: | ||
output_size: 256 # dimension of attention | ||
attention_heads: 4 | ||
linear_units: 2048 # the number of units of position-wise feed forward | ||
num_blocks: 12 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.0 | ||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
cnn_module_kernel: 15 | ||
use_cnn_module: True | ||
activation_type: 'swish' | ||
pos_enc_layer_type: 'rel_pos' | ||
selfattention_layer_type: 'rel_selfattn' | ||
|
||
# decoder related | ||
decoder: transformer | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 2048 | ||
num_blocks: 6 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.0 | ||
src_attention_dropout_rate: 0.0 | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
|
||
dataset_conf: | ||
filter_conf: | ||
max_length: 1200 #40960 | ||
min_length: 10 #0 | ||
token_max_length: 100 # 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: false #true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 32 | ||
|
||
supervised_dataset_conf: | ||
filter_conf: | ||
max_length: 1200 #40960 | ||
min_length: 10 #0 | ||
token_max_length: 100 # 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: false #true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 32 | ||
|
||
unsupervised_dataset_conf: | ||
filter_conf: | ||
max_length: 1200 #40960 | ||
min_length: 10 #0 | ||
token_max_length: 100 # 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: false #true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 32 | ||
|
||
grad_clip: 5 | ||
accum_grad: 4 | ||
max_epoch: 30 | ||
log_interval: 100 | ||
# for full supervised training, just set this pseudo_ratio to 0 | ||
pseudo_ratio: 0.75 | ||
|
||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.002 | ||
scheduler: warmuplr # pytorch v1.1.0+ required | ||
scheduler_conf: | ||
warmup_steps: 20000 #25000 | ||
|
||
|
||
|
11 changes: 11 additions & 0 deletions
11
examples/aishell/NST/local/example/train/data_aishell.list
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
data/train/shards/shards_000000000.tar | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些示例的文件可以删除掉。 |
||
data/train/shards/shards_000000001.tar | ||
data/train/shards/shards_000000002.tar | ||
data/train/shards/shards_000000003.tar | ||
data/train/shards/shards_000000004.tar | ||
data/train/shards/shards_000000005.tar | ||
data/train/shards/shards_000000006.tar | ||
data/train/shards/shards_000000007.tar | ||
data/train/shards/shards_000000008.tar | ||
data/train/shards/shards_000000009.tar | ||
data/train/shards/shards_000000010.tar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
你好你好 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
data/train/wenet_1khr_tar//dir0_000000.tar | ||
data/train/wenet_1khr_tar//dir0_000001.tar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"ID001": 2.05, | ||
"ID002": 2.75, | ||
"ID003": 3.36 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import argparse | ||
import os | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser(description='generate data.list file ') | ||
parser.add_argument('--tar_dir', help='path for tar file') | ||
parser.add_argument('--out_data_list', help='output path for data list') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
target_dir = args.tar_dir | ||
# target_dir = "data/train/wenet_bad_tar_10_7_nst3_r6" | ||
data_list = os.listdir(target_dir) | ||
output_file = args.out_data_list | ||
# output_file = "data/train/wenet_bad_tar_10_7_nst3_r6.list" | ||
with open(output_file, "w") as writer: | ||
for line in data_list: | ||
writer.write(target_dir + "/" + line + "\n") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to your .yaml configuration, this is a non-streaming model and
4.85
is obtained from attention restoring.I do think we'd better mention those Infos directly in README.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, i will add those into the README file. I am trying to wrote a different data fusion method, (eg, a coarse fusion that combines the supervised and unsupervised data.list files in a given ratio ). Then i can replace the train_nst.py and executor_nst.py with the standard one which should be easier to maintain in the future.