Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noisy student training for wenet #1600

Merged
merged 46 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
220d7b2
add NST module
NevermoreCY Dec 5, 2022
fd2cbeb
Delete test.text
NevermoreCY Dec 5, 2022
92974fd
Delete text.md
NevermoreCY Dec 5, 2022
4d68c70
Update README
wd929 Dec 5, 2022
2f84cd7
Update README.md
wd929 Dec 5, 2022
bd36149
Update README.md
wd929 Dec 5, 2022
1c4043b
Update README.md
wd929 Dec 5, 2022
1683016
add example toy data
NevermoreCY Dec 5, 2022
0496b1d
add few lines in README
NevermoreCY Dec 5, 2022
7595ac9
add toy wav and txt
NevermoreCY Dec 5, 2022
b3c880c
add toy data
NevermoreCY Dec 5, 2022
0b47162
Update README.md
NevermoreCY Dec 5, 2022
aa5bd02
fixed some formatting issue
NevermoreCY Dec 5, 2022
5663ec0
Merge remote-tracking branch 'origin/main'
NevermoreCY Dec 5, 2022
b4dfc71
fixed some formatting issue
NevermoreCY Dec 5, 2022
9a059a2
fixed some formatting issue
NevermoreCY Dec 5, 2022
145fe9c
fixed some formatting issue
NevermoreCY Dec 5, 2022
dc53142
fix mis-comment on cv_configs
NevermoreCY Dec 7, 2022
463d09a
fix some formatting issue
NevermoreCY Dec 7, 2022
974ada1
fix some formatting issue
NevermoreCY Dec 7, 2022
3b6e69b
Merge branch 'wenet-e2e:main' into main
NevermoreCY Dec 7, 2022
0918842
commit after rebase
NevermoreCY Dec 7, 2022
2289c79
Merge branch 'main' of https://github.com/NevermoreCY/wenet
NevermoreCY Dec 7, 2022
555edae
fixed bugs mentioned by @xingchensong, add data list fusion.
NevermoreCY Dec 9, 2022
a31a7d1
modify run_nst.sh
NevermoreCY Dec 11, 2022
06e58a6
add config
NevermoreCY Dec 11, 2022
d6751d3
add run.sh
NevermoreCY Dec 11, 2022
7c8a859
fix bug in generate_data_list.py
NevermoreCY Dec 12, 2022
68d974a
fix bug in generate_data_list.py
NevermoreCY Dec 12, 2022
91baa96
fix bug in generate_data_list.py
NevermoreCY Dec 12, 2022
98eced6
fix bug in generate_data_list.py
NevermoreCY Dec 12, 2022
898ccd0
fix bug in generate_data_list.py
NevermoreCY Dec 12, 2022
bd341f6
fix bug in generate_data_list.py
NevermoreCY Dec 12, 2022
eb58e5b
format issue
NevermoreCY Dec 12, 2022
52dcdfe
format issue
NevermoreCY Dec 12, 2022
6460e82
delete train_nst.py and executor_nst.py
NevermoreCY Dec 12, 2022
52c4783
Merge branch 'wenet-e2e:main' into main
NevermoreCY Dec 12, 2022
b553f01
formatting
NevermoreCY Dec 12, 2022
40b7b61
Merge remote-tracking branch 'origin/main'
NevermoreCY Dec 12, 2022
62e8cf1
formatting
NevermoreCY Dec 12, 2022
de8bbe2
fix comments
NevermoreCY Dec 12, 2022
0141767
Update README.md
wd929 Dec 13, 2022
908752c
Update README.md
wd929 Dec 13, 2022
7f5ee03
delete example data, fix run_nst.sh
NevermoreCY Dec 13, 2022
4b19bb8
fix readme
NevermoreCY Dec 13, 2022
74caa24
fix config
NevermoreCY Dec 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/aishell/NST/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Recipe to run Noisy Student Training with LM filter in WeNet

Noisy Student Training (NST) has recently demonstrated extremely strong performance in Automatic Speech Recognition (ASR).

