diff --git a/README.md b/README.md
index f560949f1fe..c6e9fc2093b 100644
--- a/README.md
+++ b/README.md
@@ -178,6 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
+- 👑 2023.05.31: Add [WavLM ASR-en](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/librispeech/asr5), WavLM fine-tuning for ASR on LibriSpeech.
- 👑 2023.05.04: Add [HuBERT ASR-en](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/librispeech/asr4), HuBERT fine-tuning for ASR on LibriSpeech.
- ⚡ 2023.04.28: Fix [0-d tensor](https://github.com/PaddlePaddle/PaddleSpeech/pull/3214), with the upgrade of paddlepaddle==2.5, the problem of modifying 0-d tensor has been solved.
- 👑 2023.04.25: Add [AMP for U2 conformer](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
diff --git a/README_cn.md b/README_cn.md
index 25a716d2961..eabb2ead45f 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -183,6 +183,8 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。
### 近期更新
+- 👑 2023.05.31: 新增 [WavLM ASR-en](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/librispeech/asr5), 基于WavLM的英语识别微调,使用LibriSpeech数据集
+- 👑 2023.05.04: 新增 [HuBERT ASR-en](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/librispeech/asr4), 基于HuBERT的英语识别微调,使用LibriSpeech数据集
- ⚡ 2023.04.28: 修正 [0-d tensor](https://github.com/PaddlePaddle/PaddleSpeech/pull/3214), 配合PaddlePaddle2.5升级修改了0-d tensor的问题。
- 👑 2023.04.25: 新增 [U2 conformer 的 AMP 训练](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)。
diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md
index 937cd95a365..ef9b2237d38 100644
--- a/demos/speech_ssl/README.md
+++ b/demos/speech_ssl/README.md
@@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
Arguments:
- `input`(required): Audio file to recognize.
- - `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert].
+ - `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert, wavlm].
- `task`: Output type. Default: `asr`.
- `lang`: Model language. Default: `en`.
- `sample_rate`: Sample rate of the model. Default: `16000`.
diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md
index 8455d2c77ac..a18c778a7d7 100644
--- a/demos/speech_ssl/README_cn.md
+++ b/demos/speech_ssl/README_cn.md
@@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
参数:
- `input`(必须输入):用于识别的音频文件。
- - `model`:ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert]。
+ - `model`:ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert, wavlm]。
- `task`:输出类别,默认值:`asr`。
- `lang`:模型语言,默认值:`en`。
- `sample_rate`:音频采样率,默认值:`16000`。
diff --git a/examples/librispeech/asr5/README.md b/examples/librispeech/asr5/README.md
new file mode 100644
index 00000000000..826c33cec71
--- /dev/null
+++ b/examples/librispeech/asr5/README.md
@@ -0,0 +1,197 @@
+# WavLM2ASR with Librispeech
+This example contains code used to finetune [WavLM](https://arxiv.org/abs/2110.13900) model with [Librispeech dataset](http://www.openslr.org/resources/12)
+## Overview
+All the scripts you need are in `run.sh`. There are several stages in `run.sh`, and each stage has its function.
+| Stage | Function |
+|:---- |:----------------------------------------------------------- |
+| 0 | Process data. It includes: (1) Download the dataset (2) Calculate the CMVN of the train dataset (3) Get the vocabulary file (4) Get the manifest files of the train, development and test dataset (5) Download the pretrained wav2vec2 model |
+| 1 | Train the model |
+| 2 | Get the final model by averaging the top-k models, set k = 1 means to choose the best model |
+| 3 | Test the final model performance |
+| 4 | Infer the single audio file |
+
+
+You can choose to run a range of stages by setting `stage` and `stop_stage `.
+
+For example, if you want to execute the code in stage 2 and stage 3, you can run this script:
+```bash
+bash run.sh --stage 2 --stop_stage 3
+```
+Or you can set `stage` equal to `stop-stage` to only run one stage.
+For example, if you only want to run `stage 0`, you can use the script below:
+```bash
+bash run.sh --stage 0 --stop_stage 0
+```
+The document below will describe the scripts in `run.sh` in detail.
+## The Environment Variables
+The path.sh contains the environment variables.
+```bash
+. ./path.sh
+. ./cmd.sh
+```
+This script needs to be run first. And another script is also needed:
+```bash
+source ${MAIN_ROOT}/utils/parse_options.sh
+```
+It will support the way of using `--variable value` in the shell scripts.
+## The Local Variables
+Some local variables are set in `run.sh`.
+`gpus` denotes the GPU number you want to use. If you set `gpus=`, it means you only use CPU.
+`stage` denotes the number of stages you want to start from in the experiments.
+`stop stage` denotes the number of the stage you want to end at in the experiments.
+`conf_path` denotes the config path of the model.
+`avg_num` denotes the number K of top-K models you want to average to get the final model.
+`audio file` denotes the file path of the single file you want to infer in stage 5
+`ckpt` denotes the checkpoint prefix of the model, e.g. "WavLMASR"
+
+You can set the local variables (except `ckpt`) when you use `run.sh`
+
+For example, you can set the `gpus` and `avg_num` when you use the command line:
+```bash
+bash run.sh --gpus 0,1 --avg_num 20
+```
+## Stage 0: Data Processing
+To use this example, you need to process data firstly and you can use stage 0 in `run.sh` to do this. The code is shown below:
+```bash
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # prepare data
+ bash ./local/data.sh || exit -1
+ fi
+```
+Stage 0 is for processing the data.
+
+If you only want to process the data. You can run
+```bash
+bash run.sh --stage 0 --stop_stage 0
+```
+You can also just run these scripts in your command line.
+```bash
+. ./path.sh
+. ./cmd.sh
+bash ./local/data.sh
+```
+After processing the data, the `data` directory will look like this:
+```bash
+data/
+|-- dev.meta
+|-- lang_char
+| `-- bpe_unigram_5000.model
+| `-- bpe_unigram_5000.vocab
+| `-- vocab.txt
+|-- manifest.dev
+|-- manifest.dev.raw
+|-- manifest.test
+|-- manifest.test.raw
+|-- manifest.train
+|-- manifest.train.raw
+|-- mean_std.json
+|-- test.meta
+`-- train.meta
+```
+
+Stage 0 also downloads the pre-trained [wavlm](https://paddlespeech.bj.bcebos.com/wavlm/wavlm-base-plus.pdparams) model.
+```bash
+mkdir -p exp/wavlm
+wget -P exp/wavlm https://paddlespeech.bj.bcebos.com/wavlm/wavlm-base-plus.pdparams
+```
+## Stage 1: Model Training
+If you want to train the model. you can use stage 1 in `run.sh`. The code is shown below.
+```bash
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `exp` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
+ fi
+```
+If you want to train the model, you can use the script below to execute stage 0 and stage 1:
+```bash
+bash run.sh --stage 0 --stop_stage 1
+```
+or you can run these scripts in the command line (only use CPU).
+```bash
+. ./path.sh
+. ./cmd.sh
+bash ./local/data.sh
+CUDA_VISIBLE_DEVICES= ./local/train.sh conf/wavlmASR.yaml wavlmASR
+```
+## Stage 2: Top-k Models Averaging
+After training the model, we need to get the final model for testing and inference. In every epoch, the model checkpoint is saved, so we can choose the best model from them based on the validation loss or we can sort them and average the parameters of the top-k models to get the final model. We can use stage 2 to do this, and the code is shown below. Note: We only train one epoch for wavlmASR, thus the `avg_num` is set to 1.
+```bash
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # avg n best model
+ avg.sh best exp/${ckpt}/checkpoints ${avg_num}
+ fi
+```
+The `avg.sh` is in the `../../../utils/` which is define in the `path.sh`.
+If you want to get the final model, you can use the script below to execute stage 0, stage 1, and stage 2:
+```bash
+bash run.sh --stage 0 --stop_stage 2
+```
+or you can run these scripts in the command line (only use CPU).
+
+```bash
+. ./path.sh
+. ./cmd.sh
+bash ./local/data.sh
+CUDA_VISIBLE_DEVICES= ./local/train.sh conf/wavlmASR.yaml wavlmASR
+avg.sh best exp/wavlmASR/checkpoints 1
+```
+## Stage 3: Model Testing
+The test stage is to evaluate the model performance. The code of test stage is shown below:
+```bash
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ # test ckpt avg_n
+ CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
+ fi
+```
+If you want to train a model and test it, you can use the script below to execute stage 0, stage 1, stage 2, and stage 3 :
+```bash
+bash run.sh --stage 0 --stop_stage 3
+```
+or you can run these scripts in the command line (only use CPU).
+```bash
+. ./path.sh
+. ./cmd.sh
+bash ./local/data.sh
+CUDA_VISIBLE_DEVICES= ./local/train.sh conf/wavlmASR.yaml wavlmASR
+avg.sh best exp/wavlmASR/checkpoints 1
+CUDA_VISIBLE_DEVICES= ./local/test.sh conf/wavlmASR.yaml conf/tuning/decode.yaml exp/wavlmASR/checkpoints/avg_1
+```
+## Pretrained Model
+You can get the pretrained wavlmASR from [this](../../../docs/source/released_model.md).
+
+using the `tar` scripts to unpack the model and then you can use the script to test the model.
+
+For example:
+```bash
+wget https://paddlespeech.bj.bcebos.com/wavlm/wavlmASR-base-100h-librispeech_ckpt_1.4.0.model.tar.gz
+tar xzvf wavlmASR-base-100h-librispeech_ckpt_1.4.0.model.tar.gz
+source path.sh
+# If you have process the data and get the manifest file, you can skip the following 2 steps
+bash local/data.sh --stage -1 --stop_stage -1
+bash local/data.sh --stage 2 --stop_stage 2
+CUDA_VISIBLE_DEVICES= ./local/test.sh conf/wavlmASR.yaml conf/tuning/decode.yaml exp/wavlmASR/checkpoints/avg_1
+```
+The performance of the released models are shown in [here](./RESULTS.md).
+
+
+## Stage 4: Single Audio File Inference
+In some situations, you want to use the trained model to do the inference for the single audio file. You can use stage 5. The code is shown below
+```bash
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ # test a single .wav file
+ CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
+ fi
+```
+you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
+```bash
+wget https://paddlespeech.bj.bcebos.com/wavlm/wavlm_baseplus_libriclean_100h.tar.gz
+tar xzvf wavlm_baseplus_libriclean_100h.tar.gz
+```
+You can download the audio demo:
+```bash
+wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
+```
+You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below.
+```bash
+CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/wavlmASR.yaml conf/tuning/decode.yaml exp/wavlmASR/checkpoints/avg_1 data/demo_002_en.wav
+```
diff --git a/examples/librispeech/asr5/RESULTS.md b/examples/librispeech/asr5/RESULTS.md
new file mode 100644
index 00000000000..806b39a1f85
--- /dev/null
+++ b/examples/librispeech/asr5/RESULTS.md
@@ -0,0 +1,9 @@
+# LibriSpeech
+
+## WavLMASR
+Fintuning on train-clean-100
+train: Epoch 16, 4*A800-80G, batchsize: 16, accum_grad: 8
+
+| Model | Params | Config | Augmentation| Test set | Decode method | WER |
+| --- | --- | --- | --- | --- | --- | --- |
+| WavLMASR | 326.16M | conf/wavlmasr.yaml | spec_aug | test-clean | greedy search | 0.0561 |
diff --git a/examples/librispeech/asr5/avg.sh b/examples/librispeech/asr5/avg.sh
new file mode 100644
index 00000000000..c49b5c25f52
--- /dev/null
+++ b/examples/librispeech/asr5/avg.sh
@@ -0,0 +1,33 @@
+#! /usr/bin/env bash
+
+if [ $# != 3 ]; then
+ echo "usage: ${0} [best|latest] ckpt_dir avg_num"
+ exit -1
+fi
+
+avg_mode=${1} # best,latest
+ckpt_dir=${2}
+average_num=${3}
+decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
+
+if [ $avg_mode == best ];then
+ # best
+ python avg_model.py \
+ --dst_model ${decode_checkpoint} \
+ --ckpt_dir ${ckpt_dir} \
+ --num ${average_num} \
+ --val_best
+else
+ # latest
+ python avg_model.py \
+ --dst_model ${decode_checkpoint} \
+ --ckpt_dir ${ckpt_dir} \
+ --num ${average_num}
+fi
+
+if [ $? -ne 0 ]; then
+ echo "Failed in avg ckpt!"
+ exit 1
+fi
+
+exit 0
diff --git a/examples/librispeech/asr5/cmd.sh b/examples/librispeech/asr5/cmd.sh
new file mode 100644
index 00000000000..7b70ef5e06e
--- /dev/null
+++ b/examples/librispeech/asr5/cmd.sh
@@ -0,0 +1,89 @@
+# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======
+# Usage: .pl [options] JOB=1:
+# e.g.
+# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB
+#
+# Options:
+# --time : Limit the maximum time to execute.
+# --mem : Limit the maximum memory usage.
+# -–max-jobs-run : Limit the number parallel jobs. This is ignored for non-array jobs.
+# --num-threads : Specify the number of CPU core.
+# --gpu : Specify the number of GPU devices.
+# --config: Change the configuration file from default.
+#
+# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs.
+# The left string of "=", i.e. "JOB", is replaced by (Nth job) in the command and the log file name,
+# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively.
+# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example.
+#
+# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend.
+# These options are mapping to specific options for each backend and
+# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default.
+# If jobs failed, your configuration might be wrong for your environment.
+#
+#
+# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl:
+# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html
+# =========================================================~
+
+
+# Select the backend used by run.sh from "local", "sge", "slurm", or "ssh"
+cmd_backend='local'
+
+# Local machine, without any Job scheduling system
+if [ "${cmd_backend}" = local ]; then
+
+ # The other usage
+ export train_cmd="run.pl"
+ # Used for "*_train.py": "--gpu" is appended optionally by run.sh
+ export cuda_cmd="run.pl"
+ # Used for "*_recog.py"
+ export decode_cmd="run.pl"
+
+# "qsub" (SGE, Torque, PBS, etc.)
+elif [ "${cmd_backend}" = sge ]; then
+ # The default setting is written in conf/queue.conf.
+ # You must change "-q g.q" for the "queue" for your environment.
+ # To know the "queue" names, type "qhost -q"
+ # Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler.
+
+ export train_cmd="queue.pl"
+ export cuda_cmd="queue.pl"
+ export decode_cmd="queue.pl"
+
+# "sbatch" (Slurm)
+elif [ "${cmd_backend}" = slurm ]; then
+ # The default setting is written in conf/slurm.conf.
+ # You must change "-p cpu" and "-p gpu" for the "partion" for your environment.
+ # To know the "partion" names, type "sinfo".
+ # You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*"
+ # The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}".
+
+ export train_cmd="slurm.pl"
+ export cuda_cmd="slurm.pl"
+ export decode_cmd="slurm.pl"
+
+elif [ "${cmd_backend}" = ssh ]; then
+ # You have to create ".queue/machines" to specify the host to execute jobs.
+ # e.g. .queue/machines
+ # host1
+ # host2
+ # host3
+ # Assuming you can login them without any password, i.e. You have to set ssh keys.
+
+ export train_cmd="ssh.pl"
+ export cuda_cmd="ssh.pl"
+ export decode_cmd="ssh.pl"
+
+# This is an example of specifying several unique options in the JHU CLSP cluster setup.
+# Users can modify/add their own command options according to their cluster environments.
+elif [ "${cmd_backend}" = jhu ]; then
+
+ export train_cmd="queue.pl --mem 2G"
+ export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/gpu.conf"
+ export decode_cmd="queue.pl --mem 4G"
+
+else
+ echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2
+ return 1
+fi
diff --git a/examples/librispeech/asr5/compute_wer.py b/examples/librispeech/asr5/compute_wer.py
new file mode 100644
index 00000000000..5711c725b77
--- /dev/null
+++ b/examples/librispeech/asr5/compute_wer.py
@@ -0,0 +1,558 @@
+# Copyright 2021 Mobvoi Inc. All Rights Reserved.
+# flake8: noqa
+import codecs
+import re
+import sys
+import unicodedata
+
+remove_tag = True
+spacelist = [' ', '\t', '\r', '\n']
+puncts = [
+ '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
+ '《', '》'
+]
+
+
+def characterize(string):
+ res = []
+ i = 0
+ while i < len(string):
+ char = string[i]
+ if char in puncts:
+ i += 1
+ continue
+ cat1 = unicodedata.category(char)
+ #https://unicodebook.readthedocs.io/unicode.html#unicode-categories
+ if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
+ i += 1
+ continue
+ if cat1 == 'Lo': # letter-other
+ res.append(char)
+ i += 1
+ else:
+ # some input looks like: , we want to separate it to two words.
+ sep = ' '
+ if char == '<': sep = '>'
+ j = i + 1
+ while j < len(string):
+ c = string[j]
+ if ord(c) >= 128 or (c in spacelist) or (c == sep):
+ break
+ j += 1
+ if j < len(string) and string[j] == '>':
+ j += 1
+ res.append(string[i:j])
+ i = j
+ return res
+
+
+def stripoff_tags(x):
+ if not x: return ''
+ chars = []
+ i = 0
+ T = len(x)
+ while i < T:
+ if x[i] == '<':
+ while i < T and x[i] != '>':
+ i += 1
+ i += 1
+ else:
+ chars.append(x[i])
+ i += 1
+ return ''.join(chars)
+
+
+def normalize(sentence, ignore_words, cs, split=None):
+ """ sentence, ignore_words are both in unicode
+ """
+ new_sentence = []
+ for token in sentence:
+ x = token
+ if not cs:
+ x = x.upper()
+ if x in ignore_words:
+ continue
+ if remove_tag:
+ x = stripoff_tags(x)
+ if not x:
+ continue
+ if split and x in split:
+ new_sentence += split[x]
+ else:
+ new_sentence.append(x)
+ return new_sentence
+
+
+class Calculator:
+ def __init__(self):
+ self.data = {}
+ self.space = []
+ self.cost = {}
+ self.cost['cor'] = 0
+ self.cost['sub'] = 1
+ self.cost['del'] = 1
+ self.cost['ins'] = 1
+
+ def calculate(self, lab, rec):
+ # Initialization
+ lab.insert(0, '')
+ rec.insert(0, '')
+ while len(self.space) < len(lab):
+ self.space.append([])
+ for row in self.space:
+ for element in row:
+ element['dist'] = 0
+ element['error'] = 'non'
+ while len(row) < len(rec):
+ row.append({'dist': 0, 'error': 'non'})
+ for i in range(len(lab)):
+ self.space[i][0]['dist'] = i
+ self.space[i][0]['error'] = 'del'
+ for j in range(len(rec)):
+ self.space[0][j]['dist'] = j
+ self.space[0][j]['error'] = 'ins'
+ self.space[0][0]['error'] = 'non'
+ for token in lab:
+ if token not in self.data and len(token) > 0:
+ self.data[token] = {
+ 'all': 0,
+ 'cor': 0,
+ 'sub': 0,
+ 'ins': 0,
+ 'del': 0
+ }
+ for token in rec:
+ if token not in self.data and len(token) > 0:
+ self.data[token] = {
+ 'all': 0,
+ 'cor': 0,
+ 'sub': 0,
+ 'ins': 0,
+ 'del': 0
+ }
+ # Computing edit distance
+ for i, lab_token in enumerate(lab):
+ for j, rec_token in enumerate(rec):
+ if i == 0 or j == 0:
+ continue
+ min_dist = sys.maxsize
+ min_error = 'none'
+ dist = self.space[i - 1][j]['dist'] + self.cost['del']
+ error = 'del'
+ if dist < min_dist:
+ min_dist = dist
+ min_error = error
+ dist = self.space[i][j - 1]['dist'] + self.cost['ins']
+ error = 'ins'
+ if dist < min_dist:
+ min_dist = dist
+ min_error = error
+ if lab_token == rec_token:
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
+ error = 'cor'
+ else:
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
+ error = 'sub'
+ if dist < min_dist:
+ min_dist = dist
+ min_error = error
+ self.space[i][j]['dist'] = min_dist
+ self.space[i][j]['error'] = min_error
+ # Tracing back
+ result = {
+ 'lab': [],
+ 'rec': [],
+ 'all': 0,
+ 'cor': 0,
+ 'sub': 0,
+ 'ins': 0,
+ 'del': 0
+ }
+ i = len(lab) - 1
+ j = len(rec) - 1
+ while True:
+ if self.space[i][j]['error'] == 'cor': # correct
+ if len(lab[i]) > 0:
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
+ self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
+ result['all'] = result['all'] + 1
+ result['cor'] = result['cor'] + 1
+ result['lab'].insert(0, lab[i])
+ result['rec'].insert(0, rec[j])
+ i = i - 1
+ j = j - 1
+ elif self.space[i][j]['error'] == 'sub': # substitution
+ if len(lab[i]) > 0:
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
+ self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
+ result['all'] = result['all'] + 1
+ result['sub'] = result['sub'] + 1
+ result['lab'].insert(0, lab[i])
+ result['rec'].insert(0, rec[j])
+ i = i - 1
+ j = j - 1
+ elif self.space[i][j]['error'] == 'del': # deletion
+ if len(lab[i]) > 0:
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
+ self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
+ result['all'] = result['all'] + 1
+ result['del'] = result['del'] + 1
+ result['lab'].insert(0, lab[i])
+ result['rec'].insert(0, "")
+ i = i - 1
+ elif self.space[i][j]['error'] == 'ins': # insertion
+ if len(rec[j]) > 0:
+ self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
+ result['ins'] = result['ins'] + 1
+ result['lab'].insert(0, "")
+ result['rec'].insert(0, rec[j])
+ j = j - 1
+ elif self.space[i][j]['error'] == 'non': # starting point
+ break
+ else: # shouldn't reach here
+ print(
+ 'this should not happen , i = {i} , j = {j} , error = {error}'.
+ format(i=i, j=j, error=self.space[i][j]['error']))
+ return result
+
+ def overall(self):
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
+ for token in self.data:
+ result['all'] = result['all'] + self.data[token]['all']
+ result['cor'] = result['cor'] + self.data[token]['cor']
+ result['sub'] = result['sub'] + self.data[token]['sub']
+ result['ins'] = result['ins'] + self.data[token]['ins']
+ result['del'] = result['del'] + self.data[token]['del']
+ return result
+
+ def cluster(self, data):
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
+ for token in data:
+ if token in self.data:
+ result['all'] = result['all'] + self.data[token]['all']
+ result['cor'] = result['cor'] + self.data[token]['cor']
+ result['sub'] = result['sub'] + self.data[token]['sub']
+ result['ins'] = result['ins'] + self.data[token]['ins']
+ result['del'] = result['del'] + self.data[token]['del']
+ return result
+
+ def keys(self):
+ return list(self.data.keys())
+
+
+def width(string):
+ return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
+
+
+def default_cluster(word):
+ unicode_names = [unicodedata.name(char) for char in word]
+ for i in reversed(range(len(unicode_names))):
+ if unicode_names[i].startswith('DIGIT'): # 1
+ unicode_names[i] = 'Number' # 'DIGIT'
+ elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
+ unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
+ # 明 / 郎
+ unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
+ elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
+ unicode_names[i].startswith('LATIN SMALL LETTER')):
+ # A / a
+ unicode_names[i] = 'English' # 'LATIN LETTER'
+ elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
+ unicode_names[i] = 'Japanese' # 'GANA LETTER'
+ elif (unicode_names[i].startswith('AMPERSAND') or
+ unicode_names[i].startswith('APOSTROPHE') or
+ unicode_names[i].startswith('COMMERCIAL AT') or
+ unicode_names[i].startswith('DEGREE CELSIUS') or
+ unicode_names[i].startswith('EQUALS SIGN') or
+ unicode_names[i].startswith('FULL STOP') or
+ unicode_names[i].startswith('HYPHEN-MINUS') or
+ unicode_names[i].startswith('LOW LINE') or
+ unicode_names[i].startswith('NUMBER SIGN') or
+ unicode_names[i].startswith('PLUS SIGN') or
+ unicode_names[i].startswith('SEMICOLON')):
+ # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
+ del unicode_names[i]
+ else:
+ return 'Other'
+ if len(unicode_names) == 0:
+ return 'Other'
+ if len(unicode_names) == 1:
+ return unicode_names[0]
+ for i in range(len(unicode_names) - 1):
+ if unicode_names[i] != unicode_names[i + 1]:
+ return 'Other'
+ return unicode_names[0]
+
+
+def usage():
+ print(
+ "compute-wer.py : compute word error rate (WER) and align recognition results and references."
+ )
+ print(
+ " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
+ )
+
+
+def main():
+ # python utils/compute-wer.py --char=1 --v=1 ref hyp > rsl.error
+ if len(sys.argv) == 1:
+ usage()
+ sys.exit(0)
+ calculator = Calculator()
+ cluster_file = ''
+ ignore_words = set()
+ tochar = False
+ verbose = 1
+ padding_symbol = ' '
+ case_sensitive = False
+ max_words_per_line = sys.maxsize
+ split = None
+ while len(sys.argv) > 3:
+ a = '--maxw='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):]
+ del sys.argv[1]
+ max_words_per_line = int(b)
+ continue
+ a = '--rt='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ remove_tag = (b == 'true') or (b != '0')
+ continue
+ a = '--cs='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ case_sensitive = (b == 'true') or (b != '0')
+ continue
+ a = '--cluster='
+ if sys.argv[1].startswith(a):
+ cluster_file = sys.argv[1][len(a):]
+ del sys.argv[1]
+ continue
+ a = '--splitfile='
+ if sys.argv[1].startswith(a):
+ split_file = sys.argv[1][len(a):]
+ del sys.argv[1]
+ split = dict()
+ with codecs.open(split_file, 'r', 'utf-8') as fh:
+ for line in fh: # line in unicode
+ words = line.strip().split()
+ if len(words) >= 2:
+ split[words[0]] = words[1:]
+ continue
+ a = '--ig='
+ if sys.argv[1].startswith(a):
+ ignore_file = sys.argv[1][len(a):]
+ del sys.argv[1]
+ with codecs.open(ignore_file, 'r', 'utf-8') as fh:
+ for line in fh: # line in unicode
+ line = line.strip()
+ if len(line) > 0:
+ ignore_words.add(line)
+ continue
+ a = '--char='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ tochar = (b == 'true') or (b != '0')
+ continue
+ a = '--v='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ verbose = 0
+ try:
+ verbose = int(b)
+ except:
+ if b == 'true' or b != '0':
+ verbose = 1
+ continue
+ a = '--padding-symbol='
+ if sys.argv[1].startswith(a):
+ b = sys.argv[1][len(a):].lower()
+ del sys.argv[1]
+ if b == 'space':
+ padding_symbol = ' '
+ elif b == 'underline':
+ padding_symbol = '_'
+ continue
+ if True or sys.argv[1].startswith('-'):
+ #ignore invalid switch
+ del sys.argv[1]
+ continue
+
+ if not case_sensitive:
+ ig = set([w.upper() for w in ignore_words])
+ ignore_words = ig
+
+ default_clusters = {}
+ default_words = {}
+
+ ref_file = sys.argv[1]
+ hyp_file = sys.argv[2]
+ rec_set = {}
+ if split and not case_sensitive:
+ newsplit = dict()
+ for w in split:
+ words = split[w]
+ for i in range(len(words)):
+ words[i] = words[i].upper()
+ newsplit[w.upper()] = words
+ split = newsplit
+
+ with codecs.open(hyp_file, 'r', 'utf-8') as fh:
+ for line in fh:
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.strip().split()
+ if len(array) == 0: continue
+ fid = array[0]
+ rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
+ split)
+
+ # compute error rate on the interaction of reference file and hyp file
+ for line in open(ref_file, 'r', encoding='utf-8'):
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.rstrip('\n').split()
+ if len(array) == 0: continue
+ fid = array[0]
+ if fid not in rec_set:
+ continue
+ lab = normalize(array[1:], ignore_words, case_sensitive, split)
+ rec = rec_set[fid]
+ if verbose:
+ print('\nutt: %s' % fid)
+
+ for word in rec + lab:
+ if word not in default_words:
+ default_cluster_name = default_cluster(word)
+ if default_cluster_name not in default_clusters:
+ default_clusters[default_cluster_name] = {}
+ if word not in default_clusters[default_cluster_name]:
+ default_clusters[default_cluster_name][word] = 1
+ default_words[word] = default_cluster_name
+
+ result = calculator.calculate(lab, rec)
+ if verbose:
+ if result['all'] != 0:
+ wer = float(result['ins'] + result['sub'] + result[
+ 'del']) * 100.0 / result['all']
+ else:
+ wer = 0.0
+ print('WER: %4.2f %%' % wer, end=' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'],
+ result['ins']))
+ space = {}
+ space['lab'] = []
+ space['rec'] = []
+ for idx in range(len(result['lab'])):
+ len_lab = width(result['lab'][idx])
+ len_rec = width(result['rec'][idx])
+ length = max(len_lab, len_rec)
+ space['lab'].append(length - len_lab)
+ space['rec'].append(length - len_rec)
+ upper_lab = len(result['lab'])
+ upper_rec = len(result['rec'])
+ lab1, rec1 = 0, 0
+ while lab1 < upper_lab or rec1 < upper_rec:
+ if verbose > 1:
+ print('lab(%s):' % fid.encode('utf-8'), end=' ')
+ else:
+ print('lab:', end=' ')
+ lab2 = min(upper_lab, lab1 + max_words_per_line)
+ for idx in range(lab1, lab2):
+ token = result['lab'][idx]
+ print('{token}'.format(token=token), end='')
+ for n in range(space['lab'][idx]):
+ print(padding_symbol, end='')
+ print(' ', end='')
+ print()
+ if verbose > 1:
+ print('rec(%s):' % fid.encode('utf-8'), end=' ')
+ else:
+ print('rec:', end=' ')
+ rec2 = min(upper_rec, rec1 + max_words_per_line)
+ for idx in range(rec1, rec2):
+ token = result['rec'][idx]
+ print('{token}'.format(token=token), end='')
+ for n in range(space['rec'][idx]):
+ print(padding_symbol, end='')
+ print(' ', end='')
+ print('\n', end='\n')
+ lab1 = lab2
+ rec1 = rec2
+
+ if verbose:
+ print(
+ '==========================================================================='
+ )
+ print()
+
+ result = calculator.overall()
+ if result['all'] != 0:
+ wer = float(result['ins'] + result['sub'] + result[
+ 'del']) * 100.0 / result['all']
+ else:
+ wer = 0.0
+ print('Overall -> %4.2f %%' % wer, end=' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'],
+ result['ins']))
+ if not verbose:
+ print()
+
+ if verbose:
+ for cluster_id in default_clusters:
+ result = calculator.cluster(
+ [k for k in default_clusters[cluster_id]])
+ if result['all'] != 0:
+ wer = float(result['ins'] + result['sub'] + result[
+ 'del']) * 100.0 / result['all']
+ else:
+ wer = 0.0
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'],
+ result['ins']))
+ if len(cluster_file) > 0: # compute separated WERs for word clusters
+ cluster_id = ''
+ cluster = []
+ for line in open(cluster_file, 'r', encoding='utf-8'):
+ for token in line.decode('utf-8').rstrip('\n').split():
+ # end of cluster reached, like
+ if token[0:2] == '' and token[len(token)-1] == '>' and \
+ token.lstrip('').rstrip('>') == cluster_id :
+ result = calculator.cluster(cluster)
+ if result['all'] != 0:
+ wer = float(result['ins'] + result['sub'] + result[
+ 'del']) * 100.0 / result['all']
+ else:
+ wer = 0.0
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'],
+ result['del'], result['ins']))
+ cluster_id = ''
+ cluster = []
+ # begin of cluster reached, like
+ elif token[0] == '<' and token[len(token)-1] == '>' and \
+ cluster_id == '' :
+ cluster_id = token.lstrip('<').rstrip('>')
+ cluster = []
+ # general terms, like WEATHER / CAR / ...
+ else:
+ cluster.append(token)
+ print()
+ print(
+ '==========================================================================='
+ )
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/librispeech/asr5/conf/preprocess.yaml b/examples/librispeech/asr5/conf/preprocess.yaml
new file mode 100644
index 00000000000..724782ed65b
--- /dev/null
+++ b/examples/librispeech/asr5/conf/preprocess.yaml
@@ -0,0 +1,3 @@
+process:
+ # use raw audio
+ - type: wav_process
diff --git a/examples/librispeech/asr5/conf/preprocessor_config.json b/examples/librispeech/asr5/conf/preprocessor_config.json
new file mode 100644
index 00000000000..36ebe8b7c1c
--- /dev/null
+++ b/examples/librispeech/asr5/conf/preprocessor_config.json
@@ -0,0 +1,9 @@
+{
+ "do_normalize": true,
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
+ "feature_size": 1,
+ "padding_side": "right",
+ "padding_value": 0,
+ "return_attention_mask": true,
+ "sampling_rate": 16000
+}
diff --git a/examples/librispeech/asr5/conf/tuning/decode.yaml b/examples/librispeech/asr5/conf/tuning/decode.yaml
new file mode 100644
index 00000000000..e5781495d46
--- /dev/null
+++ b/examples/librispeech/asr5/conf/tuning/decode.yaml
@@ -0,0 +1,4 @@
+decode_batch_size: 1
+error_rate_type: wer
+decoding_method: "ctc_greedy_search" # 'ctc_greedy_search', 'ctc_prefix_beam_search'
+beam_size: 10
diff --git a/examples/librispeech/asr5/conf/wavlmASR.yaml b/examples/librispeech/asr5/conf/wavlmASR.yaml
new file mode 100644
index 00000000000..25f9643e9d1
--- /dev/null
+++ b/examples/librispeech/asr5/conf/wavlmASR.yaml
@@ -0,0 +1,137 @@
+############################################
+# Network Architecture #
+############################################
+freeze_wavlm: False
+normalize_wav: True
+output_norm: True
+init_type: kaiming_uniform # !Warning: need to convergence
+enc:
+ input_shape: 768
+ dnn_blocks: 2
+ dnn_neurons: 768
+ activation: True
+ normalization: True
+ dropout_rate: [0.15, 0]
+ctc:
+ enc_n_units: 768
+ blank_id: 0
+ dropout_rate: 0.0
+wavlm_params_path: exp/wavlm/wavlm-base-plus.pdparams
+
+
+task_cfg:
+ label_rate: 50.0
+ sample_rate: 16000
+ normalize: True
+ enable_padding: False
+ max_keep_size: None
+ max_sample_size: 250000
+ min_sample_size: 32000
+ dropout_input: 0.1
+ final_dropout: 0.0
+ dropout: 0.1
+ attention_dropout: 0.0
+ activation_dropout: 0.1
+ apply_mask: True
+ mask_length: 10
+ mask_prob: 0.5
+ mask_selection: static
+ mask_other: 0.0
+ no_mask_overlap: False
+ mask_channel_length: 10
+ mask_channel_prob: 0.0
+ mask_channel_selection: static
+ mask_channel_other: 0.0
+ no_mask_channel_overlap: False
+ feature_grad_mult: 0.0
+ layerdrop: 0.1
+ fp16: True
+ extractor_mode: layer_norm
+ encoder_layers: 12
+ encoder_embed_dim: 768
+ encoder_ffn_embed_dim: 3072
+ encoder_attention_heads: 12
+ activation_fn: gelu
+ encoder_layerdrop: 0.0
+ dropout_features: 0.0
+ final_dim: 768
+ untie_final_proj: True
+ layer_norm_first: True
+ conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
+ conv_bias: False
+ logit_temp: 0.1
+ target_glu: False
+ mask_min_space: 1
+ mask_channel_min_space: 1
+ conv_pos: 128
+ conv_pos_groups: 16
+ latent_temp: [2.0, 0.5, 0.999995]
+ skip_masked: False
+ skip_nomask: True
+
+###########################################
+# Data #
+###########################################
+train_manifest: data/manifest.train
+dev_manifest: data/manifest.dev
+test_manifest: data/manifest.test-clean
+
+###########################################
+# Dataloader #
+###########################################
+vocab_filepath: data/lang_char/vocab.txt
+unit_type: char
+mean_std_filepath: ""
+preprocess_config: conf/preprocess.yaml
+sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
+batch_size: 8 # Different batch_size may cause large differences in results
+maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
+maxlen_out: 160000
+minibatches: 0 # for debug
+batch_count: auto
+batch_bins: 0
+batch_frames_in: 0
+batch_frames_out: 0
+batch_frames_inout: 0
+num_workers: 0
+subsampling_factor: 1
+num_encs: 1
+dist_sampler: True
+shortest_first: False
+return_lens_rate: True
+
+############################################
+# Data Augmentation #
+############################################
+audio_augment: # for raw audio
+ sample_rate: 16000
+ speeds: [90, 100, 110]
+
+###########################################
+# Training #
+###########################################
+n_epoch: 10
+accum_grad: 8
+global_grad_clip: 5.0
+model_scheduler: newbobscheduler
+model_scheduler_conf:
+ improvement_threshold: 0.0025
+ annealing_factor: 0.8
+ patient: 0
+model_optim: adam
+model_optim_conf:
+ lr: 0.0001
+ weight_decay: 0.0
+# I changed this
+wavlm_optim: adam
+wavlm_optim_conf:
+ lr: 0.00005
+ weight_decay: 0.0
+wavlm_scheduler: constantlr
+wavlm_scheduler_conf:
+ warmup_steps: 1000
+ lr_decay: 1.0
+log_interval: 1
+checkpoint:
+ kbest_n: 50
+ latest_n: 5
diff --git a/examples/librispeech/asr5/local/data.sh b/examples/librispeech/asr5/local/data.sh
new file mode 100644
index 00000000000..8e69dd7694f
--- /dev/null
+++ b/examples/librispeech/asr5/local/data.sh
@@ -0,0 +1,110 @@
+#!/bin/bash
+
+stage=-1
+stop_stage=100
+
+unit_type=char
+dict_dir=data/lang_char
+
+source ${MAIN_ROOT}/utils/parse_options.sh
+
+mkdir -p data
+mkdir -p ${dict_dir}
+TARGET_DIR=${MAIN_ROOT}/dataset
+mkdir -p ${TARGET_DIR}
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ # download data, generate manifests
+ python3 ${TARGET_DIR}/librispeech/librispeech.py \
+ --manifest_prefix="data/manifest" \
+ --target_dir="${TARGET_DIR}/librispeech" \
+ --full_download="False"
+
+ if [ $? -ne 0 ]; then
+ echo "Prepare LibriSpeech failed. Terminated."
+ exit 1
+ fi
+
+ for set in train-clean-100 dev-clean test-clean; do
+ mv data/manifest.${set} data/manifest.${set}.raw
+ done
+
+ rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
+ for set in train-clean-100; do
+ cat data/manifest.${set}.raw >> data/manifest.train.raw
+ done
+
+ for set in dev-clean; do
+ cat data/manifest.${set}.raw >> data/manifest.dev.raw
+ done
+
+ for set in test-clean; do
+ cat data/manifest.${set}.raw >> data/manifest.test.raw
+ done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # compute mean and stddev for normalizer
+ num_workers=$(nproc)
+ python ${MAIN_ROOT}/utils/compute_mean_std.py \
+ --manifest_path="data/manifest.train.raw" \
+ --num_samples=2000 \
+ --spectrum_type="fbank" \
+ --feat_dim=161 \
+ --delta_delta=false \
+ --sample_rate=16000 \
+ --stride_ms=10 \
+ --window_ms=25 \
+ --use_dB_normalization=False \
+ --num_workers=${num_workers} \
+ --output_path="data/mean_std.json"
+
+ if [ $? -ne 0 ]; then
+ echo "Compute mean and stddev failed. Terminated."
+ exit 1
+ fi
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # build vocabulary
+ python3 ${MAIN_ROOT}/utils/build_vocab.py \
+ --unit_type ${unit_type} \
+ --count_threshold=0 \
+ --vocab_path="${dict_dir}/vocab.txt" \
+ --manifest_paths="data/manifest.train.raw"
+
+ if [ $? -ne 0 ]; then
+ echo "Build vocabulary failed. Terminated."
+ exit 1
+ fi
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # format manifest with tokenids, vocab size
+ for set in train dev test dev-clean test-clean; do
+ {
+ python3 ${MAIN_ROOT}/utils/format_data.py \
+ --cmvn_path "data/mean_std.json" \
+ --unit_type ${unit_type} \
+ --vocab_path="${dict_dir}/vocab.txt" \
+ --manifest_path="data/manifest.${set}.raw" \
+ --output_path="data/manifest.${set}"
+
+ if [ $? -ne 0 ]; then
+ echo "Formt manifest.${set} failed. Terminated."
+ exit 1
+ fi
+ }&
+ done
+ wait
+fi
+
+echo "LibriSpeech Data preparation done."
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ mkdir -p exp/wavlm
+ echo "Pretrained wavlm model download"
+ wget -P exp/wavlm https://paddlespeech.bj.bcebos.com/wavlm/wavlm-base-plus.pdparams
+fi
+
+exit 0
\ No newline at end of file
diff --git a/examples/librispeech/asr5/local/test.sh b/examples/librispeech/asr5/local/test.sh
new file mode 100644
index 00000000000..18158bd5046
--- /dev/null
+++ b/examples/librispeech/asr5/local/test.sh
@@ -0,0 +1,83 @@
+#!/bin/bash
+
+set -e
+
+ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+echo "using $ngpu gpus..."
+
+expdir=exp
+datadir=data
+
+recog_set="test-clean test-other dev-clean dev-other"
+recog_set="test-clean"
+
+config_path=$1
+decode_config_path=$2
+ckpt_prefix=$3
+
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
+
+# download language model
+#bash local/download_lm_en.sh
+#if [ $? -ne 0 ]; then
+# exit 1
+#fi
+
+python3 format_rsl.py \
+ --origin_ref data/manifest.test-clean.raw \
+ --trans_ref data/manifest.test-clean.text
+
+
+for type in ctc_greedy_search; do
+ echo "decoding ${type}"
+ batch_size=16
+ python3 -u ${BIN_DIR}/test.py \
+ --ngpu ${ngpu} \
+ --config ${config_path} \
+ --decode_cfg ${decode_config_path} \
+ --result_file ${ckpt_prefix}.${type}.rsl \
+ --checkpoint_path ${ckpt_prefix} \
+ --opts decode.decoding_method ${type} \
+ --opts decode.decode_batch_size ${batch_size}
+
+ if [ $? -ne 0 ]; then
+ echo "Failed in evaluation!"
+ exit 1
+ fi
+ python3 format_rsl.py \
+ --origin_hyp ${ckpt_prefix}.${type}.rsl \
+ --trans_hyp ${ckpt_prefix}.${type}.rsl.text
+
+ python3 compute_wer.py --char=1 --v=1 \
+ data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
+ echo "decoding ${type} done."
+done
+
+for type in ctc_prefix_beam_search; do
+ echo "decoding ${type}"
+ batch_size=1
+ python3 -u ${BIN_DIR}/test.py \
+ --ngpu ${ngpu} \
+ --config ${config_path} \
+ --decode_cfg ${decode_config_path} \
+ --result_file ${ckpt_prefix}.${type}.rsl \
+ --checkpoint_path ${ckpt_prefix} \
+ --opts decode.decoding_method ${type} \
+ --opts decode.decode_batch_size ${batch_size}
+
+ if [ $? -ne 0 ]; then
+ echo "Failed in evaluation!"
+ exit 1
+ fi
+ python3 format_rsl.py \
+ --origin_hyp ${ckpt_prefix}.${type}.rsl \
+ --trans_hyp ${ckpt_prefix}.${type}.rsl.text
+
+ python3 compute_wer.py --char=1 --v=1 \
+ data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
+ echo "decoding ${type} done."
+done
+
+echo "Finished"
+
+exit 0
diff --git a/examples/librispeech/asr5/local/test_wav.sh b/examples/librispeech/asr5/local/test_wav.sh
new file mode 100644
index 00000000000..fdf3589f4ba
--- /dev/null
+++ b/examples/librispeech/asr5/local/test_wav.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+if [ $# != 4 ];then
+ echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
+ exit -1
+fi
+
+ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+echo "using $ngpu gpus..."
+
+config_path=$1
+decode_config_path=$2
+ckpt_prefix=$3
+audio_file=$4
+
+mkdir -p data
+wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
+if [ $? -ne 0 ]; then
+ exit 1
+fi
+
+if [ ! -f ${audio_file} ]; then
+ echo "Plase input the right audio_file path"
+ exit 1
+fi
+
+chunk_mode=false
+if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
+ chunk_mode=true
+fi
+
+# download language model
+#bash local/download_lm_ch.sh
+#if [ $? -ne 0 ]; then
+# exit 1
+#fi
+
+for type in ctc_greedy_search; do
+ echo "decoding ${type}"
+ batch_size=1
+ output_dir=${ckpt_prefix}
+ mkdir -p ${output_dir}
+ python3 -u ${BIN_DIR}/test_wav.py \
+ --ngpu ${ngpu} \
+ --config ${config_path} \
+ --decode_cfg ${decode_config_path} \
+ --result_file ${output_dir}/${type}.rsl \
+ --checkpoint_path ${ckpt_prefix} \
+ --opts decode.decoding_method ${type} \
+ --opts decode.decode_batch_size ${batch_size} \
+ --audio_file ${audio_file}
+
+ if [ $? -ne 0 ]; then
+ echo "Failed in evaluation!"
+ exit 1
+ fi
+done
+exit 0
diff --git a/examples/librispeech/asr5/local/train.sh b/examples/librispeech/asr5/local/train.sh
new file mode 100644
index 00000000000..24776fd1723
--- /dev/null
+++ b/examples/librispeech/asr5/local/train.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+if [ $# -lt 2 ] && [ $# -gt 3 ];then
+ echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
+ exit -1
+fi
+
+ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+echo "using $ngpu gpus..."
+
+config_path=$1
+ckpt_name=$2
+resume=$3
+ips=$4
+
+if [ ! $ips ];then
+ ips_config=
+else
+ ips_config="--ips="${ips}
+fi
+
+mkdir -p exp
+
+# seed may break model convergence
+seed=1988
+if [ ${seed} != 0 ]; then
+ export FLAGS_cudnn_deterministic=True
+fi
+
+# export FLAGS_cudnn_exhaustive_search=true
+# export FLAGS_conv_workspace_size_limit=4000
+export FLAGS_allocator_strategy=naive_best_fit
+if [ ${ngpu} == 0 ]; then
+python3 -u ${BIN_DIR}/train.py \
+--ngpu ${ngpu} \
+--config ${config_path} \
+--output exp/${ckpt_name} \
+--seed ${seed} \
+--resume ${resume}
+else
+python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
+--ngpu ${ngpu} \
+--config ${config_path} \
+--output exp/${ckpt_name} \
+--seed ${seed} \
+--resume ${resume}
+fi
+
+if [ ${seed} != 0 ]; then
+ unset FLAGS_cudnn_deterministic
+fi
+
+if [ $? -ne 0 ]; then
+ echo "Failed in training!"
+ exit 1
+fi
+
+exit 0
diff --git a/examples/librispeech/asr5/path.sh b/examples/librispeech/asr5/path.sh
new file mode 100644
index 00000000000..dbf3a9404f6
--- /dev/null
+++ b/examples/librispeech/asr5/path.sh
@@ -0,0 +1,13 @@
+export MAIN_ROOT=`realpath ${PWD}/../../../`
+
+export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
+export LC_ALL=C
+
+export PYTHONDONTWRITEBYTECODE=1
+# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+# export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
+
+export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
+
+export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/wavlm/bin
diff --git a/examples/librispeech/asr5/run.sh b/examples/librispeech/asr5/run.sh
new file mode 100644
index 00000000000..9634bc8c80d
--- /dev/null
+++ b/examples/librispeech/asr5/run.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+set -e
+
+. ./path.sh || exit 1;
+. ./cmd.sh || exit 1;
+
+gpus=0,1,2
+stage=0
+stop_stage=3
+conf_path=conf/wavlmASR.yaml
+ips= #xx.xx.xx.xx,xx.xx.xx.xx
+decode_conf_path=conf/tuning/decode.yaml
+avg_num=3
+resume= # xx e.g. 30
+
+. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
+
+audio_file=data/demo_002_en.wav
+
+# avg_ckpt=avg_${avg_num}
+avg_ckpt=4
+ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
+echo "checkpoint name ${ckpt}"
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # prepare data
+ bash ./local/data.sh || exit -1
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `exp` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
+fi
+
+# if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+# # avg n best model
+# ./avg.sh best exp/${ckpt}/checkpoints ${avg_num}
+# fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ # greedy search decoder
+ CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ # test a single .wav file
+ CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
+fi
diff --git a/examples/librispeech/asr5/utils b/examples/librispeech/asr5/utils
new file mode 100644
index 00000000000..973afe674f2
--- /dev/null
+++ b/examples/librispeech/asr5/utils
@@ -0,0 +1 @@
+../../../utils
\ No newline at end of file
diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py
index bc3c632d5e0..9b4b0280314 100644
--- a/paddlespeech/cli/ssl/infer.py
+++ b/paddlespeech/cli/ssl/infer.py
@@ -52,7 +52,7 @@ def __init__(self):
'--model',
type=str,
default='wav2vec2',
- choices=['wav2vec2', 'hubert'],
+ choices=['wav2vec2', 'hubert', "wavlm"],
help='Choose model type of asr task.')
self.parser.add_argument(
'--task',
@@ -157,6 +157,12 @@ def _init_from_path(self,
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
+ elif model_type == 'wavlm':
+ if lang == "en":
+ model_prefix = "wavlmASR_librispeech"
+ elif lang == "zh":
+ logger.error("zh wavlmASR is not supported yet")
+ tag = model_prefix + '-' + lang + '-' + sample_rate_str
else:
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py
index 04872c72ec6..6bf9b588e5a 100644
--- a/paddlespeech/resource/model_alias.py
+++ b/paddlespeech/resource/model_alias.py
@@ -25,6 +25,7 @@
"wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"],
"hubertASR": ["paddlespeech.s2t.models.hubert:HubertASR"],
"hubert": ["paddlespeech.s2t.models.hubert:HubertBase"],
+ "wavlmASR": ["paddlespeech.s2t.models.wavlm:WavLMASR"],
# ---------------------------------
# -------------- ASR --------------
diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py
index e56188640dc..e539c001825 100644
--- a/paddlespeech/resource/pretrained_models.py
+++ b/paddlespeech/resource/pretrained_models.py
@@ -149,6 +149,16 @@
'exp/hubertASR/checkpoints/avg_1.pdparams',
},
},
+ "wavlmASR_librispeech-en-16k": {
+ "1.0": {
+ "url": "https://paddlespeech.bj.bcebos.com/wavlm/wavlm_baseplus_libriclean_100h.tar.gz",
+ "md5": "f2238e982bb8bcf046e536201f5ea629",
+ "cfg_path": "model.yaml",
+ "ckpt_path": "exp/wavlmASR/checkpoints/46",
+ "model": "exp/wavlmASR/checkpoints/46.pdparams",
+ "params": "exp/wavlmASR/checkpoints/46.pdparams",
+ }
+ }
}
# ---------------------------------
diff --git a/paddlespeech/s2t/exps/wavlm/__init__.py b/paddlespeech/s2t/exps/wavlm/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/paddlespeech/s2t/exps/wavlm/bin/__init__.py b/paddlespeech/s2t/exps/wavlm/bin/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/paddlespeech/s2t/exps/wavlm/bin/test.py b/paddlespeech/s2t/exps/wavlm/bin/test.py
new file mode 100644
index 00000000000..f56b418bc8f
--- /dev/null
+++ b/paddlespeech/s2t/exps/wavlm/bin/test.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Evaluation for WavLM model."""
+import cProfile
+
+from yacs.config import CfgNode
+
+from paddlespeech.s2t.exps.wavlm.model import WavLMASRTester as Tester
+from paddlespeech.s2t.training.cli import default_argument_parser
+from paddlespeech.utils.argparse import print_arguments, add_arguments
+
+
+def main_sp(config, args):
+ exp = Tester(config, args)
+ with exp.eval():
+ exp.setup()
+ exp.run_test()
+
+
+def main(config, args):
+ main_sp(config, args)
+
+
+if __name__ == "__main__":
+ parser = default_argument_parser()
+ # save asr result to
+ parser.add_argument(
+ '--dict-path', type=str, default=None, help='dict path.')
+ parser.add_argument(
+ "--result_file", type=str, help="path of save the asr result")
+ args = parser.parse_args()
+ print_arguments(args, globals())
+
+ # https://yaml.org/type/float.html
+ config = CfgNode(new_allowed=True)
+ if args.config:
+ config.merge_from_file(args.config)
+ if args.decode_cfg:
+ decode_confs = CfgNode(new_allowed=True)
+ decode_confs.merge_from_file(args.decode_cfg)
+ config.decode = decode_confs
+ if args.opts:
+ config.merge_from_list(args.opts)
+ config.freeze()
+ print(config)
+ if args.dump_config:
+ with open(args.dump_config, 'w') as f:
+ print(config, file=f)
+
+ # Setting for profiling
+ pr = cProfile.Profile()
+ pr.runcall(main, config, args)
+ pr.dump_stats('test.profile')
diff --git a/paddlespeech/s2t/exps/wavlm/bin/test_wav.py b/paddlespeech/s2t/exps/wavlm/bin/test_wav.py
new file mode 100644
index 00000000000..e6c07629d81
--- /dev/null
+++ b/paddlespeech/s2t/exps/wavlm/bin/test_wav.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Evaluation for wavlm model."""
+import os
+import sys
+from pathlib import Path
+
+import paddle
+import soundfile
+from paddlenlp.transformers import AutoTokenizer
+from yacs.config import CfgNode
+
+from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
+from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
+from paddlespeech.s2t.training.cli import default_argument_parser
+from paddlespeech.s2t.utils.log import Log
+from paddlespeech.s2t.utils.utility import UpdateConfig
+logger = Log(__name__).getlog()
+
+
+class WavLMInfer():
+ def __init__(self, config, args):
+ self.args = args
+ self.config = config
+ self.audio_file = args.audio_file
+ self.tokenizer = config.get("tokenizer", None)
+
+ if self.tokenizer:
+ self.text_feature = AutoTokenizer.from_pretrained(
+ self.config.tokenizer)
+ else:
+ self.text_feature = TextFeaturizer(
+ unit_type=config.unit_type, vocab=config.vocab_filepath)
+
+ paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
+
+ # model
+ model_conf = config
+ with UpdateConfig(model_conf):
+ model_conf.output_dim = self.text_feature.vocab_size
+ model = WavLMASR.from_config(model_conf)
+ self.model = model
+ self.model.eval()
+
+ # load model
+ params_path = self.args.checkpoint_path + ".pdparams"
+ model_dict = paddle.load(params_path)
+ self.model.set_state_dict(model_dict)
+
+ def run(self):
+ check(args.audio_file)
+
+ with paddle.no_grad():
+ # read
+ audio, _ = soundfile.read(
+ self.audio_file, dtype="int16", always_2d=True)
+ logger.info(f"audio shape: {audio.shape}")
+ xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
+ decode_config = self.config.decode
+ result_transcripts, result_tokenids = self.model.decode(
+ xs,
+ text_feature=self.text_feature,
+ decoding_method=decode_config.decoding_method,
+ beam_size=decode_config.beam_size,
+ tokenizer=self.tokenizer, )
+ rsl = result_transcripts[0]
+ utt = Path(self.audio_file).name
+ logger.info(f"hyp: {utt} {rsl}")
+ return rsl
+
+
+def check(audio_file):
+ if not os.path.isfile(audio_file):
+ print("Please input the right audio file path")
+ sys.exit(-1)
+
+ logger.info("checking the audio file format......")
+ try:
+ sig, sample_rate = soundfile.read(audio_file)
+ except Exception as e:
+ logger.error(str(e))
+ logger.error(
+ "can not open the wav file, please check the audio file format")
+ sys.exit(-1)
+ logger.info("The sample rate is %d" % sample_rate)
+ assert (sample_rate == 16000)
+ logger.info("The audio file format is right")
+
+
+def main(config, args):
+ WavLMInfer(config, args).run()
+
+
+if __name__ == "__main__":
+ parser = default_argument_parser()
+ # save asr result to
+ parser.add_argument(
+ "--result_file", type=str, help="path of save the asr result")
+ parser.add_argument(
+ "--audio_file", type=str, help="path of the input audio file")
+ args = parser.parse_args()
+
+ config = CfgNode(new_allowed=True)
+
+ if args.config:
+ config.merge_from_file(args.config)
+ if args.decode_cfg:
+ decode_confs = CfgNode(new_allowed=True)
+ decode_confs.merge_from_file(args.decode_cfg)
+ config.decode = decode_confs
+ if args.opts:
+ config.merge_from_list(args.opts)
+ config.freeze()
+ main(config, args)
diff --git a/paddlespeech/s2t/exps/wavlm/bin/train.py b/paddlespeech/s2t/exps/wavlm/bin/train.py
new file mode 100644
index 00000000000..4ad966b7ffc
--- /dev/null
+++ b/paddlespeech/s2t/exps/wavlm/bin/train.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Trainer for wavlm model."""
+import cProfile
+import os
+
+from yacs.config import CfgNode
+
+from paddlespeech.s2t.exps.wavlm.model import WavLMASRTrainer as Trainer
+from paddlespeech.s2t.training.cli import default_argument_parser
+from paddlespeech.utils.argparse import print_arguments, add_arguments
+
+
+def main_sp(config, args):
+ exp = Trainer(config, args)
+ exp.setup()
+ exp.run()
+
+
+def main(config, args):
+ main_sp(config, args)
+
+
+if __name__ == "__main__":
+ parser = default_argument_parser()
+ parser.add_argument(
+ '--resume', type=str, default="", nargs="?", help='resume ckpt path.')
+ args = parser.parse_args()
+ print_arguments(args, globals())
+ # https://yaml.org/type/float.html
+ config = CfgNode(new_allowed=True)
+ if args.config:
+ config.merge_from_file(args.config)
+ if args.opts:
+ config.merge_from_list(args.opts)
+ config.freeze()
+ if args.dump_config:
+ with open(args.dump_config, 'w') as f:
+ print(config, file=f)
+
+ # Setting for profiling
+ pr = cProfile.Profile()
+ pr.runcall(main, config, args)
+ pr.dump_stats(os.path.join(args.output, 'train.profile'))
diff --git a/paddlespeech/s2t/exps/wavlm/model.py b/paddlespeech/s2t/exps/wavlm/model.py
new file mode 100644
index 00000000000..6ed2c5d8796
--- /dev/null
+++ b/paddlespeech/s2t/exps/wavlm/model.py
@@ -0,0 +1,912 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains wavlm model."""
+import json
+import math
+import os
+import re
+import time
+from collections import OrderedDict
+from contextlib import nullcontext
+
+import jsonlines
+import numpy as np
+import paddle
+from hyperpyyaml import load_hyperpyyaml
+from paddle import distributed as dist
+from paddlenlp.transformers import AutoTokenizer
+
+from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
+from paddlespeech.s2t.io.dataloader import DataLoaderFactory
+from paddlespeech.s2t.io.speechbrain import data_pipeline
+from paddlespeech.s2t.io.speechbrain import dataio
+from paddlespeech.s2t.io.speechbrain import dataset
+from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
+from paddlespeech.s2t.models.wavlm.processing.speech_augmentation import TimeDomainSpecAugment
+from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
+from paddlespeech.s2t.training.optimizer import OptimizerFactory
+from paddlespeech.s2t.training.reporter import ObsScope
+from paddlespeech.s2t.training.reporter import report
+from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
+from paddlespeech.s2t.training.timer import Timer
+from paddlespeech.s2t.training.trainer import Trainer
+from paddlespeech.s2t.utils import error_rate
+from paddlespeech.s2t.utils import layer_tools
+from paddlespeech.s2t.utils import mp_tools
+from paddlespeech.s2t.utils.log import Log
+from paddlespeech.s2t.utils.utility import UpdateConfig
+
+logger = Log(__name__).getlog()
+
+
+def clip_grad_norm_(
+ parameters,
+ max_norm,
+ norm_type=2.0,
+ error_if_nonfinite=False, ):
+ r"""Clips gradient norm of the iteratable parameters.
+
+ Norms are calculated together on all gradients, just as they are
+ connected into one vector. The gradient will be modified in place.
+
+ This API can only run in dynamic graph mode, not static graph mode.
+
+ Args:
+ parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor
+ that will be normalized gradients
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be `inf` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, throw an error if the total
+ norm of the gradients from :attr:`parameters` is `nan`,
+ `inf`, or `-inf`.
+
+ Returns:
+ Total norm of the parameter gradients (treated as a single vector).
+ Example:
+ .. code-block:: python
+ import paddle
+
+ x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
+ max_norm = float(5.0)
+ linear = paddle.nn.Linear(in_features=10, out_features=10)
+ out = linear(x)
+ loss = paddle.mean(out)
+ loss.backward()
+
+ paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm)
+
+ sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters())
+ sdg.step()
+ """
+ if not paddle.in_dynamic_mode():
+ raise RuntimeError('this API can only run in dynamic mode.')
+
+ if isinstance(parameters, paddle.Tensor):
+ parameters = [parameters]
+
+ support_norm_type = [float("inf"), 0, 1, 2]
+ if norm_type not in support_norm_type:
+ raise ValueError(f'norm_type only support {support_norm_type}')
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return paddle.to_tensor(0.0)
+ if norm_type == float("inf"):
+ norms = [g.detach().abs().max() for g in grads]
+ total_norm = (norms[0]
+ if len(norms) == 1 else paddle.max(paddle.stack(norms)))
+ else:
+ total_norm = paddle.linalg.norm(
+ paddle.stack(
+ [paddle.linalg.norm(g.detach(), norm_type) for g in grads]),
+ norm_type, )
+
+ if error_if_nonfinite and paddle.logical_or(total_norm.isnan(),
+ total_norm.isinf()):
+ raise RuntimeError(
+ f'The total norm of {norm_type} order of the gradients from '
+ '`parameters` is non-finite, so it cannot be clipped. In any case, '
+ 'disable this error and scale the gradient by non-finite norm, '
+ 'set `error_if_nonfinite=False`')
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this
+ # avoids the `if clip_coef < 1:` condition.
+ clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
+ with paddle.no_grad():
+ for _, p in enumerate(parameters):
+ g = p.grad
+ if g is not None:
+ p.grad = paddle.multiply(x=g, y=clip_coef_clamped)
+ return total_norm
+
+
+class WavLMASRTrainer(Trainer):
+ def __init__(self, config, args):
+ super().__init__(config, args)
+ self.avg_train_loss = 0.0
+ self.loss_isfinite = True # while flag is 'False', loss in Nan or inf, and can not be avg
+ self.use_sb = True # whether use speech brain dataloader
+
+ def update_average(self, batch_index, loss):
+ """Update running average of the loss.
+ Arguments
+ ---------
+ batch_index : int
+ current batch index
+ loss : paddle.tensor
+ detached loss, a single float value.
+ """
+ if math.isfinite(loss):
+ self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
+ self.avg_train_loss += loss / (batch_index + 1)
+ else:
+ self.loss_isfinite = False
+ logger.info('loss:{} in Nan or inf, error'.format(loss))
+
+ def before_train(self):
+ from_scratch = self.resume_or_scratch()
+ if from_scratch:
+ # scratch: save init model, i.e. 0 epoch
+ self.save(tag='init', infos=None)
+ else:
+ # resume: train next_epoch and next_iteration
+ self.epoch += 1
+ logger.info(
+ f"Resume train: epoch {self.epoch }, step {self.iteration}!")
+
+ self.maybe_batch_sampler_step()
+
+ def train_batch(self, batch_index, batch, msg):
+ train_conf = self.config
+ start = time.time()
+
+ # forward
+ ## sb data pipeline
+ if self.use_sb:
+ wav, wavs_lens_rate = batch['sig']
+ target, target_lens_rate = batch['tokens']
+ target_lens = (target_lens_rate *
+ target.shape[1]).round().astype(paddle.int64)
+ else:
+ utt, wav, wavs_lens, target, target_lens = batch
+ wavs_lens_rate = wavs_lens / wav.shape[1]
+ wav = wav[:, :, 0]
+
+ if hasattr(train_conf, 'audio_augment'):
+ wav = self.speech_augmentation(wav, wavs_lens_rate)
+ loss = self.model(wav, wavs_lens_rate, target, target_lens)
+
+ # loss div by `batch_size * accum_grad`
+ loss /= train_conf.accum_grad
+ # update self.avg_train_loss
+ self.update_average(batch_index, float(loss))
+
+ # loss backward
+ if (batch_index + 1) % train_conf.accum_grad != 0:
+ # Disable gradient synchronizations across DDP processes.
+ # Within this context, gradients will be accumulated on module
+ # variables, which will later be synchronized.
+ # When using cpu w/o DDP, model does not have `no_sync`
+ context = self.model.no_sync if (hasattr(self.model, "no_sync") and
+ self.parallel) else nullcontext
+ else:
+ # Used for single gpu training and DDP gradient synchronization
+ # processes.
+ context = nullcontext
+ with context():
+ loss.backward()
+
+ layer_tools.print_grads(self.model, print_func=None)
+
+ # NOTE: the code below asserted that the backward() is problematic, and as more steps are accumulated, the output from wavlm alone will be the same for all frames
+ # optimizer step old
+ if (batch_index + 1) % train_conf.accum_grad == 0:
+ #do global grad clip
+ if train_conf.global_grad_clip != 0:
+ clip_grad_norm_(self.model.parameters(),
+ train_conf.global_grad_clip)
+ self.model_optimizer.step()
+ self.model_optimizer.clear_grad()
+ if not train_conf.freeze_wavlm:
+ self.wavlm_optimizer.step()
+ self.wavlm_optimizer.clear_grad()
+ if self.config.model_scheduler != 'newbobscheduler':
+ self.model_lr_scheduler.step()
+ if self.config.wavlm_scheduler != 'newbobscheduler':
+ if not train_conf.freeze_wavlm:
+ self.wavlm_lr_scheduler.step()
+ self.iteration += 1
+
+ losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
+ iteration_time = time.time() - start
+ for k, v in losses_np.items():
+ report(k, v)
+ report("loss_whitoutavg", float(loss))
+ report("batch_size", self.config.batch_size)
+ report("accum", train_conf.accum_grad)
+ report("step_cost", iteration_time)
+
+ if (batch_index + 1) % train_conf.accum_grad == 0:
+ if dist.get_rank() == 0 and self.visualizer:
+ losses_np_v = losses_np.copy()
+ losses_np_v.update({
+ "model_lr": self.model_lr_scheduler(),
+ "wavlm_lr": self.wavlm_lr_scheduler()
+ })
+ for key, val in losses_np_v.items():
+ self.visualizer.add_scalar(
+ tag='train/' + key, value=val, step=self.iteration - 1)
+
+ @paddle.no_grad()
+ def valid(self):
+ self.model.eval()
+ if not self.use_streamdata:
+ logger.info(
+ f"Valid Total Examples: {len(self.valid_loader.dataset)}")
+ valid_losses = {}
+ step = 0
+ total_loss = 0.0
+ num_seen_utts = 1 # use update_average and no need for num_seen_utts here
+ for i, batch in enumerate(self.valid_loader):
+ if self.use_sb:
+ wav, wavs_lens_rate = batch['sig']
+ target, target_lens_rate = batch['tokens']
+ target_lens = (target_lens_rate *
+ target.shape[1]).round().astype(paddle.int64)
+ else:
+ utt, wav, wavs_lens, target, target_lens = batch
+ wavs_lens_rate = wavs_lens / wav.shape[1]
+ wav = wav[:, :, 0]
+
+ loss = self.model(wav, wavs_lens_rate, target, target_lens)
+ # use update_average
+ total_loss -= total_loss / (step + 1)
+ total_loss += loss / (step + 1)
+
+ if math.isfinite(float(loss)):
+ step += 1
+ valid_losses['val_loss'] = float(loss)
+ else:
+ logger.info('loss:{} in Nan or inf, error'.format(float(loss)))
+
+ if (i + 1) % self.config.log_interval == 0:
+ valid_losses['val_history_loss'] = float(total_loss)
+
+ # logging
+ msg = f"Valid: Rank: {dist.get_rank()}, "
+ msg += "epoch: {}, ".format(self.epoch)
+ msg += "step: {}, ".format(self.iteration)
+ if not self.use_streamdata:
+ msg += "batch: {}/{}, ".format(i + 1,
+ len(self.valid_loader))
+ msg += ', '.join('{}: {:>.6f}'.format(k, v)
+ for k, v in valid_losses.items())
+ logger.info(msg)
+
+ logger.info(
+ 'Rank {} Val info val_loss {}'.format(dist.get_rank(), total_loss))
+ return total_loss, num_seen_utts
+
+ @mp_tools.rank_zero_only
+ def save(self, tag=None, infos: dict=None):
+ """Save checkpoint (model parameters and optimizer states).
+
+ Args:
+ tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
+ infos (dict, optional): meta data to save. Defaults to None.
+ """
+
+ infos = infos if infos else dict()
+ infos.update({
+ "epoch": self.epoch,
+ "model_lr": self.model_optimizer.get_lr(),
+ "wavlm_lr": self.wavlm_optimizer.get_lr()
+ })
+
+ checkpoint_path = os.path.join(
+ self.checkpoint_dir,
+ "{}".format(self.iteration if tag is None else tag))
+
+ model_dict = self.model.state_dict()
+ params_path = checkpoint_path + ".pdparams"
+ paddle.save(model_dict, params_path)
+ logger.info("Saved model to {}".format(params_path))
+
+ model_opt_dict = self.model_optimizer.state_dict()
+ wavlm_opt_dict = self.wavlm_optimizer.state_dict()
+
+ opt_dict = {'model': model_opt_dict, 'wavlm': wavlm_opt_dict}
+
+ optimizer_path = checkpoint_path + ".pdopt"
+ paddle.save(opt_dict, optimizer_path)
+ logger.info("Saved optimzier state to {}".format(optimizer_path))
+
+ scheduler_dict = {}
+
+ if self.config.model_scheduler == 'newbobscheduler':
+ scheduler_dict['model'] = self.model_lr_scheduler.save()
+ if self.config.wavlm_scheduler == 'newbobscheduler':
+ scheduler_dict['wavlm'] = self.wavlm_lr_scheduler.save()
+ if scheduler_dict:
+ scheduler_path = checkpoint_path + ".pdlrs"
+ paddle.save(scheduler_dict, scheduler_path)
+ logger.info("Saved scheduler state to {}".format(scheduler_path))
+ info_path = re.sub('.pdparams$', '.json', params_path)
+ infos = {} if infos is None else infos
+ with open(info_path, 'w', encoding='utf8') as fout:
+ data = json.dumps(infos)
+ fout.write(data)
+
+ def resume_or_scratch(self):
+ """Resume from latest checkpoint at checkpoints in the output
+ directory or load a specified checkpoint.
+
+ If ``args.checkpoint_path`` is not None, load the checkpoint, else
+ resume training.
+ """
+ scratch = None
+ if self.args.resume:
+ # just restore ckpt
+ # lr will resotre from optimizer ckpt
+ resume_json_path = os.path.join(self.checkpoint_dir,
+ self.args.resume + '.json')
+ with open(resume_json_path, 'r', encoding='utf8') as f:
+ resume_json = json.load(f)
+ self.iteration = 0
+ self.epoch = resume_json["epoch"]
+
+ # resotre model from *.pdparams
+ params_path = os.path.join(self.checkpoint_dir,
+ "{}".format(self.epoch)) + '.pdparams'
+ model_dict = paddle.load(params_path)
+ self.model.set_state_dict(model_dict)
+
+ # resotre optimizer from *.pdopt
+ optimizer_path = os.path.join(self.checkpoint_dir,
+ "{}".format(self.epoch)) + '.pdopt'
+ optimizer_dict = paddle.load(optimizer_path)
+ self.model_optimizer.set_state_dict(optimizer_dict['model'])
+ self.wavlm_optimizer.set_state_dict(optimizer_dict['wavlm'])
+
+ # resotre lr_scheduler from *.pdlrs
+ scheduler_path = os.path.join(self.checkpoint_dir,
+ "{}".format(self.epoch)) + '.pdlrs'
+ if os.path.isfile(os.path.join(scheduler_path)):
+ scheduler_dict = paddle.load(scheduler_path)
+ if self.config.model_scheduler == 'newbobscheduler':
+ self.model_lr_scheduler.load(scheduler_dict['model'])
+ if self.config.wavlm_scheduler == 'newbobscheduler':
+ self.wavlm_lr_scheduler.load(scheduler_dict['wavlm'])
+ logger.info(
+ f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
+ scratch = False
+ else:
+ self.iteration = 0
+ self.epoch = 0
+ scratch = True
+ logger.info("Init from scratch!")
+ return scratch
+
+ def do_train(self):
+ """The training process control by step."""
+ # !!!IMPORTANT!!!
+ # Try to export the model by script, if fails, we should refine
+ # the code to satisfy the script export requirements
+ # script_model = paddle.jit.to_static(self.model)
+ # script_model_path = str(self.checkpoint_dir / 'init')
+ # paddle.jit.save(script_model, script_model_path)
+
+ self.before_train()
+ if not self.use_streamdata:
+ logger.info(
+ f"Train Total Examples: {len(self.train_loader.dataset)}")
+ while self.epoch < self.config.n_epoch:
+ with Timer("Epoch-Train Time Cost: {}"):
+ self.model.train()
+ try:
+ data_start_time = time.time()
+ for batch_index, batch in enumerate(self.train_loader):
+ dataload_time = time.time() - data_start_time
+ msg = "Train:"
+ observation = OrderedDict()
+ with ObsScope(observation):
+ report("Rank", dist.get_rank())
+ report("epoch", self.epoch)
+ report('step', self.iteration)
+ report("model_lr", self.model_optimizer.get_lr())
+ report("wavlm_lr",
+ self.wavlm_optimizer.get_lr())
+ self.train_batch(batch_index, batch, msg)
+ self.after_train_batch()
+ report('iter', batch_index + 1)
+ if not self.use_streamdata:
+ report('total', len(self.train_loader))
+ report('reader_cost', dataload_time)
+ observation['batch_cost'] = observation[
+ 'reader_cost'] + observation['step_cost']
+ observation['samples'] = observation['batch_size']
+ observation['ips,samples/s'] = observation[
+ 'batch_size'] / observation['batch_cost']
+ for k, v in observation.items():
+ msg += f" {k.split(',')[0]}: "
+ msg += f"{v:>.8f}" if isinstance(v,
+ float) else f"{v}"
+ msg += f" {k.split(',')[1]}" if len(
+ k.split(',')) == 2 else ""
+ msg += ","
+ msg = msg[:-1] # remove the last ","
+ if (batch_index + 1) % self.config.log_interval == 0:
+ logger.info(msg)
+ data_start_time = time.time()
+ except Exception as e:
+ logger.error(e)
+ raise e
+ with Timer("Eval Time Cost: {}"):
+ total_loss, num_seen_utts = self.valid()
+ if dist.get_world_size() > 1:
+ num_seen_utts = paddle.to_tensor(num_seen_utts)
+ dist.all_reduce(num_seen_utts)
+ total_loss = paddle.to_tensor(total_loss)
+ dist.all_reduce(total_loss)
+ cv_loss = total_loss / num_seen_utts
+ cv_loss = float(cv_loss)
+ else:
+ cv_loss = float(total_loss)
+ logger.info(
+ 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
+ if self.visualizer:
+ self.visualizer.add_scalar(
+ tag='eval/cv_loss', value=cv_loss, step=self.epoch)
+ self.visualizer.add_scalar(
+ tag='eval/model_lr',
+ value=self.model_lr_scheduler(),
+ step=self.epoch)
+ self.visualizer.add_scalar(
+ tag='eval/wavlm_lr',
+ value=self.wavlm_lr_scheduler(),
+ step=self.epoch)
+
+ if self.config.model_scheduler == 'newbobscheduler':
+ self.model_lr_scheduler.step(cv_loss)
+ if self.config.wavlm_scheduler == 'newbobscheduler':
+ if not self.config.freeze_wavlm:
+ self.wavlm_lr_scheduler.step(cv_loss)
+ self.save(tag=self.epoch, infos={'val_loss': cv_loss})
+ self.avg_train_loss = 0.0
+ self.new_epoch()
+
+ def dataio_prepare(self, hparams):
+ """This function prepares the datasets to be used in the brain class.
+ It also defines the data processing pipeline through user-defined functions."""
+ data_folder = hparams["data_folder"]
+
+ train_data = dataset.DynamicItemDataset.from_csv(
+ csv_path=hparams["train_data"],
+ replacements={"data_root": data_folder}, )
+
+ if hparams["sorting"] == "ascending":
+ # we sort training data to speed up training and get better results.
+ train_data = train_data.filtered_sorted(sort_key="duration")
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
+ hparams["train_dataloader_opts"]["shuffle"] = False
+
+ elif hparams["sorting"] == "descending":
+ train_data = train_data.filtered_sorted(
+ sort_key="duration", reverse=True)
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
+ hparams["train_dataloader_opts"]["shuffle"] = False
+
+ elif hparams["sorting"] == "random":
+ pass
+
+ else:
+ raise NotImplementedError(
+ "sorting must be random, ascending or descending")
+
+ valid_data = dataset.DynamicItemDataset.from_csv(
+ csv_path=hparams["valid_data"],
+ replacements={"data_root": data_folder}, )
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
+
+ test_data = dataset.DynamicItemDataset.from_csv(
+ csv_path=hparams["test_data"],
+ replacements={"data_root": data_folder}, )
+ test_data = test_data.filtered_sorted(sort_key="duration")
+
+ datasets = [train_data, valid_data, test_data]
+
+ # Defining tokenizer and loading it
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
+ self.tokenizer = tokenizer
+ # 2. Define audio pipeline:
+ @data_pipeline.takes("wav")
+ @data_pipeline.provides("sig")
+ def audio_pipeline(wav):
+ sig = dataio.read_audio(wav)
+ return sig
+
+ dataset.add_dynamic_item(datasets, audio_pipeline)
+
+ # 3. Define text pipeline:
+ @data_pipeline.takes("transcript")
+ @data_pipeline.provides("wrd", "tokens_list", "tokens")
+ def text_pipeline(wrd):
+ wrd = "".join(wrd.split(" "))
+ yield wrd
+ tokens_list = tokenizer(wrd)["input_ids"]
+ yield tokens_list
+ tokens = np.array(tokens_list, dtype="int64")
+ # tokens = paddle.to_tensor(tokens_list, dtype="int64")
+ yield tokens
+
+ dataset.add_dynamic_item(datasets, text_pipeline)
+
+ # 4. Set output:
+ dataset.set_output_keys(
+ datasets,
+ ["id", "sig", "wrd", "tokens"], )
+
+ # 5. If Dynamic Batching is used, we instantiate the needed samplers.
+ train_batch_sampler = None
+ valid_batch_sampler = None
+ if hparams["dynamic_batching"]:
+ from sampler import DynamicBatchSampler # noqa
+
+ dynamic_hparams = hparams["dynamic_batch_sampler"]
+ num_buckets = dynamic_hparams["num_buckets"]
+
+ train_batch_sampler = DynamicBatchSampler(
+ train_data,
+ dynamic_hparams["max_batch_len"],
+ num_buckets=num_buckets,
+ length_func=lambda x: x["duration"],
+ shuffle=dynamic_hparams["shuffle_ex"],
+ batch_ordering=dynamic_hparams["batch_ordering"], )
+
+ valid_batch_sampler = DynamicBatchSampler(
+ valid_data,
+ dynamic_hparams["max_batch_len"],
+ num_buckets=num_buckets,
+ length_func=lambda x: x["duration"],
+ shuffle=dynamic_hparams["shuffle_ex"],
+ batch_ordering=dynamic_hparams["batch_ordering"], )
+
+ return (train_data, valid_data, test_data, tokenizer,
+ train_batch_sampler, valid_batch_sampler, )
+
+ def setup_dataloader(self):
+ config = self.config.clone()
+ self.use_streamdata = config.get("use_stream_data", False)
+ self.use_sb = config.get("use_sb_pipeline", False)
+ if self.use_sb:
+ hparams_file = config.sb_pipeline_conf
+ with open(hparams_file, 'r', encoding='utf8') as fin:
+ hparams = load_hyperpyyaml(fin, None)
+
+ (train_data, valid_data, test_data, tokenizer, train_bsampler,
+ valid_bsampler, ) = self.dataio_prepare(hparams)
+
+ train_dataloader_opts = hparams["train_dataloader_opts"]
+ valid_dataloader_opts = hparams["valid_dataloader_opts"]
+
+ if train_bsampler is not None:
+ train_dataloader_opts = {
+ "batch_sampler": train_bsampler,
+ "num_workers": hparams["num_workers"],
+ }
+
+ if valid_bsampler is not None:
+ valid_dataloader_opts = {"batch_sampler": valid_bsampler}
+
+ if self.train:
+ self.train_loader = make_dataloader(
+ train_data, stage='train', **train_dataloader_opts)
+ self.valid_loader = make_dataloader(
+ valid_data,
+ stage='val',
+ **valid_dataloader_opts, )
+ logger.info("Setup train/valid Dataloader!")
+ else:
+ self.test_loader = make_dataloader(
+ test_data, stage='test', **hparams["test_dataloader_opts"])
+ else:
+ if self.train:
+ self.train_loader = DataLoaderFactory.get_dataloader(
+ 'train', config, self.args)
+ self.valid_loader = DataLoaderFactory.get_dataloader(
+ 'valid', config, self.args)
+ logger.info("Setup train/valid Dataloader!")
+ else:
+ decode_batch_size = config.get('decode', dict()).get(
+ 'decode_batch_size', 1)
+ self.test_loader = DataLoaderFactory.get_dataloader(
+ 'test', config, self.args)
+ self.align_loader = DataLoaderFactory.get_dataloader(
+ 'align', config, self.args)
+ logger.info("Setup test/align Dataloader!")
+
+ def setup_model(self):
+ config = self.config
+ model_conf = config
+
+ with UpdateConfig(model_conf):
+ if self.use_sb:
+ model_conf.output_dim = self.tokenizer.vocab_size
+ else:
+ if self.train:
+ model_conf.input_dim = self.train_loader.feat_dim
+ model_conf.output_dim = self.train_loader.vocab_size
+ else:
+ model_conf.input_dim = self.test_loader.feat_dim
+ model_conf.output_dim = self.test_loader.vocab_size
+
+ model = WavLMASR.from_config(model_conf)
+
+ model_dict = paddle.load(config.wavlm_params_path)
+ model.wavlm.set_state_dict(model_dict)
+
+ if self.parallel:
+ model = paddle.DataParallel(model, find_unused_parameters=True)
+
+ layer_tools.print_params(model, logger.info)
+ self.model = model
+ logger.info("Setup model!")
+
+ # setup speech augmentation for wavlm
+ if hasattr(config, 'audio_augment') and self.train:
+ self.speech_augmentation = TimeDomainSpecAugment(
+ **config.audio_augment)
+
+ if not self.train:
+ return
+
+ train_config = config
+ model_optim_type = train_config.model_optim
+ model_optim_conf = train_config.model_optim_conf
+ logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
+ wavlm_optim_type = train_config.wavlm_optim
+ wavlm_optim_conf = train_config.wavlm_optim_conf
+ logger.info("optim_model:{},{}", wavlm_optim_type,
+ wavlm_optim_conf)
+
+ model_scheduler_type = train_config.model_scheduler
+ model_scheduler_conf = train_config.model_scheduler_conf
+ wavlm_scheduler_type = train_config.wavlm_scheduler
+ wavlm_scheduler_conf = train_config.wavlm_scheduler_conf
+
+ model_scheduler_args = dict(
+ **{"learning_rate": model_optim_conf.lr,
+ "verbose": False}, **(dict(model_scheduler_conf)))
+
+ wavlm_scheduler_args = dict(
+ **{"learning_rate": wavlm_optim_conf.lr,
+ "verbose": False}, **(dict(wavlm_scheduler_conf)))
+
+ model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
+ model_scheduler_args)
+ wavlm_lr_scheduler = LRSchedulerFactory.from_args(
+ wavlm_scheduler_type, wavlm_scheduler_args)
+
+ def optimizer_args(
+ config,
+ optim_type,
+ optim_conf,
+ parameters,
+ lr_scheduler=None, ):
+ optim_arg = dict(optim_conf)
+ optim_arg.update({
+ "learning_rate":
+ lr_scheduler if lr_scheduler else optim_conf.lr,
+ "parameters":
+ parameters
+ })
+ return optim_arg
+
+ model_optimizer_args = optimizer_args(
+ config, model_optim_type,
+ model_optim_conf,
+ [{'params': model._layers.enc.parameters()}, {'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.enc.parameters()}, {'params': model.ctc.parameters()}],
+ model_lr_scheduler
+ )
+ # [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
+
+
+ wavlm_optimizer_args = optimizer_args(
+ config, wavlm_optim_type, wavlm_optim_conf,
+ model._layers.wavlm.parameters() if self.parallel else
+ model.wavlm.parameters(), wavlm_lr_scheduler)
+
+ model_optimizer = OptimizerFactory.from_args(model_optim_type,
+ model_optimizer_args)
+ wavlm_optimizer = OptimizerFactory.from_args(wavlm_optim_type,
+ wavlm_optimizer_args)
+
+ self.model_optimizer = model_optimizer
+ self.wavlm_optimizer = wavlm_optimizer
+ self.model_lr_scheduler = model_lr_scheduler
+ self.wavlm_lr_scheduler = wavlm_lr_scheduler
+ logger.info("Setup optimizer/lr_scheduler!")
+
+
+class WavLMASRTester(WavLMASRTrainer):
+ def __init__(self, config, args):
+ super().__init__(config, args)
+ self.text_featurizer = TextFeaturizer(
+ unit_type=config.unit_type, vocab=config.vocab_filepath)
+ self.vocab_list = self.text_featurizer.vocab_list
+
+ def id2token(self, texts, texts_len):
+ """ ord() id to chr() chr """
+ trans = []
+ for text, n in zip(texts, texts_len):
+ n = n.numpy().item()
+ ids = text[:n]
+ trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
+ return trans
+
+ def compute_metrics(self, id, audio, audio_len, texts, texts_len,
+ fout=None):
+ decode_cfg = self.config.decode
+ errors_sum, len_refs, num_ins = 0.0, 0, 0
+ errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
+ error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
+
+ start_time = time.time()
+ target_transcripts = self.id2token(texts, texts_len)
+ result_transcripts, result_tokenids = self.model.decode(
+ audio,
+ text_feature=self.text_featurizer,
+ decoding_method=decode_cfg.decoding_method,
+ beam_size=decode_cfg.beam_size)
+ decode_time = time.time() - start_time
+
+ for utt, target, result, rec_tids in zip(
+ id, target_transcripts, result_transcripts, result_tokenids):
+ errors, len_ref = errors_func(target, result)
+ errors_sum += errors
+ len_refs += len_ref
+ num_ins += 1
+ if fout:
+ fout.write({
+ "utt": utt,
+ "refs": [target],
+ "hyps": [result],
+ "hyps_tokenid": [rec_tids],
+ })
+ logger.info(f"Utt: {utt}")
+ logger.info(f"Ref: {target}")
+ logger.info(f"Hyp: {result}")
+ logger.info("One example error rate [%s] = %f" % (
+ decode_cfg.error_rate_type, error_rate_func(target, result)))
+
+ return dict(
+ errors_sum=errors_sum,
+ len_refs=len_refs,
+ num_ins=num_ins, # num examples
+ error_rate=errors_sum / len_refs,
+ error_rate_type=decode_cfg.error_rate_type,
+ num_frames=audio_len.sum().numpy().item(),
+ decode_time=decode_time)
+
+ def sb_compute_metrics(self, id, sig, wrd, tokens, fout=None):
+ decode_cfg = self.config.decode
+ errors_sum, len_refs, num_ins = 0.0, 0, 0
+ errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
+ error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
+ start_time = time.time()
+ target_transcripts = wrd
+ result_transcripts, result_tokenids = self.model.decode(
+ sig[0],
+ text_feature=self.tokenizer,
+ decoding_method=decode_cfg.decoding_method,
+ beam_size=decode_cfg.beam_size,
+ sb_pipeline=True)
+ decode_time = time.time() - start_time
+
+ for utt, target, result, rec_tids in zip(
+ id, target_transcripts, result_transcripts, result_tokenids):
+ errors, len_ref = errors_func(target, result)
+ errors_sum += errors
+ len_refs += len_ref
+ num_ins += 1
+ if fout:
+ fout.write({
+ "utt": utt,
+ "refs": [target],
+ "hyps": [result],
+ "hyps_tokenid": [rec_tids],
+ })
+ logger.info(f"Utt: {utt}")
+ logger.info(f"Ref: {target}")
+ logger.info(f"Hyp: {result}")
+ logger.info("One example error rate [%s] = %f" % (
+ decode_cfg.error_rate_type, error_rate_func(target, result)))
+
+ return dict(
+ errors_sum=errors_sum,
+ len_refs=len_refs,
+ num_ins=num_ins, # num examples
+ error_rate=errors_sum / len_refs,
+ error_rate_type=decode_cfg.error_rate_type,
+ num_frames=sig[1].sum().numpy().item(),
+ decode_time=decode_time)
+
+ @mp_tools.rank_zero_only
+ @paddle.no_grad()
+ def test(self):
+ logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
+ self.model.eval()
+
+ error_rate_type = None
+ errors_sum, len_refs, num_ins = 0.0, 0, 0
+ num_frames = 0.0
+ num_time = 0.0
+ # Initialized the decoder in model
+ decode_cfg = self.config.decode
+ vocab_list = self.vocab_list
+ decode_batch_size = decode_cfg.decode_batch_size
+
+ with jsonlines.open(self.args.result_file, 'w') as fout:
+ for i, batch in enumerate(self.test_loader):
+ if self.use_sb:
+ metrics = self.sb_compute_metrics(**batch, fout=fout)
+ else:
+ metrics = self.compute_metrics(*batch, fout=fout)
+ num_frames += metrics['num_frames']
+ num_time += metrics["decode_time"]
+ errors_sum += metrics['errors_sum']
+ len_refs += metrics['len_refs']
+ num_ins += metrics['num_ins']
+ error_rate_type = metrics['error_rate_type']
+ rtf = num_time / (num_frames)
+ logger.info(
+ "RTF: %f, Error rate [%s] (%d/?) = %f" %
+ (rtf, error_rate_type, num_ins, errors_sum / len_refs))
+
+ # logging
+ msg = "Test: "
+ msg += "epoch: {}, ".format(self.epoch)
+ msg += "step: {}, ".format(self.iteration)
+ msg += "Final error rate [%s] (%d/%d) = %f" % (
+ error_rate_type, num_ins, num_ins, errors_sum / len_refs)
+ logger.info(msg)
+
+ err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
+ err_type_str = "{}".format(error_rate_type)
+ with open(err_meta_path, 'w', encoding='utf8') as f:
+ data = json.dumps({
+ "epoch":
+ self.epoch,
+ "step":
+ self.iteration,
+ "rtf":
+ rtf,
+ error_rate_type:
+ errors_sum / len_refs,
+ "dataset_hour": (num_frames) / 1000.0 / 3600.0,
+ "process_hour":
+ num_time / 1000.0 / 3600.0,
+ "num_examples":
+ num_ins,
+ "err_sum":
+ errors_sum,
+ "ref_len":
+ len_refs,
+ "decode_method":
+ self.config.decode.decoding_method,
+ })
+ f.write(data + '\n')
diff --git a/paddlespeech/s2t/models/wavlm/__init__.py b/paddlespeech/s2t/models/wavlm/__init__.py
new file mode 100644
index 00000000000..cf69114ea34
--- /dev/null
+++ b/paddlespeech/s2t/models/wavlm/__init__.py
@@ -0,0 +1,2 @@
+from .wavlm_paddle import WavLM, WavLMConfig
+from .wavlm_asr import WavLMASR, WavLMBase
\ No newline at end of file
diff --git a/paddlespeech/s2t/models/wavlm/modules/__init__.py b/paddlespeech/s2t/models/wavlm/modules/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/paddlespeech/s2t/models/wavlm/modules/activations.py b/paddlespeech/s2t/models/wavlm/modules/activations.py
new file mode 100644
index 00000000000..b11dc1a9dc3
--- /dev/null
+++ b/paddlespeech/s2t/models/wavlm/modules/activations.py
@@ -0,0 +1,88 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import paddle
+import paddle.nn.functional as F
+
+
+def _gelu_python(x):
+ """
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
+ torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+ return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
+
+
+
+
+def gelu_new(x):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+ return 0.5 * x * (1.0 + paddle.tanh(
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * paddle.pow(x, 3.0))))
+
+
+def gelu_fast(x):
+ return 0.5 * x * (1.0 + paddle.tanh(x * 0.7978845608 *
+ (1.0 + 0.044715 * x * x)))
+
+gelu = gelu_fast
+
+def _silu_python(x):
+ """
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
+ later.
+ """
+ return x * paddle.nn.functional.sigmoid(x)
+
+
+def mish(x):
+ return x * paddle.tanh(paddle.nn.functional.softplus(x))
+
+
+def linear_act(x):
+ return x
+
+
+ACT2FN = {
+ "relu": F.relu,
+ "silu": _silu_python,
+ "swish": _silu_python,
+ "gelu": gelu,
+ "tanh": paddle.tanh,
+ "gelu_new": gelu_new,
+ "gelu_fast": gelu_fast,
+ "mish": mish,
+ "linear": linear_act,
+ "sigmoid": paddle.nn.functional.sigmoid,
+}
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(
+ f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}"
+ )
\ No newline at end of file
diff --git a/paddlespeech/s2t/models/wavlm/modules/functional.py b/paddlespeech/s2t/models/wavlm/modules/functional.py
new file mode 100644
index 00000000000..d2ebdc71b90
--- /dev/null
+++ b/paddlespeech/s2t/models/wavlm/modules/functional.py
@@ -0,0 +1,473 @@
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from typing import Optional, List, Tuple
+import math
+
+def _mha_shape_check(query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor,
+ key_padding_mask: Optional[paddle.Tensor], attn_mask: Optional[paddle.Tensor], num_heads: int):
+ # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
+ # and returns if the input is batched or not.
+ # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
+
+ # Shape check.
+ if query.dim() == 3:
+ # Batched Inputs
+ is_batched = True
+ assert key.dim() == 3 and value.dim() == 3, \
+ ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+ if key_padding_mask is not None:
+ assert key_padding_mask.dim() == 2, \
+ ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
+ if attn_mask is not None:
+ assert attn_mask.dim() in (2, 3), \
+ ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
+ elif query.dim() == 2:
+ # Unbatched Inputs
+ is_batched = False
+ assert key.dim() == 2 and value.dim() == 2, \
+ ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.dim() == 1, \
+ ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
+ f" but found {key_padding_mask.dim()}-D tensor instead")
+
+ if attn_mask is not None:
+ assert attn_mask.dim() in (2, 3), \
+ ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
+ f" but found {attn_mask.dim()}-D tensor instead")
+ if attn_mask.dim() == 3:
+ expected_shape = (num_heads, query.shape[0], key.shape[0])
+ assert attn_mask.shape == expected_shape, \
+ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
+ else:
+ raise AssertionError(
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
+
+
+def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
+ """
+ Scaled Dot-Product Attention
+ """
+
+ d_key = k.shape[-1]
+ scaled_q = paddle.scale(x=q, scale=d_key ** -0.5)
+ product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
+ weights = F.softmax(x=product + attn_mask)
+ if dropout_p:
+ weights = F.dropout(
+ weights,
+ p=dropout_p,
+ training=True,
+ mode="upscale_in_train"
+ )
+ out = paddle.matmul(x=weights, y=v)
+ return out
+
+
+def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
+ """
+ A helper function to calculate alpha*(vec1*vec2^T) + beta*input
+ """
+ row = vec1.shape[0]
+ column = vec2.shape[0]
+ vec1 = paddle.unsqueeze(vec1, 0)
+ vec1 = paddle.transpose(vec1, [1, 0])
+ vec1 = paddle.expand(vec1, [row, column])
+ new_vec2 = paddle.zeros([column, column], dtype=vec2.dtype)
+ new_vec2[0, :] = vec2
+ out = alpha * paddle.matmul(vec1, new_vec2)
+ out = beta * input + out
+ return out
+
+def multi_head_attention_forward(
+ x: paddle.Tensor,
+ num_heads: int,
+ q_proj: nn.Linear,
+ k_proj: nn.Linear,
+ v_proj: nn.Linear,
+ c_proj: nn.Linear,
+ attn_mask: Optional[paddle.Tensor] = None,
+):
+ max_len, batch_size, emb_dim = x.shape
+ head_dim = emb_dim // num_heads
+ scaling = float(head_dim) ** -0.5
+ q = q_proj(x) # L, N, E
+ k = k_proj(x) # L, N, E
+ v = v_proj(x) # L, N, E
+
+ v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
+ k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
+ q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
+
+ q = q * scaling
+ qk = paddle.matmul(q, k, transpose_y=True)
+ if attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask.unsqueeze_(0)
+ assert attn_mask.shape[0] == 1 and attn_mask.shape[1] == max_len and attn_mask.shape[2] == max_len
+ qk += attn_mask
+
+ qk = F.softmax(qk, axis=-1)
+ atten = paddle.bmm(qk, v)
+ atten = atten.transpose((1, 0, 2))
+ atten = atten.reshape((max_len, batch_size, emb_dim))
+ atten = c_proj(atten)
+ return atten
+
+def linear(input, weight, bias=None):
+ # compute y = x A^T + b
+ # Input: (N, in_feature) paddle tensor
+ # weight: (out_feature, in_feature) paddle tensor
+ # bias: (out_feature) paddle tensor
+ if input.dim() == 2 and bias is not None:
+ # fused op is marginally faster
+ return paddle.addmm(bias, input, weight)
+ output = paddle.matmul(input, weight)
+ if bias is not None:
+ output += bias
+ return output
+
+
+def _in_projection_packed(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ w: paddle.Tensor,
+ b: Optional[paddle.Tensor] = None,
+) -> List[paddle.Tensor]:
+ r"""
+ Performs the in-projection step of the attention operation, using packed weights.
+ Output is a triple containing projection tensors for query, key and value.
+ Args:
+ q, k, v: query, key and value tensors to be projected. For self-attention,
+ these are typically the same tensor; for encoder-decoder attention,
+ k and v are typically the same tensor. (We take advantage of these
+ identities for performance if they are present.) Regardless, q, k and v
+ must share a common embedding dimension; otherwise their shapes may vary.
+ w: projection weights for q, k and v, packed into a single tensor. Weights
+ are packed along dimension 0, in q, k, v order.
+ b: optional projection biases for q, k and v, packed into a single tensor
+ in q, k, v order.
+ Shape:
+ Inputs:
+ - q: :math:`(..., E)` where E is the embedding dimension
+ - k: :math:`(..., E)` where E is the embedding dimension
+ - v: :math:`(..., E)` where E is the embedding dimension
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
+ - b: :math:`E * 3` where E is the embedding dimension
+ Output:
+ - in output list :math:`[q', k', v']`, each output tensor will have the
+ same shape as the corresponding input tensor.
+ """
+ E = q.shape[-1]
+ if k is v:
+ if q is k:
+ # self-attention
+ proj = F.linear(q, w, b)
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
+ return proj[0], proj[1], proj[2]
+ else:
+ # encoder-decoder attention
+ w_q, w_kv = w.split([E, E * 2])
+ if b is None:
+ b_q = b_kv = None
+ else:
+ b_q, b_kv = b.split([E, E * 2])
+ q_proj = F.linear(q, w_q, b_q)
+ kv_proj = F.linear(k, w_kv, b_kv)
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
+ return (q_proj, kv_proj[0], kv_proj[1])
+ else:
+ w_q, w_k, w_v = w.chunk(3)
+ if b is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = b.chunk(3)
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
+
+def _in_projection(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ w_q: paddle.Tensor,
+ w_k: paddle.Tensor,
+ w_v: paddle.Tensor,
+ b_q: Optional[paddle.Tensor] = None,
+ b_k: Optional[paddle.Tensor] = None,
+ b_v: Optional[paddle.Tensor] = None,
+) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
+ A, B, C = F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
+ return A, B, C
+
+def multi_head_attention_forward_paddle(
+ query: paddle.Tensor,
+ key: paddle.Tensor,
+ value: paddle.Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Optional[paddle.Tensor],
+ in_proj_bias: Optional[paddle.Tensor],
+ bias_k: Optional[paddle.Tensor],
+ bias_v: Optional[paddle.Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: paddle.Tensor,
+ out_proj_bias: Optional[paddle.Tensor],
+ training: bool = True,
+ key_padding_mask: Optional[paddle.Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[paddle.Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[paddle.Tensor] = None,
+ k_proj_weight: Optional[paddle.Tensor] = None,
+ v_proj_weight: Optional[paddle.Tensor] = None,
+ static_k: Optional[paddle.Tensor] = None,
+ static_v: Optional[paddle.Tensor] = None,
+ average_attn_weights: bool = True,
+ is_causal: bool = False,
+) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
+ attn_mask for computing scaled dot product attention.
+ Default: ``False``.
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
+ and value in different forms. If false, in_proj_weight will be used, which is
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
+ static_k, static_v: static key and value used for attention operators.
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
+ when ``need_weights=True.``. Default: True
+ Shape:
+ Inputs:
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a FloatTensor is provided, it will be directly added to the value.
+ If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+ positions. If a BoolTensor is provided, positions with ``True``
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ Outputs:
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
+ """
+
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
+ tgt_len, bsz, embed_dim = query.shape
+ src_len, _, _ = key.shape
+
+ if is_causal:
+ attn_mask = None
+
+ assert embed_dim == embed_dim_to_check, \
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
+ if isinstance(embed_dim, paddle.Tensor):
+ # embed_dim can be a tensor when JIT tracing
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
+ else:
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
+ if use_separate_proj_weight:
+ # allow MHA to have different embedding dimensions when separate projection weights are used
+ assert key.shape[:2] == value.shape[:2], \
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
+ else:
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
+
+ #
+ # compute in-projection
+ #
+ if not use_separate_proj_weight:
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
+
+ else:
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
+ if in_proj_bias is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
+
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
+
+ # prep attention mask
+
+ if attn_mask is not None:
+ # ensure attn_mask's dim is 3
+ if attn_mask.dim() == 2:
+ correct_2d_size = (tgt_len, src_len)
+ if attn_mask.shape != correct_2d_size:
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
+ attn_mask = attn_mask.unsqueeze(0)
+ elif attn_mask.dim() == 3:
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
+ if tuple(attn_mask.shape) != correct_3d_size:
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
+ else:
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
+
+ # add bias along batch dimension (currently second)
+ if bias_k is not None and bias_v is not None:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1)
+ v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1)
+ if attn_mask is not None:
+ # attn_mask = pad(attn_mask, (0, 1))
+ # pad last dim with 0 on one side and 1 on the other
+ attn_mask = paddle.concat([attn_mask, paddle.zeros_like(attn_mask[:, :, -1:])], axis=2)
+ if key_padding_mask is not None:
+ # key_padding_mask = pad(key_padding_mask, (0, 1))
+ # pad last dim with 0 on one side and 1 on the other
+ key_padding_mask = paddle.concat([key_padding_mask, paddle.zeros_like(key_padding_mask[:, -1:])], axis=1)
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ #
+ # reshape q, k, v for multihead attention and make em batch first
+ #
+ q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
+
+
+ if static_k is None:
+ k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
+ else:
+ assert static_k.size(0) == bsz * num_heads, \
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
+ assert static_k.size(2) == head_dim, \
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
+ k = static_k
+ if static_v is None:
+ v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
+ else:
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
+ assert static_v.size(0) == bsz * num_heads, \
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
+ assert static_v.size(2) == head_dim, \
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
+ v = static_v
+
+ # add zero attention along batch dimension (now first)
+ if add_zero_attn:
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
+ k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
+ v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
+ if attn_mask is not None:
+ # attn_mask = pad(attn_mask, (0, 1))
+ attn_mask = paddle.concat([attn_mask, paddle.zeros_like(attn_mask[:, :, -1:])], axis=2)
+ if key_padding_mask is not None:
+ # key_padding_mask = pad(key_padding_mask, (0, 1))
+ key_padding_mask = paddle.concat([key_padding_mask, paddle.zeros_like(key_padding_mask[:, -1:])], axis=1)
+
+ # update source sequence length after adjustments
+ src_len = k.shape[1]
+
+ # merge key padding and attention masks
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (bsz, src_len), \
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
+ key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
+ if attn_mask is None:
+ attn_mask = key_padding_mask
+ else:
+ attn_mask = attn_mask + key_padding_mask
+
+ # adjust dropout probability
+ if not training:
+ dropout_p = 0.0
+
+ #
+ # (deep breath) calculate attention and out projection
+ #
+ if need_weights:
+ B, Nt, E = q.shape
+ q_scaled = q / math.sqrt(E)
+ if attn_mask is not None:
+ attn_output_weights = addr(q_scaled, k.transpose(-2, -1))
+ else:
+ attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1))
+ attn_output_weights = F.softmax(attn_output_weights, axis=-1)
+ if dropout_p > 0.0:
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)
+
+ attn_output = paddle.bmm(attn_output_weights, v)
+ attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim])
+ # attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
+ attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
+
+ # optionally average attention weights over heads
+ attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
+ if average_attn_weights:
+ attn_output_weights = attn_output_weights.mean(dim=1)
+
+ if not is_batched:
+ # squeeze the output if input was unbatched
+ attn_output = attn_output.squeeze(1)
+ attn_output_weights = attn_output_weights.squeeze(0)
+ return attn_output, attn_output_weights
+ else:
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
+ # in order to match the input for SDPA of (N, num_heads, L, S)
+ if attn_mask is not None:
+ if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
+ attn_mask = attn_mask.unsqueeze(0)
+ else:
+ attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len])
+
+ q = q.reshape([bsz, num_heads, tgt_len, head_dim])
+ k = k.reshape([bsz, num_heads, src_len, head_dim])
+ v = v.reshape([bsz, num_heads, src_len, head_dim])
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
+ attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim])
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
+ attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
+ return attn_output, None
\ No newline at end of file
diff --git a/paddlespeech/s2t/models/wavlm/modules/modules.py b/paddlespeech/s2t/models/wavlm/modules/modules.py
new file mode 100644
index 00000000000..f14e4016ff7
--- /dev/null
+++ b/paddlespeech/s2t/models/wavlm/modules/modules.py
@@ -0,0 +1,768 @@
+# --------------------------------------------------------
+# paddle: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
+# Github source: https://github.com/microsoft/unilm/tree/master/paddle
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import warnings
+from typing import Dict, Optional, Tuple
+from .functional import multi_head_attention_forward_paddle
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import Tensor
+
+
+
+class TransposeLast(nn.Layer):
+ def __init__(self, deconstruct_idx=None):
+ super().__init__()
+ self.deconstruct_idx = deconstruct_idx
+
+ def forward(self, x):
+ if self.deconstruct_idx is not None:
+ x = x[self.deconstruct_idx]
+ return paddle.transpose(x, perm=[0, 2, 1])
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.group_norm(
+ input.float(),
+ self.num_groups,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+
+class SamePad(nn.Layer):
+ def __init__(self, kernel_size, causal=False):
+ super().__init__()
+ if causal:
+ self.remove = kernel_size - 1
+ else:
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ if self.remove > 0:
+ x = x[:, :, : -self.remove]
+ return x
+
+
+class Swish(nn.Layer):
+ """Swish function
+ """
+
+ def __init__(self):
+ """Construct an MultiHeadedAttention object."""
+ super(Swish, self).__init__()
+ self.act = nn.Sigmoid()
+
+ def forward(self, x):
+ return x * self.act(x)
+
+
+class GLU_Linear(nn.Layer):
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
+ super(GLU_Linear, self).__init__()
+
+ self.glu_type = glu_type
+ self.output_dim = output_dim
+
+ if glu_type == "sigmoid":
+ self.glu_act = nn.Sigmoid()
+ elif glu_type == "swish":
+ self.glu_act = Swish()
+ elif glu_type == "relu":
+ self.glu_act = nn.ReLU()
+ elif glu_type == "gelu":
+ self.glu_act = nn.GELU()
+
+ if bias_in_glu:
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
+ else:
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
+
+ def forward(self, x):
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
+ x = self.linear(x)
+
+ if self.glu_type == "bilinear":
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
+ else:
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
+
+ return x
+
+
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return (
+ 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * (x + 0.044715 * paddle.pow(x, 3))))
+ )
+
+
+def gelu(x: Tensor) -> Tensor:
+ return nn.functional.gelu(x.astype("float32")).astype(x.dtype)
+
+
+def get_activation_fn(activation: str):
+ """Returns the activation function corresponding to `activation`"""
+
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return gelu
+ elif activation == "gelu_fast":
+ warnings.warn(
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
+ )
+ return gelu_accurate
+ elif activation == "gelu_accurate":
+ return gelu_accurate
+ elif activation == "tanh":
+ return paddle.tanh
+ elif activation == "linear":
+ return lambda x: x
+ elif activation == "glu":
+ return lambda x: x
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+def quant_noise(module, p, block_size):
+ """
+ Wraps modules and applies quantization noise to the weights for
+ subsequent quantization with Iterative Product Quantization as
+ described in "Training with Quantization Noise for Extreme Model Compression"
+
+ Args:
+ - module: nn.Layer
+ - p: amount of Quantization Noise
+ - block_size: size of the blocks for subsequent quantization with iPQ
+
+ Remarks:
+ - Module weights must have the right sizes wrt the block size
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
+ - For more detail on how to quantize by blocks with convolutional weights,
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+ - We implement the simplest form of noise here as stated in the paper
+ which consists in randomly dropping blocks
+ """
+
+ # if no quantization noise, don't register hook
+ if p <= 0:
+ return module
+
+ # supported modules
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+ # test whether module.weight has the right sizes wrt block_size
+ is_conv = module.weight.ndim == 4
+
+ # 2D matrix
+ if not is_conv:
+ assert (
+ module.weight.size(1) % block_size == 0
+ ), "Input features must be a multiple of block sizes"
+
+ # 4D matrix
+ else:
+ # 1x1 convolutions
+ if module.kernel_size == (1, 1):
+ assert (
+ module.in_channels % block_size == 0
+ ), "Input channels must be a multiple of block sizes"
+ # regular convolutions
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+ def _forward_pre_hook(mod, input):
+ # no noise for evaluation
+ if mod.training:
+ if not is_conv:
+ # gather weight and sizes
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ mask = paddle.zeros(
+ in_features // block_size * out_features, device=weight.device
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+ else:
+ # gather weight and sizes
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ if mod.kernel_size == (1, 1):
+ mask = paddle.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = paddle.zeros(
+ weight.size(0), weight.size(1), device=weight.device
+ )
+
+ mask.bernoulli_(p)
+ mask = (
+ mask.unsqueeze(2)
+ .unsqueeze(3)
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ )
+
+ # scale weights and apply mask
+ mask = mask.to(
+ paddle.bool
+ )
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
+
+class MultiheadAttention(nn.Layer):
+ """Multi-headed attention.
+
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ q_noise=0.0,
+ qn_block_size=8,
+ has_relative_attention_bias=True,
+ num_buckets=32,
+ max_distance=128,
+ gru_rel_pos=True,
+ rescale_init=False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout_module = nn.Dropout(dropout)
+
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+ self.head_dim = embed_dim // num_heads
+ self.q_head_dim = self.head_dim
+ self.k_head_dim = self.head_dim
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, (
+ "Self-attention requires query, key and " "value to be of the same size"
+ )
+
+ k_bias = True
+ if rescale_init:
+ k_bias = False
+
+ k_embed_dim = embed_dim
+ q_embed_dim = embed_dim
+
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, k_embed_dim, bias_attr=k_bias), q_noise, qn_block_size
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, embed_dim, bias_attr=bias), q_noise, qn_block_size
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(embed_dim, q_embed_dim, bias_attr=bias), q_noise, qn_block_size
+ )
+
+ self.out_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias_attr=bias), q_noise, qn_block_size
+ )
+
+ if add_bias_kv:
+ self.bias_k = self.create_parameter(
+ shape=[1, 1, embed_dim], dtype="float32"
+ )
+ self.bias_v = self.create_parameter(
+ shape=[1, 1, embed_dim], dtype="float32"
+ )
+
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.gru_rel_pos = gru_rel_pos
+ if self.gru_rel_pos:
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
+ self.grep_a = self.create_parameter(
+ shape=[1, num_heads, 1, 1], dtype="float32"
+ )
+
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ pass
+
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
+ num_buckets = self.num_buckets
+ max_distance = self.max_distance
+ relative_buckets = 0
+
+ if bidirectional:
+ num_buckets = num_buckets // 2
+ relative_buckets += (relative_positions > 0).astype("int64") * num_buckets
+ relative_positions = paddle.abs(relative_positions)
+ else:
+ relative_positions = -paddle.minimum(relative_positions, paddle.zeros_like(relative_positions))
+
+ max_exact = num_buckets // 2
+ is_small = relative_positions < max_exact
+
+ relative_postion_if_large = max_exact + (
+ paddle.log(relative_positions.astype("float32") / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).astype("int64")
+ relative_postion_if_large = paddle.minimum(
+ relative_postion_if_large, paddle.full_like(relative_postion_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += paddle.where(is_small, relative_positions, relative_postion_if_large)
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length):
+ context_position = paddle.arange(query_length, dtype="int64")[:, None]
+ memory_position = paddle.arange(key_length, dtype="int64")[None, :]
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_positions_bucket(
+ relative_position,
+ bidirectional=True
+ )
+ # relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose([2, 0, 1])
+ return values
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ position_bias: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.shape
+ src_len = tgt_len
+ assert embed_dim == self.embed_dim
+ assert list(query.shape) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.shape
+
+ if self.has_relative_attention_bias and position_bias is None:
+ position_bias = self.compute_bias(tgt_len, src_len)
+ position_bias_ = position_bias.unsqueeze(0)
+ position_bias = paddle.concat([position_bias_ for _ in range(bsz)], axis=0)
+ position_bias = position_bias.reshape([bsz * self.num_heads, tgt_len, src_len])
+ if (
+ incremental_state is None
+ and not static_kv
+ and self.q_head_dim == self.head_dim
+ ):
+ assert key is not None and value is not None
+ assert attn_mask is None
+
+ attn_mask_rel_pos = None
+ if position_bias is not None:
+ attn_mask_rel_pos = position_bias
+ if self.gru_rel_pos:
+ query_layer = query.transpose([1, 0, 2])
+ new_x_shape = query_layer.shape[:-1] + [self.num_heads, -1]
+ query_layer = query_layer.reshape(new_x_shape)
+ query_layer = query_layer.transpose([0, 2, 1, 3])
+ _B, _H, _L, __ = query_layer.shape
+
+ gate_a, gate_b = paddle.nn.functional.sigmoid(self.grep_linear(query_layer).reshape([_B, _H, _L, 2, 4]).sum(-1, keepdim=False)).chunk(2, axis=-1)
+
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+ attn_mask_rel_pos = gate_a_1.reshape([bsz * self.num_heads, -1, 1]) * position_bias
+
+ attn_mask_rel_pos = attn_mask_rel_pos.reshape((-1, tgt_len, tgt_len))
+ k_proj_bias = self.k_proj.bias
+ if k_proj_bias is None:
+ k_proj_bias = paddle.zeros_like(self.q_proj.bias)
+
+
+ x, attn = multi_head_attention_forward_paddle(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ paddle.empty([0]),
+ paddle.concat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias), axis=0),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training,
+ key_padding_mask,
+ need_weights,
+ attn_mask_rel_pos,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+
+ return x, attn, position_bias
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = paddle.concat([k, self.bias_k.repeat(1, bsz, 1)], axis=0)
+ v = paddle.concat([v, self.bias_v.repeat(1, bsz, 1)], axis=0)
+ if attn_mask is not None:
+ attn_mask = paddle.concat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], axis=1
+ )
+
+ if key_padding_mask is not None:
+ key_padding_mask = paddle.concat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+ ],
+ axis=1,
+ )
+
+ q = (
+ q.contiguous()
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
+ .transpose([1, 0, 2])
+ )
+ if k is not None:
+ k = (
+ k.contiguous()
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
+ .transpose([1, 0, 2])
+ )
+ if v is not None:
+ v = (
+ v.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose([1, 0, 2])
+ )
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = paddle.concat([prev_key, k], axis=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = paddle.concat([prev_value, v], axis=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k = paddle.concat([k, k.new_zeros((k.size(0), 1) + k.shape[2:])], axis=1)
+ v = paddle.concat([v, v.new_zeros((v.size(0), 1) + v.shape[2:])], axis=1)
+ if attn_mask is not None:
+ attn_mask = paddle.concat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], axis=1
+ )
+
+ if key_padding_mask is not None:
+ key_padding_mask = paddle.concat(
+ [
+ key_padding_mask,
+ paddle.zeros(key_padding_mask.size(0), 1).type_as(
+ key_padding_mask
+ ),
+ ],
+ axis=1,
+ )
+
+
+ attn_weights = paddle.matmul(q, k.transpose([0, 2, 1]))
+
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(paddle.bool),
+ float("-inf"),
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v, position_bias
+
+ if position_bias is not None:
+ if self.gru_rel_pos == 1:
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
+ _B, _H, _L, __ = query_layer.shape
+ gate_a, gate_b = paddle.sigmoid(self.grep_linear(query_layer).view(
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, axis=-1)
+
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
+
+ position_bias = position_bias.view(attn_weights.shape)
+
+ attn_weights = attn_weights + position_bias
+
+ attn_weights_float = F.softmax(
+ attn_weights, dim=-1
+ )
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+ attn = paddle.bmm(attn_probs, v)
+ assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose([1, 0, 2, 3])
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights, position_bias
+
+ @staticmethod
+ def _append_prev_key_padding_mask(
+ key_padding_mask: Optional[Tensor],
+ prev_key_padding_mask: Optional[Tensor],
+ batch_size: int,
+ src_len: int,
+ static_kv: bool,
+ ) -> Optional[Tensor]:
+ # saved key padding masks have shape (bsz, seq_len)
+ if prev_key_padding_mask is not None and static_kv:
+ new_key_padding_mask = prev_key_padding_mask
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
+ new_key_padding_mask = paddle.concat(
+ [prev_key_padding_mask.float(), key_padding_mask.float()], axis=1
+ )
+ # During incremental decoding, as the padding token enters and
+ # leaves the frame, there will be a time when prev or current
+ # is None
+ elif prev_key_padding_mask is not None:
+ if src_len > prev_key_padding_mask.size(1):
+ filler = paddle.zeros(
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
+ device=prev_key_padding_mask.device,
+ )
+ new_key_padding_mask = paddle.concat(
+ [prev_key_padding_mask.float(), filler.float()], axis=1
+ )
+
+ else:
+ new_key_padding_mask = prev_key_padding_mask.float()
+ elif key_padding_mask is not None:
+ if src_len > key_padding_mask.size(1):
+ filler = paddle.zeros(
+ (batch_size, src_len - key_padding_mask.size(1)),
+ device=key_padding_mask.device,
+ )
+ new_key_padding_mask = paddle.concat(
+ [filler.float(), key_padding_mask.float()], axis=1
+ )
+
+ else:
+ new_key_padding_mask = key_padding_mask.float()
+ else:
+ new_key_padding_mask = prev_key_padding_mask
+ return new_key_padding_mask
+
+ def _get_input_buffer(
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
+ ) -> Dict[str, Optional[Tensor]]:
+ result = self.get_incremental_state(incremental_state, "attn_state")
+ if result is not None:
+ return result
+ else:
+ empty_result: Dict[str, Optional[Tensor]] = {}
+ return empty_result
+
+ def _set_input_buffer(
+ self,
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+ buffer: Dict[str, Optional[Tensor]],
+ ):
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
+ return attn_weights
\ No newline at end of file
diff --git a/paddlespeech/s2t/models/wavlm/wavlm_asr.py b/paddlespeech/s2t/models/wavlm/wavlm_asr.py
new file mode 100644
index 00000000000..5764890d200
--- /dev/null
+++ b/paddlespeech/s2t/models/wavlm/wavlm_asr.py
@@ -0,0 +1,323 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import defaultdict
+from typing import Dict
+from typing import List
+from typing import Tuple
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
+from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
+from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
+from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
+from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
+from paddlespeech.s2t.utils.utility import log_add
+
+from .wavlm_paddle import WavLM, WavLMConfig
+
+
+class WavLMASR(nn.Layer):
+ def __init__(self, config: dict):
+ super().__init__()
+ init_type = config.get("init_type", None)
+ with DefaultInitializerContext(init_type):
+ self.config = config
+ wavlm_config = WavLMConfig(config)
+ wavlm = WavLM(wavlm_config)
+
+ self.normalize_wav = config.normalize_wav
+ self.output_norm = config.output_norm
+ if hasattr(config, 'spec_augment'):
+ self.spec_augment = SpecAugment(**config.spec_augment)
+
+ if config.freeze_wavlm:
+ wavlm.eval()
+ for parm in wavlm.parameters():
+ parm.trainable = False
+ self.wavlm = wavlm
+ self.enc = VanillaNN(**config.enc)
+ self.ctc = CTC(**config.ctc,
+ odim=config.output_dim,
+ batch_average=False,
+ reduction='mean')
+
+ def forward(self, wav, wavs_lens_rate, target, target_lens):
+ if self.normalize_wav:
+ wav = F.layer_norm(wav, wav.shape)
+
+ # Extract wav2vec output
+ out = self.wavlm(wav)
+ # We normalize the output if required
+ if self.output_norm:
+ out = F.layer_norm(out, out.shape)
+
+ if self.training and hasattr(self.config, 'spec_augment'):
+ feats = self.spec_augment(out)
+ else:
+ feats = out
+
+ x = self.enc(feats)
+ # x = feats
+
+ x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
+ target_lens = target_lens.astype(paddle.int64)
+ # target = target.astype(paddle.int32)
+ ctc_loss = self.ctc(x, x_lens, target, target_lens)
+
+ return ctc_loss
+
+ @paddle.no_grad()
+ def decode(self,
+ feats: paddle.Tensor,
+ text_feature: Dict[str, int],
+ decoding_method: str,
+ beam_size: int,
+ tokenizer: str=None,
+ sb_pipeline=False):
+ batch_size = feats.shape[0]
+
+ if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
+ print(
+ f"decoding mode {decoding_method} must be running with batch_size == 1"
+ )
+ print(f"current batch_size is {batch_size}")
+
+ if decoding_method == 'ctc_greedy_search':
+ if tokenizer is None and sb_pipeline is False:
+ hyps = self.ctc_greedy_search(feats)
+ res = [text_feature.defeaturize(hyp) for hyp in hyps]
+ res_tokenids = [hyp for hyp in hyps]
+ else:
+ if sb_pipeline is True:
+ hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
+ else:
+ hyps = self.ctc_greedy_search(feats)
+ res = []
+ res_tokenids = []
+ for sequence in hyps:
+ # Decode token terms to words
+ predicted_tokens = text_feature.convert_ids_to_tokens(
+ sequence)
+ tmp_res = []
+ tmp_res_tokenids = []
+ for c in predicted_tokens:
+ if c == "[CLS]":
+ continue
+ elif c == "[SEP]" or c == "[PAD]":
+ break
+ else:
+ tmp_res.append(c)
+ tmp_res_tokenids.append(text_feature.vocab[c])
+ res.append(''.join(tmp_res))
+ res_tokenids.append(tmp_res_tokenids)
+
+ # ctc_prefix_beam_search and attention_rescoring only return one
+ # result in List[int], change it to List[List[int]] for compatible
+ # with other batch decoding mode
+ elif decoding_method == 'ctc_prefix_beam_search':
+ assert feats.shape[0] == 1
+ if tokenizer is None and sb_pipeline is False:
+ hyp = self.ctc_prefix_beam_search(feats, beam_size)
+ res = [text_feature.defeaturize(hyp)]
+ res_tokenids = [hyp]
+ else:
+ if sb_pipeline is True:
+ hyp = self.ctc_prefix_beam_search(
+ feats.unsqueeze(-1), beam_size)
+ else:
+ hyp = self.ctc_prefix_beam_search(feats, beam_size)
+ res = []
+ res_tokenids = []
+ predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
+ tmp_res = []
+ tmp_res_tokenids = []
+ for c in predicted_tokens:
+ if c == "[CLS]":
+ continue
+ elif c == "[SEP]" or c == "[PAD]":
+ break
+ else:
+ tmp_res.append(c)
+ tmp_res_tokenids.append(text_feature.vocab[c])
+ res.append(''.join(tmp_res))
+ res_tokenids.append(tmp_res_tokenids)
+ else:
+ raise ValueError(
+ f"WavLM not support decoding method: {decoding_method}")
+
+ return res, res_tokenids
+
+ @classmethod
+ def from_config(cls, config):
+ model = cls(config)
+ return model
+
+ def ctc_greedy_search(self, wav) -> List[List[int]]:
+ """ Apply CTC greedy search
+ Args:
+ speech (paddle.Tensor): (batch, max_len)
+ speech_length (paddle.Tensor): (batch, )
+ Returns:
+ List[List[int]]: best path result
+ """
+ batch_size = wav.shape[0]
+ wav = wav[:, :, 0]
+ if self.normalize_wav:
+ wav = F.layer_norm(wav, wav.shape[1:])
+ # Extract wavlm output
+ out = self.wavlm(wav)
+ # We normalize the output if required
+ if self.output_norm:
+ out = F.layer_norm(out, out.shape[1:])
+ feats = out
+ x = self.enc(feats)
+ x_lens = x.shape[1]
+ ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
+ topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
+ topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
+
+ hyps = [hyp.tolist() for hyp in topk_index]
+ hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
+ return hyps
+
+ def _ctc_prefix_beam_search(
+ self,
+ wav,
+ beam_size,
+ blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
+ """ CTC prefix beam search inner implementation
+ Args:
+ speech (paddle.Tensor): (batch, max_len, feat_dim)
+ speech_length (paddle.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ Returns:
+ List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
+ paddle.Tensor: encoder output, (1, max_len, encoder_dim),
+ it will be used for rescoring in attention rescoring mode
+ """
+ wav = wav[:, :, 0]
+
+ if self.normalize_wav:
+ wav = F.layer_norm(wav, wav.shape[1:])
+ # Extract wavlm output
+ out = self.wavlm(wav)
+ # We normalize the output if required
+ if self.output_norm:
+ out = F.layer_norm(out, out.shape[1:])
+ feats = out
+
+ x = self.enc(feats)
+ maxlen = x.shape[1]
+ ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size)
+ ctc_probs = ctc_probs.squeeze(0)
+
+ # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
+ # blank_ending_score and none_blank_ending_score in ln domain
+ cur_hyps = [(tuple(), (0.0, -float('inf')))]
+ # 2. CTC beam search step by step
+ for t in range(0, maxlen):
+ logp = ctc_probs[t] # (vocab_size,)
+ # key: prefix, value (pb, pnb), default value(-inf, -inf)
+ next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
+ # 2.1 First beam prune: select topk best
+ top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
+ for s in top_k_index:
+ s = s.item()
+ ps = logp[s].item()
+ for prefix, (pb, pnb) in cur_hyps:
+ last = prefix[-1] if len(prefix) > 0 else None
+ if s == blank_id: # blank
+ n_pb, n_pnb = next_hyps[prefix]
+ n_pb = log_add([n_pb, pb + ps, pnb + ps])
+ next_hyps[prefix] = (n_pb, n_pnb)
+ elif s == last:
+ # Update *ss -> *s;
+ n_pb, n_pnb = next_hyps[prefix]
+ n_pnb = log_add([n_pnb, pnb + ps])
+ next_hyps[prefix] = (n_pb, n_pnb)
+ # Update *s-s -> *ss, - is for blank
+ n_prefix = prefix + (s, )
+ n_pb, n_pnb = next_hyps[n_prefix]
+ n_pnb = log_add([n_pnb, pb + ps])
+ next_hyps[n_prefix] = (n_pb, n_pnb)
+ else:
+ n_prefix = prefix + (s, )
+ n_pb, n_pnb = next_hyps[n_prefix]
+ n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
+ next_hyps[n_prefix] = (n_pb, n_pnb)
+
+ # 2.2 Second beam prune
+ next_hyps = sorted(
+ next_hyps.items(),
+ key=lambda x: log_add(list(x[1])),
+ reverse=True)
+ cur_hyps = next_hyps[:beam_size]
+
+ hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
+ return hyps
+
+ def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]:
+ """ Apply CTC prefix beam search
+ Args:
+ speech (paddle.Tensor): (batch, max_len, feat_dim)
+ speech_length (paddle.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ Returns:
+ List[int]: CTC prefix beam search nbest results
+ """
+ hyps = self._ctc_prefix_beam_search(wav, beam_size)
+ return hyps[0][0]
+
+
+class WavLMBase(nn.Layer):
+ """WavLM model"""
+
+ def __init__(self, config: dict):
+ super().__init__()
+ wavlm_config = WavLMConfig(config)
+ wavlm = WavLM(wavlm_config)
+ self.wavlm = wavlm
+
+ @classmethod
+ def from_config(cls, configs: dict):
+ """init model.
+ Args:
+ configs (dict): config dict.
+ Raises:
+ ValueError: raise when using not support encoder type.
+ Returns:
+ nn.Layer: WavLMBase
+ """
+ model = cls(configs)
+ return model
+
+ def forward(self, wav):
+ out = self.wavlm(wav)
+ return out
diff --git a/paddlespeech/s2t/models/wavlm/wavlm_paddle.py b/paddlespeech/s2t/models/wavlm/wavlm_paddle.py
new file mode 100644
index 00000000000..6ed9ecd0e60
--- /dev/null
+++ b/paddlespeech/s2t/models/wavlm/wavlm_paddle.py
@@ -0,0 +1,756 @@
+# --------------------------------------------------------
+# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
+# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import logging
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn import LayerNorm
+from paddle import Tensor
+from .modules.modules import (
+ MultiheadAttention,
+ SamePad,
+ get_activation_fn,
+ TransposeLast,
+ GLU_Linear,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = np.random.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = np.random.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - keep_length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = np.random.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len:
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+
+ return mask
+
+
+class WavLMConfig:
+ def __init__(self, cfg=None):
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
+
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
+ self.activation_fn: str = "gelu" # activation function to use
+
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
+ self.conv_bias: bool = False # include bias in conv encoder
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
+
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
+
+ # dropouts
+ self.dropout: float = 0.1 # dropout probability for the transformer
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
+
+ # masking
+ self.mask_length: int = 10 # mask length
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
+ self.mask_selection: str = "static" # how to choose mask length
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
+
+ # channel masking
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
+
+ # positional embeddings
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
+
+ # relative position embedding
+ self.relative_position_embedding: bool = True # apply relative position embedding
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
+ self.gru_rel_pos: bool = True # apply gated relative position embedding
+
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ self.__dict__.update(cfg)
+
+
+class WavLM(nn.Layer):
+ def __init__(
+ self,
+ cfg: WavLMConfig,
+ ) -> None:
+ super().__init__()
+ logger.info(f"WavLM Config: {cfg.__dict__}")
+
+ self.cfg = cfg
+ feature_enc_layers = eval(cfg.conv_feature_layers)
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+
+ self.mask_emb = self.create_parameter(
+ shape=[cfg.encoder_embed_dim],
+ default_initializer=nn.initializer.Uniform(),
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ def apply_mask(self, x, padding_mask):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ # mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ mask_indices = paddle.to_tensor(mask_indices, dtype='int64')
+ x[mask_indices] = self.mask_emb
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ # torch.from_numpy(mask_channel_indices)
+ paddle.to_tensor(mask_channel_indices, dtype='int64')
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def forward_padding_mask(
+ self, features: Tensor, padding_mask: Tensor,
+ ) -> Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(
+ padding_mask.size(0), features.size(1), -1
+ )
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def extract_features(
+ self,
+ source: Tensor,
+ padding_mask: Optional[Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ ret_layer_results: bool = False,
+ ):
+
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ # if self.feature_grad_mult != 1.0:
+ # features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ # with torch.no_grad():
+ with paddle.no_grad():
+ features = self.feature_extractor(source)
+
+ features = features.transpose([0, 2, 1]) # [1, 49, 512]
+ features = self.layer_norm(features)
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+ # [1, 49, 768]
+ features = self.dropout_input(features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(
+ features, padding_mask
+ )
+ else:
+ x = features
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1
+ )
+ # print(f"Debugging: x.shape: {x.shape}, x.mean(): {x.mean()}, x.std(): {x.std()}")
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
+
+ feature = res["features"] if ret_conv else res["x"]
+ if ret_layer_results:
+ feature = (feature, res["layer_results"])
+ return feature, res["padding_mask"]
+
+ def forward(self, x):
+ return self.extract_features(x)[0]
+
+
+class ConvFeatureExtractionModel(nn.Layer):
+ def __init__(
+ self,
+ conv_layers: List[Tuple[int, int, int]],
+ dropout: float = 0.0,
+ mode: str = "default",
+ conv_bias: bool = False,
+ conv_type: str = "default"
+ ):
+ super().__init__()
+
+ assert mode in {"default", "layer_norm"}
+
+ def block(
+ n_in,
+ n_out,
+ k,
+ stride,
+ is_layer_norm=False,
+ is_group_norm=False,
+ conv_bias=False,
+ ):
+ def make_conv():
+ conv = nn.Conv1D(n_in, n_out, k, stride=stride, bias_attr=conv_bias,
+ weight_attr=nn.initializer.KaimingNormal())
+ # nn.init.kaiming_normal_(conv.weight)
+ return conv
+
+ assert (
+ is_layer_norm and is_group_norm
+ ) == False, "layer norm and group norm are exclusive"
+
+ if is_layer_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ nn.Sequential(
+ TransposeLast(),
+ nn.LayerNorm(normalized_shape=dim, epsilon=1e-5),
+ TransposeLast(),
+ ),
+ nn.GELU(),
+ )
+ elif is_group_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ nn.GroupNorm(num_groups=dim, num_channels=dim, epsilon=1e-5),
+ nn.GELU(),
+ )
+ else:
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+
+ self.conv_type = conv_type
+ if self.conv_type == "default":
+ in_d = 1
+ self.conv_layers = nn.LayerList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(
+ block(
+ in_d,
+ dim,
+ k,
+ stride,
+ is_layer_norm=mode == "layer_norm",
+ is_group_norm=mode == "default" and i == 0,
+ conv_bias=conv_bias,
+ )
+ )
+ in_d = dim
+ elif self.conv_type == "conv2d":
+ in_d = 1
+ self.conv_layers = nn.LayerList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(
+ paddle.nn.Conv2D(in_d, dim, k, stride)
+ )
+ self.conv_layers.append(paddle.nn.ReLU())
+ in_d = dim
+ elif self.conv_type == "custom":
+ in_d = 1
+ idim = 80
+ self.conv_layers = nn.LayerList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3
+ (dim, k, stride) = cl
+ self.conv_layers.append(
+ paddle.nn.Conv2D(in_d, dim, k, stride, padding=1)
+ )
+ self.conv_layers.append(
+ paddle.nn.LayerNorm([dim, idim])
+ )
+ self.conv_layers.append(paddle.nn.ReLU())
+ in_d = dim
+ if (i + 1) % 2 == 0:
+ self.conv_layers.append(
+ paddle.nn.MaxPool2D(2, stride=2, ceil_mode=True)
+ )
+ idim = int(math.ceil(idim / 2))
+ else:
+ pass
+
+ def forward(self, x, mask=None):
+
+ # BxT -> BxCxT
+ x = x.unsqueeze(1)
+ if self.conv_type == "custom":
+ for conv in self.conv_layers:
+ if isinstance(conv, nn.LayerNorm):
+ x = x.transpose([0, 2, 1])
+ x = conv(x).transpose([0, 2, 1])
+ else:
+ x = conv(x)
+ x = x.transpose([0, 1, 3, 2]).contiguous()
+ x = x.view(x.size(0), -1, x.size(-1))
+ else:
+ for conv in self.conv_layers:
+ x = conv(x)
+ if self.conv_type == "conv2d":
+ b, c, t, f = x.size()
+ # x = x.transpose(2, 3).contiguous().view(b, c * f, t)
+ x = x.transpose([0, 1, 3, 2]).contiguous().view(b, c * f, t)
+ return x
+
+
+class TransformerEncoder(nn.Layer):
+ def __init__(self, args):
+ super().__init__()
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+
+
+ self.pos_conv = nn.Conv1D(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=args.conv_pos,
+ padding=args.conv_pos // 2,
+ groups=args.conv_pos_groups,
+ weight_attr=nn.initializer.Normal(mean=0, std=std),
+ bias_attr=True
+ )
+ # nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ # nn.init.constant_(self.pos_conv.bias, 0)
+
+ # self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ # self.pos_conv.weight_g = self.pos_conv.weight_g.unsqueeze(0).unsqueeze(0)
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+ if hasattr(args, "relative_position_embedding"):
+ self.relative_position_embedding = args.relative_position_embedding
+ self.num_buckets = args.num_buckets
+ self.max_distance = args.max_distance
+ else:
+ self.relative_position_embedding = False
+ self.num_buckets = 0
+ self.max_distance = 0
+
+ self.layers = nn.LayerList(
+ [
+ TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ gru_rel_pos=args.gru_rel_pos,
+ )
+ for i in range(args.encoder_layers)
+ ]
+ )
+
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+
+ # self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
+ # print("x.shape", x.shape)
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+
+ return x, layer_results
+
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
+
+ if padding_mask is not None:
+ x[padding_mask] = 0
+
+ x_conv = self.pos_conv(x.transpose([0, 2, 1]))
+ x_conv = x_conv.transpose([0, 2, 1])
+ x += x_conv
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ # x = x.transpose(0, 1)
+ x = x.transpose([1, 0, 2])
+
+
+ layer_results = []
+ z = None
+ if tgt_layer is not None:
+ layer_results.append((x, z))
+ r = None
+ pos_bias = None
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random()
+ if not self.training or (dropout_probability > self.layerdrop):
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,self_attn_mask=streaming_mask, pos_bias=pos_bias)
+ if tgt_layer is not None:
+ layer_results.append((x, z))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+
+ # T x B x C -> B x T x C
+ # x = x.transpose(0, 1)
+ x = x.transpose([1, 0, 2])
+
+ return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Layer):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: float = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ has_relative_attention_bias: bool = True,
+ num_buckets: int = 0,
+ max_distance: int = 0,
+ rescale_init: bool = False,
+ gru_rel_pos: bool = True,
+ ) -> None:
+
+ super().__init__()
+ # Initialize parameters
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ # Initialize blocks
+ self.activation_name = activation_fn
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ has_relative_attention_bias=has_relative_attention_bias,
+ num_buckets=num_buckets,
+ max_distance=max_distance,
+ rescale_init=rescale_init,
+ gru_rel_pos=gru_rel_pos,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+
+ if self.activation_name == "glu":
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
+ else:
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ def forward(
+ self,
+ x: Tensor,
+ self_attn_mask: Tensor = None,
+ self_attn_padding_mask: Tensor = None,
+ need_weights: bool = False,
+ pos_bias=None
+ ):
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer imlementation.
+ """
+ residual = x
+ if self.layer_norm_first:
+
+ x = self.self_attn_layer_norm(x)
+ x, attn, pos_bias = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias
+ )
+ # import pdb; pdb.set_trace()
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ if self.activation_name == "glu":
+ x = self.fc1(x)
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn, pos_bias = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=need_weights,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias
+ )
+
+ x = self.dropout1(x)
+ x = residual + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ if self.activation_name == "glu":
+ x = self.fc1(x)
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x, attn, pos_bias