diff --git a/examples/aishell/NST/README.md b/examples/aishell/NST/README.md new file mode 100644 index 000000000..75fb1e43a --- /dev/null +++ b/examples/aishell/NST/README.md @@ -0,0 +1,146 @@ +# Recipe to run Noisy Student Training with LM filter in WeNet + +Noisy Student Training (NST) has recently demonstrated extremely strong performance in Automatic Speech Recognition (ASR). + +Here, we provide a recipe to run NST with `LM filter` strategy using AISHELL-1 as supervised data and WenetSpeech as unsupervised data from [this paper](https://arxiv.org/abs/2211.04717), where hypotheses with and without Language Model are generated and CER differences between them are utilized as a filter threshold to improve the ASR performances of non-target domain datas. + +## Table of Contents + +- [Guideline](#guideline) + - [Data preparation](#data-preparation) + - [Initial supervised teacher](#initial-supervised-teacher) + - [Noisy student interations](#noisy-student-interations) +- [Performance Record](#performance-record) + - [Supervised baseline and standard NST](##supervised-baseline-and-standard-nst) + - [Supervised AISHELL-1 and unsupervised 1khr WenetSpeech](#supervised-aishell-1-and-unsupervised-1khr-wenetspeech) + - [Supervised AISHELL-2 and unsupervised 4khr WenetSpeech](#supervised-aishell-2-and-unsupervised-4khr-wenetspeech) +- [Citations](#citations) + +## Guideline + + +First, you have to prepare supervised and unsupervised data for NST. Then in stage 1 of `run.sh`, you will train an initial supervised teacher and generate pseudo labels for unsupervised data. +After that, you can run the noisy student training iteratively in stage 2. The whole pipeline is illustrated in the following picture. + +![plot](local/NST_plot.png) + +### Data preparation + +To run this recipe, you should follow the steps from [WeNet examples](https://github.com/wenet-e2e/wenet/tree/main/examples) to prepare [AISHELL1](https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/s0) and [WenetSpeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0) data. +We extract 1khr data from WenetSpeech and data should be prepared and stored in the following format: + +``` +data/ +├── train/ +├──── data_aishell.list +├──── wenet_1khr.list +├──── wav_dir/ +├──── utter_time.json (optional) +├── dev/ +└── test/ + +``` +- Files `*.list` contain paths for all the data shards for training. +- A Json file containing the audio length should be prepared as `utter_time.json` if you want to apply the `speaking rate` filter. +- A wav_dir contains all the audio data (id.wav) and labels (id.txt which is optional) for unsupervised data. + +### Initial supervised teacher + +To train an initial supervised teacher model, run the following command: + +```bash +bash run.sh --stage 1 --stop-stage 1 +``` + +Full arguments are listed below, you can check `run.sh` and `run_nst.sh` for more information about steps in each stage and their arguments. We used `num_split = 60` and generate shards with different cpu for the experiments in our paper which saved us lots of inference time and data shards generation time. + +```bash +bash run.sh --stage 1 --stop-stage 1 --dir exp/conformer_test_fully_supervised --supervised_data_list data_aishell.list --enable_nst 0 --num_split 1 --unsupervised_data_list wenet_1khr.list --dir_split wenet_split_60_test/ --job_num 0 --hypo_name hypothesis_nst0.txt --label 1 --wav_dir data/train/wenet_1k_untar/ --cer_hypo_dir wenet_cer_hypo --cer_label_dir wenet_cer_label --label_file label.txt --cer_hypo_threshold 10 --speak_rate_threshold 0 --utter_time_file utter_time.json --untar_dir data/train/wenet_1khr_untar/ --tar_dir data/train/wenet_1khr_tar/ --out_data_list data/train/wenet_1khr.list +``` +- `dir` contains the training parameters. +- `data_list` contains paths for the training data list. +- `supervised_data_list` contains paths for supervised data shards. +- `unsupervised_data_list`contains paths for unsupervised data shards which is used for inference. +- `dir_split` is the directory stores split unsupervised data for parallel computing. +- `out_data_list` is the pseudo label data list file path. +- `enable_nst` indicates whether we train with pseudo label and split data, for initial teacher we set it to 0. +- This recipe uses the default `num_split=1` while we strongly recommend use larger number to decrease the inference and shards generation time. +> **HINTS** If num_split is set to N larger than 1, you need to modify the script in step 4-8 in run_nst.sh to submit N tasks into your own clusters (such as slurm,ngc etc..). +> We strongly recommend to do so since inference and pseudo-data generation is time-consuming. + +### Noisy student interations + +After finishing the initial fully supervised baseline, we now have the mixed list contains both supervised and pseudo data which is `wenet_1khr_nst0.list`. +We will use it as the `data_list` in the training step and the `data_list` for next NST iteration will be generated. + +Here is an example command: + +```bash +bash run.sh --stage 2 --stop-stage 2 --iter_num 2 +``` + +Here we add extra argument `iter_num` for number of NST iterations. Intermediate files are named with `iter_num` as a suffix. +Please check the `run.sh` and `run_nst.sh` scripts for more information about each stage and their arguments. + +## Performance Record + +### Supervised baseline and standard NST +* Non-streaming conformer model with attention rescoring decoder. +* Without filter strategy, first iteration +* Feature info: using FBANK feature, dither, cmvn, online speed perturb +* Training info: lr 0.002, batch size 32, 8 gpu, acc_grad 4, 240 epochs, dither 0.1 +* Decoding info: ctc_weight 0.3, average_num 30 + + +| Supervised | Unsupervised | Test CER | +|--------------------------|--------------|----------| +| AISHELL-1 Only | ---- | 4.85 | +| AISHELL-1+WenetSpeech | ---- | 3.54 | +| AISHELL-1+AISHELL-2 | ---- | 1.01 | +| AISHELL-1 (standard NST) | WenetSpeech | 5.52 | + + + +### Supervised AISHELL-1 and unsupervised 1khr WenetSpeech +* Non-streaming conformer model with attention rescoring decoder. +* Feature info: using FBANK feature +* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1 +* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75 + +| # nst iteration | AISHELL-1 test CER | Pseudo CER| Filtered CER | Filtered hours | +|----------------|--------------------|-----------|--------------|----------------| +| 0 | 4.85 | 47.10 | 25.18 | 323 | +| 1 | 4.86 | 37.02 | 20.93 | 436 | +| 2 | 4.75 | 31.81 | 19.74 | 540 | +| 3 | 4.69 | 28.27 | 17.85 | 592 | +| 4 | 4.48 | 26.64 | 14.76 | 588 | +| 5 | 4.41 | 24.70 | 15.86 | 670 | +| 6 | 4.34 | 23.64 | 15.40 | 669 | +| 7 | 4.31 | 23.79 | 15.75 | 694 | + +### Supervised AISHELL-2 and unsupervised 4khr WenetSpeech +* Non-streaming conformer model with attention rescoring decoder. +* Feature info: using FBANK feature +* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1 +* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75 + +| # nst iteration | AISHELL-2 test CER | Pseudo CER | Filtered CER | Filtered hours | +|----------------|--------------------|------------|--------------|----------------| +| 0 | 5.48 | 30.10 | 11.73 | 1637 | +| 1 | 5.09 | 28.31 | 9.39 | 2016 | +| 2 | 4.88 | 25.38 | 9.99 | 2186 | +| 3 | 4.74 | 22.47 | 10.66 | 2528 | +| 4 | 4.73 | 22.23 | 10.43 | 2734 | + + + +## Citations + +``` bibtex + +@article{chen2022NST, + title={Improving Noisy Student Training on Non-target Domain Data for Automatic Speech Recognition}, + author={Chen, Yu and Wen, Ding and Lai, Junjie}, + journal={arXiv preprint arXiv:2203.15455}, + year={2022} +} diff --git a/examples/aishell/NST/conf/train_conformer.yaml b/examples/aishell/NST/conf/train_conformer.yaml new file mode 100644 index 000000000..8499de2e9 --- /dev/null +++ b/examples/aishell/NST/conf/train_conformer.yaml @@ -0,0 +1,77 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 1200 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 240 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.002 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/NST/local/NST_plot.png b/examples/aishell/NST/local/NST_plot.png new file mode 100644 index 000000000..c652c62ca Binary files /dev/null and b/examples/aishell/NST/local/NST_plot.png differ diff --git a/examples/aishell/NST/local/generate_data_list.py b/examples/aishell/NST/local/generate_data_list.py new file mode 100644 index 000000000..684e7cb68 --- /dev/null +++ b/examples/aishell/NST/local/generate_data_list.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random + +def get_args(): + parser = argparse.ArgumentParser(description='generate data.list file ') + parser.add_argument('--tar_dir', help='path for tar file') + parser.add_argument('--supervised_data_list', + help='path for supervised data list') + parser.add_argument('--pseudo_data_ratio', + type=float, + help='ratio of pseudo data, ' + '0 means none pseudo data, ' + '1 means all using pseudo data.') + parser.add_argument('--out_data_list', help='output path for data list') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + target_dir = args.tar_dir + pseudo_data_list = os.listdir(target_dir) + output_file = args.out_data_list + pseudo_data_ratio = args.pseudo_data_ratio + supervised_path = args.supervised_data_list + with open(supervised_path, "r") as reader: + supervised_data_list = reader.readlines() + pseudo_len = len(pseudo_data_list) + supervised_len = len(supervised_data_list) + random.shuffle(pseudo_data_list) + random.shuffle(supervised_data_list) + + cur_ratio = pseudo_len / (pseudo_len + supervised_len) + if cur_ratio < pseudo_data_ratio: + pseudo_to_super_datio = pseudo_data_ratio / (1 - pseudo_data_ratio) + supervised_len = int(pseudo_len / pseudo_to_super_datio) + elif cur_ratio > pseudo_data_ratio: + super_to_pseudo_datio = (1 - pseudo_data_ratio) / pseudo_data_ratio + pseudo_len = int(supervised_len / super_to_pseudo_datio) + + for i in range(len(pseudo_data_list)): + pseudo_data_list[i] = target_dir + "/" + pseudo_data_list[i] + "\n" + + fused_list = pseudo_data_list[:pseudo_len] + supervised_data_list[:supervised_len] + + with open(output_file, "w") as writer: + for line in fused_list: + writer.write(line) + + +if __name__ == '__main__': + main() diff --git a/examples/aishell/NST/local/generate_filtered_pseudo_label.py b/examples/aishell/NST/local/generate_filtered_pseudo_label.py new file mode 100644 index 000000000..2a8ee83c3 --- /dev/null +++ b/examples/aishell/NST/local/generate_filtered_pseudo_label.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import tarfile +import time +import json + + +def get_args(): + parser = argparse.ArgumentParser(description='generate filter pseudo label') + parser.add_argument('--dir_num', required=True, help='split directory number') + parser.add_argument('--cer_hypo_dir', required=True, + help='prefix for cer_hypo_dir') + parser.add_argument('--utter_time_file', required=True, + help='the json file that contains audio time infos ') + parser.add_argument('--cer_hypo_threshold', required=True, type=float, + help='the cer-hypo threshold used to filter') + parser.add_argument('--speak_rate_threshold', type=float, + help='the cer threshold we use to filter') + parser.add_argument('--dir', required=True, help='dir for the experiment ') + # output untar and tar + parser.add_argument('--untar_dir', required=True, + help='the output path, ' + 'eg: data/train/wenet_untar_cer_hypo_nst1/') + parser.add_argument('--tar_dir', required=True, + help='the tar file path, ' + 'eg: data/train/wenet_tar_cer_hypo_leq_10_nst1/') + parser.add_argument('--wav_dir', required=True, + help='dir to store wav files, ' + 'eg "data/train/wenet_1k_untar/"') + parser.add_argument('--start_tar_id', default=0 , type=int, + help='the initial tar id (for debugging)') + args = parser.parse_args() + return args + + +def make_tarfile(output_filename, source_dir): + with tarfile.open(output_filename, "w") as tar: + tar.add(source_dir, arcname=os.path.basename(source_dir)) + + +def main(): + args = get_args() + dir_num = args.dir_num + dir_name = args.dir + output_dir = args.untar_dir + cer_hypo_threshold = args.cer_hypo_threshold + speak_rate_threshold = args.speak_rate_threshold + utter_time_file = args.utter_time_file + tar_dir = args.tar_dir + wav_dir = args.wav_dir + start_tar_id = args.start_tar_id + os.makedirs(tar_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + cer_hypo_name = args.cer_hypo_dir + print("start tar id is", start_tar_id) + print("make dirs") + + utter_time_enable = True + dataset = "wenet" + + utter_time = {} + if utter_time_enable: + + if dataset == "wenet": + print("wenet") + with open(utter_time_file, encoding='utf-8') as fh: + utter_time = json.load(fh) + + if dataset == "aishell2": + aishell2_jason = utter_time_file + print("aishell2") + with open(aishell2_jason, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + data_audio = data["audio_filepath"] + t_id = data_audio.split("/")[-1].split(".")[0] + data_duration = data["duration"] + utter_time[t_id] = data_duration + + print(time.time(), "start time ") + cer_dict = {} + print("dir_num = ", dir_num) + cer_hypo_path = dir_name + "/Hypo_LM_diff10/" + cer_hypo_name + cer_hypo_path = cer_hypo_path + "_" + dir_num + "/wer" + with open(cer_hypo_path, 'r', encoding="utf-8") as reader: + data = reader.readlines() + + for i in range(len(data)): + line = data[i] + if line[:3] == 'utt': + wer_list = data[i + 1].split() + wer_pred_lm = float(wer_list[1]) + n_hypo = int(wer_list[3].split("=")[1]) + + utt_list = line.split() + lab_list = data[i + 2].split() + rec_list = data[i + 3].split() + + utt_id = utt_list[1] + pred_no_lm = "".join(lab_list[1:]) + pred_lm = "".join(rec_list[1:]) + prediction = "".join(lab_list[1:]) + + if utter_time_enable: + + utt_time = utter_time[utt_id] + + cer_dict[utt_id] = [pred_no_lm, pred_lm, wer_pred_lm, + utt_time, n_hypo, prediction] + else: + cer_dict[utt_id] = [pred_no_lm, pred_lm, + wer_pred_lm, -1, -1, prediction] + + c = 0 + cer_preds = [] + uttr_len = [] + speak_rates = [] + num_lines = 0 + data_filtered = [] + + for key, item in cer_dict.items(): + + cer_pred = item[2] + speak_rate = item[4] / item[3] # char per second + + if cer_pred <= cer_hypo_threshold and speak_rate > speak_rate_threshold: + + num_lines += 1 + c += 1 + cer_preds.append(cer_pred) + uttr_len.append(item[4]) + speak_rates.append(speak_rate) + pred = item[1] + utt_id = key + filtered_line = [utt_id, pred] + data_filtered.append(filtered_line) + + num_uttr = 1000 + len_data = len(data_filtered) + print("total sentences after filter ") + cur_id = start_tar_id * 1000 + end_id = cur_id + num_uttr + if cur_id < len_data < end_id: + end_id = len_data + tar_id = start_tar_id + + not_exist = [] + while end_id <= len_data: + + tar_s = str(tar_id) + diff = 6 - len(tar_s) + for _ in range(diff): + tar_s = "0" + tar_s + + out_put_dir = output_dir + "dir" + str(dir_num) + out_put_dir = out_put_dir + "_" + "tar" + tar_s + "/" + os.makedirs(out_put_dir, exist_ok=True) + + for i in range(cur_id, end_id): + print("dir:", dir_num, ", " "tar: ", tar_id, + ", ", "progress:", i / len_data) + + t_id, utter = data_filtered[i] + + output_path = out_put_dir + t_id + ".txt" + wav_path = wav_dir + t_id + ".wav" + print(wav_path) + wav_exist = os.path.exists(wav_path) + if wav_exist: + # update .txt + with open(output_path, "w", encoding="utf-8") as writer: + writer.write(utter) + # update .wav + os.system("cp" + " " + wav_path + " " + + out_put_dir + t_id + ".wav") + else: + print(" wav does not exists ! ", wav_path) + not_exist.append(wav_path) + + tar_file_name = tar_dir + "dir" + str(dir_num) + "_" + tar_s + ".tar" + # tar the dir + + make_tarfile(tar_file_name, out_put_dir) + # update index + tar_id += 1 + cur_id += num_uttr + end_id += num_uttr + + if cur_id < len_data < end_id: + end_id = len_data + + print("end, now removing untar files for saving storge space.") + print("rm -rf" + " " + out_put_dir[:-1]) + os.system("rm -rf" + " " + out_put_dir[:-1]) + print("remove done") + + print("There are ", len(not_exist), "wav files not exist") + print(not_exist) + + +if __name__ == '__main__': + main() diff --git a/examples/aishell/NST/local/get_wav_labels.py b/examples/aishell/NST/local/get_wav_labels.py new file mode 100644 index 000000000..fb0c5c2b0 --- /dev/null +++ b/examples/aishell/NST/local/get_wav_labels.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser(description='sum up prediction wer') + parser.add_argument('--job_num', type=int, default=8, + help='number of total split dir') + parser.add_argument('--dir_split', required=True, + help='the path to the data_list dir ' + 'eg data/train/wenet1k_good_split_60/') + parser.add_argument('--label', type=int, default=0, + help='if ture, label file will also be considered.') + parser.add_argument('--hypo_name', type=str, required=True, + help='the hypothesis path. eg. /hypothesis_0.txt ') + parser.add_argument('--wav_dir', type=str, required=True, + help='the wav dir path. eg. data/train/wenet_1k_untar/ ') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + data_list_dir = args.dir_split + num_lists = args.job_num + hypo = args.hypo_name + # wav_dir is the directory where your pair of ID.scp + # (the audio file ) and ID.txt (the optional label file ) file stored. + # We assumed that you have generated this dir in data processing steps. + wav_dir = args.wav_dir + label = args.label + + print("data_list_path is", data_list_dir) + print("num_lists is", num_lists) + print("hypo is", hypo) + print("wav_dir is", wav_dir) + + i = num_lists + c = 0 + hypo_path = data_list_dir + "data_sublist" + str(i) + hypo + output_wav = data_list_dir + "data_sublist" + str(i) + "/wav.scp" + output_label = data_list_dir + "data_sublist" + str(i) + "/label.txt" + # bad lines are just for debugging + output_bad_lines = data_list_dir + "data_sublist" + str(i) + "/bad_line.txt" + + with open(hypo_path, 'r', encoding="utf-8") as reader: + hypo_lines = reader.readlines() + + wavs = [] + labels = [] + bad_files = [] + for x in hypo_lines: + c += 1 + file_id = x.split()[0] + + label_path = wav_dir + file_id + ".txt" + wav_path = wav_dir + file_id + ".wav\n" + wav_line = file_id + " " + wav_path + wavs.append(wav_line) + if label: + try: + with open(label_path, 'r', encoding="utf-8") as reader1: + label_line = reader1.readline() + except OSError as e: + bad_files.append(label_path) + + label_line = file_id + " " + label_line + "\n" + labels.append(label_line) + + with open(output_wav, 'w', encoding="utf-8") as writer2: + for wav in wavs: + writer2.write(wav) + with open(output_bad_lines, 'w', encoding="utf-8") as writer4: + for line in bad_files: + writer4.write(line) + if label: + with open(output_label, 'w', encoding="utf-8") as writer3: + for label in labels: + writer3.write(label) + + +if __name__ == '__main__': + main() diff --git a/examples/aishell/NST/local/split_data_list.py b/examples/aishell/NST/local/split_data_list.py new file mode 100644 index 000000000..17d507cb7 --- /dev/null +++ b/examples/aishell/NST/local/split_data_list.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('--job_nums', type=int, default=8, + help='number of total split jobs') + parser.add_argument('--data_list_path', required=True, + help='the path to the data.list file') + parser.add_argument('--output_dir', required=True, + help='path to output dir, ' + 'eg --output_dir=data/train/aishell_split_60') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + data_list_path = args.data_list_path + num_lists = args.job_nums + output_dir = args.output_dir + + print("data_list_path is", data_list_path) + print("num_lists is", num_lists) + print("output_dir is", output_dir) + os.makedirs(output_dir, exist_ok=True) + + with open(data_list_path, 'r', encoding="utf-8") as reader: + data_list_we = reader.readlines() + + # divide data.list equally + len_d = int(len(data_list_we) / num_lists) + rest_lines = data_list_we[num_lists * len_d:] + rest_len = len(rest_lines) + print("total num of lines", len(data_list_we) , "rest len is", rest_len) + + # generate N sublist + for i in range(num_lists): + print("current dir num", i) + out_put_sub_dir = output_dir + "/" + "data_sublist" + str(i) + "/" + os.makedirs(out_put_sub_dir, exist_ok=True) + output_list = out_put_sub_dir + "data_list" + + with open(output_list, 'w', encoding="utf-8") as writer: + + new_list = data_list_we[i * len_d: (i + 1) * len_d] + if i < rest_len: + new_list.append(rest_lines[i]) + for x in new_list: + # output list + writer.write(x) + + +if __name__ == '__main__': + main() diff --git a/examples/aishell/NST/path.sh b/examples/aishell/NST/path.sh new file mode 100644 index 000000000..5ddca76cc --- /dev/null +++ b/examples/aishell/NST/path.sh @@ -0,0 +1,8 @@ +export WENET_DIR=$PWD/../../.. +export BUILD_DIR=${WENET_DIR}/runtime/server/x86/build +export OPENFST_PREFIX_DIR=${BUILD_DIR}/../fc_base/openfst-subbuild/openfst-populate-prefix +export PATH=$PWD:${BUILD_DIR}:${BUILD_DIR}/kaldi:${OPENFST_PREFIX_DIR}/bin:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/aishell/NST/run.sh b/examples/aishell/NST/run.sh new file mode 100644 index 000000000..258f5061f --- /dev/null +++ b/examples/aishell/NST/run.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +iter_num=2 +stage=1 +stop_stage=1 +pseudo_data_ratio=0.75 +dir=exp/conformer_test_fully_supervised +data_list=data_aishell.list +supervised_data_list=data_aishell.list +unsupervised_data_list=wenet_1khr.list +dir_split=wenet_split_60_test/ +out_data_list=data/train/wenet_1khr_nst0.list +num_split=1 +. tools/parse_options.sh || exit 1; + +# Stage 1 trains the initial teacher and generates initial pseudo-labels. +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "******** stage 1 training the intial teacher ********" + bash run_nst.sh --dir $dir \ + --data_list $data_list \ + --supervised_data_list $supervised_data_list \ + --unsupervised_data_list $unsupervised_data_list \ + --dir_split $dir_split\ + --out_data_list $out_data_list \ + --enable_nst 0 \ + --pseudo_data_ratio pseudo_data_ratio \ + --num_split $num_split + +fi + +# Stage 2 trains the nst iterations. +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + + for ((i = 0; i < $iter_num; ++i)); do + { + echo "******** stage 2 training nst iteration number $i ********" + bash run_nst.sh --dir exp/conformer_nst${i+1} \ + --supervised_data_list data_aishell.list \ + --data_list wenet_1khr_nst${i}.list \ + --enable_nst 1 \ + --job_num 0 \ + --num_split $num_split \ + --hypo_name hypothesis_nst${i+1}.txt \ + --untar_dir wenet_1khr_untar_nst${i+1}/ \ + --tar_dir wenet_1khr_tar_nst${i+1}/ \ + --out_data_list wenet_1khr_nst${i+1}.list \ + --pseudo_data_ratio $pseudo_data_ratio + + } + done + +fi diff --git a/examples/aishell/NST/run_nst.sh b/examples/aishell/NST/run_nst.sh new file mode 100644 index 000000000..877d55ddd --- /dev/null +++ b/examples/aishell/NST/run_nst.sh @@ -0,0 +1,409 @@ +#!/bin/bash + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This is an augmented version of aishell-1 "run.sh" to make the code compatible with noisy student training + +. ./path.sh || exit 1; + +# Use this to control how many gpu you use, It's 1-gpu training if you specify +# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +# The NCCL_SOCKET_IFNAME variable specifies which IP interface to use for nccl +# communication. More details can be found in +# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html +# export NCCL_SOCKET_IFNAME=ens4f1 +export NCCL_DEBUG=INFO +stage=1 # start from 0 if you need to start from data preparation +stop_stage=8 + +# here are extra parameters used in NST +cer_out_dir="" +dir="" +supervised_data_list="" +checkpoint= +unsupervised_data_list="" +data_list="" + +hypo_name="" +out_data_list="" +#parameters with default values: +label=0 +average_num=30 +nj=16 +num_split=1 +cer_hypo_threshold=10 +speak_rate_threshold=0 +label_file="label.txt" +utter_time_file="utter_time.json" +enable_nst=1 +job_num=0 +dir_split="wenet_split_60_test/" +hypo_name="hypothesis_nst${job_num}.txt" +wav_dir="data/train/wenet_1k_untar/" +tar_dir="data/train/wenet_1khr_tar/" +untar_dir="data/train/wenet_1khr_untar/" +cer_hypo_dir="wenet_cer_hypo" +cer_label_dir="wenet_cer_label" +pseudo_data_ratio=0.75 + +# The num of machines(nodes) for multi-machine training, 1 is for one machine. +# NFS is required if num_nodes > 1. + +num_nodes=1 + +# The rank of each node or machine, which ranges from 0 to `num_nodes - 1`. +# You should set the node_ranHk=0 on the first machine, set the node_rank=1 +# on the second machine, and so on. +node_rank=0 +dict=data/dict/lang_char.txt + +# data_type can be `raw` or `shard`. Typically, raw is used for small dataset, +# `shard` is used for large dataset which is over 1k hours, and `shard` is +# faster on reading data and training. +data_type=shard +num_utts_per_shard=1000 +train_set=train +train_config=conf/train_conformer.yaml +cmvn=true +average_checkpoint=true +target_pt=80 +decode_checkpoint=$dir/$target_pt.pt + +# here we only use attention_rescoring for NST +decode_modes="attention_rescoring" + +. tools/parse_options.sh || exit 1; + +# print the settings +echo "setting for this run:" +echo "dir is ${dir}" +echo "data list is ${data_list}" +echo "job_num is ${job_num}" +echo "cer_out_dir is ${cer_out_dir}" +echo "average_num is ${average_num}" +echo "checkpoint is ${checkpoint} " +echo "enable_nst is ${enable_nst} " + +# we assumed that you have finished the data pre-process steps from -1 to 3 in aishell1/s0/run.sh . +# You can modify the "--train_data_supervised" to match your supervised data list. +# Here i used wenetspeech as the unsupervised data, you can run the data pre-process steps from -1 to 3 in +# wenetspeech/s0/run.sh ; you can modify "--train_data_supervised" to match your unsupervised data list. +# you can follow this process to generate your own dataset. +# I have also included my code for extracting data in local/... + +# stage 1 is for training +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "********step 1 start time : $now ********" + mkdir -p $dir + # You have to rm `INIT_FILE` manually when you resume or restart a + # multi-machine training. + rm $dir/ddp_init + INIT_FILE=$dir/ddp_init + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + dist_backend="gloo" + world_size=`expr $num_gpus \* $num_nodes` + echo "total gpus is: $world_size" + + # the global_cmvn file need to be calculated by combining both supervised/unsupervised datasets, + # and it should be positioned at data/${train_set}/global_cmvn . + cmvn_opts= + $cmvn && cp data/${train_set}/global_cmvn $dir/global_cmvn + $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" + + # train.py rewrite $train_config to $dir/train.yaml with model input + # and output dimension, and $dir/train.yaml will be used for inference + # and export. + echo "checkpoint is " ${checkpoint} + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + echo "gpu number $i " + # Rank of each gpu/process used for knowing whether it is + # the master of a worker. + + rank=`expr $node_rank \* $num_gpus + $i` + python wenet/bin/train.py --gpu $gpu_id \ + --config $train_config \ + --data_type $data_type \ + --symbol_table $dict \ + --train_data data/$train_set/$data_list \ + --cv_data data/dev/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + --model_dir $dir \ + --ddp.init_method $init_method \ + --ddp.world_size $world_size \ + --ddp.rank $rank \ + --ddp.dist_backend $dist_backend \ + --num_workers 1 \ + $cmvn_opts \ + --pin_memory + } & + done + wait +fi + +# In stage 2, we get the averaged final checkpoint and calculate the test and dev accuracy +# please make sure your test and valid data.list are in the proper location. +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # Test model, please specify the model you want to test by --checkpoint + # stage 5 we test with aishell dataset, + echo "******** step 2 start time : $now ********" + if [ ${average_checkpoint} == true ]; then + decode_checkpoint=$dir/avg_${average_num}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python wenet/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path $dir \ + --num ${average_num} \ + --val_best + fi + + # export model + python wenet/bin/export_jit.py \ + --config $dir/train.yaml \ + --checkpoint $dir/avg_${average_num}.pt \ + --output_file $dir/final.zip \ + --output_quant_file $dir/final_quant.zip + # Please specify decoding_chunk_size for unified streaming and + # non-streaming model. The default value is -1, which is full chunk + # for non-streaming inference. + decoding_chunk_size= + ctc_weight=0.5 + reverse_weight=0.0 + + # test_wer + for mode in ${decode_modes}; do + { + #test_dir=$dir/test_${mode}_${target_pt}pt # for target pt + test_dir=$dir/test_${mode}${average_num}pt # for average pt + mkdir -p $test_dir + python wenet/bin/recognize.py --gpu 0 \ + --mode $mode \ + --config $dir/train.yaml \ + --data_type $data_type \ + --test_data data/test/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --dict $dict \ + --ctc_weight $ctc_weight \ + --reverse_weight $reverse_weight \ + --result_file $test_dir/text \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + echo "before compute-wer" + python tools/compute-wer.py --char=1 --v=1 \ + data/test/text $test_dir/text > $test_dir/wer + } & + done + +# dev_wer + for mode in ${decode_modes}; do + { + #test_dir=$dir/test_${mode}_${target_pt}pt # for target pt + dev_dir=$dir/dev_${mode}${average_num}pt # for average pt + mkdir -p $dev_dir + python wenet/bin/recognize.py --gpu 0 \ + --mode $mode \ + --config $dir/train.yaml \ + --data_type $data_type \ + --test_data data/dev/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --dict $dict \ + --ctc_weight $ctc_weight \ + --reverse_weight $reverse_weight \ + --result_file $dev_dir/text \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + echo "before compute-wer" + python tools/compute-wer.py --char=1 --v=1 \ + data/dev/text $dev_dir/text > $dev_dir/wer + } & + done + wait +fi + + +# split the (unsupervised) datalist into N sublists, where N depends on the number of available cpu in your cluster. +# when making inference, we compute N sublist in parallel. +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ] && [ ${enable_nst} -eq 0 ]; then + echo "********step 3 start time : $now ********" + python local/split_data_list.py \ + --job_nums $num_split \ + --data_list_path data/train/$unsupervised_data_list \ + --output_dir data/train/$dir_split + +fi + + +# stage 4 will perform inference without language model on the given sublist(job num) +# here is example usages: +# bash run_nst.sh --stage 4 --stop-stage 4 --job_num $i --dir_split data/train/wenet_4khr_split_60/ +# --hypo_name hypothesis_0.txt --dir exp/conformer_aishell2_wenet4k_nst4 +# You need to specify the "job_num" n (n <= N), "dir_split" which is the dir path for split data +# "hypo_name" is the path for output hypothesis and "dir" is the path where we train and store the model. +# For each gpu, you can run with different job_num to perform data-wise parallel computing. +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "********step 4 start time : $now ********" + # we assume you have run stage 2 so that avg_${average_num}.pt exists + decode_checkpoint=$dir/avg_${average_num}.pt + # Please specify decoding_chunk_size for unified streaming and + # non-streaming model. The default value is -1, which is full chunk + # for non-streaming inference. + decoding_chunk_size= + ctc_weight=0.5 + reverse_weight=0.0 + mode="attention_rescoring" + gpu_id=0 + echo "job number ${job_num} " + echo "data_list dir is ${dir_split}" + echo "hypo name is " $hypo_name + echo "dir is ${dir}" + + python wenet/bin/recognize.py --gpu $gpu_id \ + --mode $mode \ + --config $dir/train.yaml \ + --data_type $data_type \ + --test_data data/train/${dir_split}data_sublist${job_num}/data_list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --dict $dict \ + --ctc_weight $ctc_weight \ + --reverse_weight $reverse_weight \ + --result_file data/train/${dir_split}data_sublist${job_num}/${hypo_name} \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + echo "end time : $now" + +fi + + +# Generate wav.scp file and label.txt file(optional) for each sublist we generated in step 3. +# the wav_dir should be prepared in data processing step as we mentioned. +#You need to specify the "job_num" n (n <= N), "dir_split" which is the dir path for split data, +# "hypo_name" is the path for output hypothesis and "dir" is the path where we train and store the model. +# wav_dir is the directory that stores raw wav file and possible labels. +# if you have label for unsupervised dataset, set label = 1 other wise keep it 0 +# For each gpu or cpu, you can run with different job_num to perform data-wise parallel computing. +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ] && [ ${enable_nst} -eq 0 ]; then + echo "********step 5 start time : $now ********" + python local/get_wav_labels.py \ + --dir_split data/train/${dir_split} \ + --hypo_name /$hypo_name \ + --wav_dir $wav_dir\ + --job_num $job_num \ + --label $label +fi + +# Calculate cer-hypo between hypothesis with and without language model. +# We assumed that you have finished language model +# training using the wenet aishell-1 pipline. (You should have data/lang/words.txt , data/lang/TLG.fst files ready.) +# Here is an exmaple usage: +# bash run_nst.sh --stage 5 --stop-stage 5 --job_num n --dir_split data/train/wenet1k_redo_split_60/ +# --cer_hypo_dir wenet1k_cer_hypo --hypo_name hypothesis_nst.txt --dir exp/conformer_no_filter_redo_nst6 +# You need to specify the "job_num" n (n <= N), "dir_split" which is the dir path for split data +# "hypo_name" is the path for output hypothesis and "dir" is the path where we train and store the model. +# For each gpu, you can run with different job_num to perform data-wise parallel computing. +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "********step 6 start time : $now ********" + chunk_size=-1 + mode="attention_rescoring" + test_dir=$dir/test_${mode}_${job_num} + now=$(date +"%T") + echo "start time : $now" + echo "GPU dir is " $job_num "dir_split is " data/train/${dir_split} + echo "nj is" $nj "hypo_file is" $hypo_name "cer out is" $cer_hypo_dir "lm is 4gram" + echo "dir is " $dir + if [ ! -f data/train/${dir_split}data_sublist${job_num}/${hypo_name} ]; then + echo "text file does not exists" + exit 1; + fi + + ./tools/decode.sh --nj 16 \ + --beam 15.0 --lattice_beam 7.5 --max_active 7000 \ + --blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \ + --chunk_size $chunk_size \ + --fst_path data/lang_test/TLG.fst \ + data/train/${dir_split}data_sublist${job_num}/wav.scp \ + data/train/${dir_split}data_sublist${job_num}/${hypo_name} $dir/final.zip \ + data/lang_test/words.txt $dir/Hypo_LM_diff10/${cer_hypo_dir}_${job_num} + now=$(date +"%T") + echo "end time : $now" +fi + +# (optional, only run this stage if you have true label for unsupervised data.) +# Calculate cer-label between true label and hypothesis with language model. +# You can use the output cer to evaluate NST's performance. +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ] && [ ${label} -eq 1 ]; then + echo "********step 7 start time : $now ********" + chunk_size=-1 + mode="attention_rescoring" + test_dir=$dir/test_${mode}_${job_num} + now=$(date +"%T") + echo "start time : $now" + echo "GPU dir is " $job_num "dir_split is " data/train/${dir_split} + echo "nj is" $nj "label_file is" $label_file "cer out is" $cer_label_dir "lm is 4gram" + echo "dir is " $dir + echo "label_file " data/train/${dir_split}data_sublist${job_num}/${label_file} + if [ ! -f data/train/${dir_split}data_sublist${job_num}/${label_file} ]; then + echo "text file does not exists" + exit 1; + fi + + ./tools/decode.sh --nj 16 \ + --beam 15.0 --lattice_beam 7.5 --max_active 7000 \ + --blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \ + --chunk_size $chunk_size \ + --fst_path data/lang_test/TLG.fst \ + data/train/${dir_split}data_sublist${job_num}/wav.scp \ + data/train/${dir_split}data_sublist${job_num}/${label_file} $dir/final.zip \ + data/lang_test/words.txt $dir/Hypo_LM_diff10/${cer_label_dir}_${job_num} + now=$(date +"%T") + echo "end time : $now" +fi + + +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + echo "********step 8 start time : $now ********" + python local/generate_filtered_pseudo_label.py \ + --cer_hypo_dir $cer_hypo_dir \ + --untar_dir data/train/$untar_dir \ + --wav_dir $wav_dir \ + --dir_num $job_num \ + --cer_hypo_threshold $cer_hypo_threshold \ + --speak_rate_threshold $speak_rate_threshold \ + --dir $dir \ + --tar_dir data/train/$tar_dir \ + --utter_time_file $utter_time_file + + python local/generate_data_list.py \ + --tar_dir data/train/$tar_dir \ + --out_data_list data/train/$out_data_list \ + --supervised_data_list data/train/$supervised_data_list \ + --pseudo_data_ratio $pseudo_data_ratio + +fi + + +