Here, we provide a recipe to run NST with `LM filter` strategy using AISHELL-1 as supervised data and WenetSpeech as unsupervised data from [this paper](https://arxiv.org/abs/2211.04717), where hypotheses with and without Language Model are generated and CER differences between them are utilized as a filter threshold to improve the ASR performances of non-target domain datas.

## Table of Contents

- [Guideline](#guideline)
- [Data preparation](#data-preparation)
- [Initial supervised teacher](#initial-supervised-teacher)
- [Noisy student interations](#noisy-student-interations)
- [Performance Record](#performance-record)
- [Supervised baseline and standard NST](##supervised-baseline-and-standard-nst)
- [Supervised AISHELL-1 and unsupervised 1khr WenetSpeech](#supervised-aishell-1-and-unsupervised-1khr-wenetspeech)
- [Supervised AISHELL-2 and unsupervised 4khr WenetSpeech](#supervised-aishell-2-and-unsupervised-4khr-wenetspeech)
- [Citations](#citations)

## Guideline


First, you have to prepare supervised and unsupervised data for NST. Then in stage 1 of `run.sh`, you will train an initial supervised teacher and generate pseudo labels for unsupervised data.
After that, you can run the noisy student training iteratively in stage 2. The whole pipeline is illustrated in the following picture.

![plot](local/NST_plot.png)

### Data preparation

To run this recipe, you should follow the steps from [WeNet examples](https://github.com/wenet-e2e/wenet/tree/main/examples) to prepare [AISHELL1](https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/s0) and [WenetSpeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0) data.
We extract 1khr data from WenetSpeech and data should be prepared and stored in the following format:

```
data/
├── train/
├──── data_aishell.list
├──── wenet_1khr.list
├──── wav_dir/
├──── utter_time.json (optional)
├── dev/
└── test/

```
- Files `*.list` contain paths for all the data shards for training.
- A Json file containing the audio length should be prepared as `utter_time.json` if you want to apply the `speaking rate` filter.
- A wav_dir contains all the audio data (id.wav) and labels (id.txt which is optional) for unsupervised data.

### Initial supervised teacher

To train an initial supervised teacher model, run the following command:

```bash
bash run.sh --stage 1 --stop-stage 1
```

Full arguments are listed below, you can check `run.sh` and `run_nst.sh` for more information about steps in each stage and their arguments. We used `num_split = 60` and generate shards with different cpu for the experiments in our paper which saved us lots of inference time and data shards generation time.

```bash
bash run.sh --stage 1 --stop-stage 1 --dir exp/conformer_test_fully_supervised --supervised_data_list data_aishell.list --enable_nst 0 --num_split 1 --unsupervised_data_list wenet_1khr.list --dir_split wenet_split_60_test/ --job_num 0 --hypo_name hypothesis_nst0.txt --label 1 --wav_dir data/train/wenet_1k_untar/ --cer_hypo_dir wenet_cer_hypo --cer_label_dir wenet_cer_label --label_file label.txt --cer_hypo_threshold 10 --speak_rate_threshold 0 --utter_time_file utter_time.json --untar_dir data/train/wenet_1khr_untar/ --tar_dir data/train/wenet_1khr_tar/ --out_data_list data/train/wenet_1khr.list
```
- `dir` contains the training parameters.
- `data_list` contains paths for the training data list.
- `supervised_data_list` contains paths for supervised data shards.
- `unsupervised_data_list`contains paths for unsupervised data shards which is used for inference.
- `dir_split` is the directory stores split unsupervised data for parallel computing.
- `out_data_list` is the pseudo label data list file path.
- `enable_nst` indicates whether we train with pseudo label and split data, for initial teacher we set it to 0.
- This recipe uses the default `num_split=1` while we strongly recommend use larger number to decrease the inference and shards generation time.
> **HINTS** If num_split is set to N larger than 1, you need to modify the script in step 4-8 in run_nst.sh to submit N tasks into your own clusters (such as slurm,ngc etc..).
> We strongly recommend to do so since inference and pseudo-data generation is time-consuming.

### Noisy student interations

After finishing the initial fully supervised baseline, we now have the mixed list contains both supervised and pseudo data which is `wenet_1khr_nst0.list`.
We will use it as the `data_list` in the training step and the `data_list` for next NST iteration will be generated.

Here is an example command:

```bash
bash run.sh --stage 2 --stop-stage 2 --iter_num 2
```

Here we add extra argument `iter_num` for number of NST iterations. Intermediate files are named with `iter_num` as a suffix.
Please check the `run.sh` and `run_nst.sh` scripts for more information about each stage and their arguments.

## Performance Record

### Supervised baseline and standard NST
* Non-streaming conformer model with attention rescoring decoder.
* Without filter strategy, first iteration
* Feature info: using FBANK feature, dither, cmvn, online speed perturb
* Training info: lr 0.002, batch size 32, 8 gpu, acc_grad 4, 240 epochs, dither 0.1
* Decoding info: ctc_weight 0.3, average_num 30


| Supervised | Unsupervised | Test CER |
|--------------------------|--------------|----------|
| AISHELL-1 Only | ---- | 4.85 |
| AISHELL-1+WenetSpeech | ---- | 3.54 |
| AISHELL-1+AISHELL-2 | ---- | 1.01 |
| AISHELL-1 (standard NST) | WenetSpeech | 5.52 |



### Supervised AISHELL-1 and unsupervised 1khr WenetSpeech
* Non-streaming conformer model with attention rescoring decoder.
* Feature info: using FBANK feature
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75

| # nst iteration | AISHELL-1 test CER | Pseudo CER| Filtered CER | Filtered hours |
|----------------|--------------------|-----------|--------------|----------------|
| 0 | 4.85 | 47.10 | 25.18 | 323 |
| 1 | 4.86 | 37.02 | 20.93 | 436 |
| 2 | 4.75 | 31.81 | 19.74 | 540 |
| 3 | 4.69 | 28.27 | 17.85 | 592 |
| 4 | 4.48 | 26.64 | 14.76 | 588 |
| 5 | 4.41 | 24.70 | 15.86 | 670 |
| 6 | 4.34 | 23.64 | 15.40 | 669 |
| 7 | 4.31 | 23.79 | 15.75 | 694 |

### Supervised AISHELL-2 and unsupervised 4khr WenetSpeech
* Non-streaming conformer model with attention rescoring decoder.
* Feature info: using FBANK feature
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75

| # nst iteration | AISHELL-2 test CER | Pseudo CER | Filtered CER | Filtered hours |
|----------------|--------------------|------------|--------------|----------------|
| 0 | 5.48 | 30.10 | 11.73 | 1637 |
| 1 | 5.09 | 28.31 | 9.39 | 2016 |
| 2 | 4.88 | 25.38 | 9.99 | 2186 |
| 3 | 4.74 | 22.47 | 10.66 | 2528 |
| 4 | 4.73 | 22.23 | 10.43 | 2734 |



## Citations

``` bibtex

@article{chen2022NST,
title={Improving Noisy Student Training on Non-target Domain Data for Automatic Speech Recognition},
author={Chen, Yu and Wen, Ding and Lai, Junjie},
journal={arXiv preprint arXiv:2203.15455},
year={2022}
}
77 changes: 77 additions & 0 deletions examples/aishell/NST/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset_conf:
filter_conf:
max_length: 1200
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 4
max_epoch: 240
log_interval: 100

optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
Binary file added examples/aishell/NST/local/NST_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 66 additions & 0 deletions examples/aishell/NST/local/generate_data_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import random

def get_args():
parser = argparse.ArgumentParser(description='generate data.list file ')
parser.add_argument('--tar_dir', help='path for tar file')
parser.add_argument('--supervised_data_list',
help='path for supervised data list')
parser.add_argument('--pseudo_data_ratio',
type=float,
help='ratio of pseudo data, '
'0 means none pseudo data, '
'1 means all using pseudo data.')
parser.add_argument('--out_data_list', help='output path for data list')
args = parser.parse_args()
return args


def main():
args = get_args()
target_dir = args.tar_dir
pseudo_data_list = os.listdir(target_dir)
output_file = args.out_data_list
pseudo_data_ratio = args.pseudo_data_ratio
supervised_path = args.supervised_data_list
with open(supervised_path, "r") as reader:
supervised_data_list = reader.readlines()
pseudo_len = len(pseudo_data_list)
supervised_len = len(supervised_data_list)
random.shuffle(pseudo_data_list)
random.shuffle(supervised_data_list)

cur_ratio = pseudo_len / (pseudo_len + supervised_len)
if cur_ratio < pseudo_data_ratio:
pseudo_to_super_datio = pseudo_data_ratio / (1 - pseudo_data_ratio)
supervised_len = int(pseudo_len / pseudo_to_super_datio)
elif cur_ratio > pseudo_data_ratio:
super_to_pseudo_datio = (1 - pseudo_data_ratio) / pseudo_data_ratio
pseudo_len = int(supervised_len / super_to_pseudo_datio)

for i in range(len(pseudo_data_list)):
pseudo_data_list[i] = target_dir + "/" + pseudo_data_list[i] + "\n"

fused_list = pseudo_data_list[:pseudo_len] + supervised_data_list[:supervised_len]

with open(output_file, "w") as writer:
for line in fused_list:
writer.write(line)


if __name__ == '__main__':
main()
Loading