From 697c04b8e5e73c864c94202e46d68596a1d15787 Mon Sep 17 00:00:00 2001 From: Soohwan Kim Date: Tue, 8 Jun 2021 00:56:54 +0900 Subject: [PATCH] Release v0.2 (resolved #11 resolved #12) --- README.md | 26 +++++ openspeech/configs/eval.yaml | 5 + .../configs/{configs.yaml => train.yaml} | 0 .../data/audio/filter_bank/configuration.py | 2 +- .../audio/melspectrogram/configuration.py | 2 +- openspeech/data/audio/mfcc/configuration.py | 2 +- .../data/audio/spectrogram/configuration.py | 2 +- openspeech/data/data_loader.py | 25 +++++ openspeech/data/dataset.py | 2 +- openspeech/dataclass/__init__.py | 6 + openspeech/dataclass/configurations.py | 56 ++++++++++ openspeech/dataclass/initialize.py | 18 +++ openspeech/decoders/transformer_decoder.py | 2 +- openspeech/models/conformer/configurations.py | 2 +- .../models/conformer_lstm/configurations.py | 2 +- openspeech/models/conformer_lstm/model.py | 3 +- .../conformer_transducer/configurations.py | 2 +- .../configurations.py | 2 +- .../model.py | 3 +- .../models/deepspeech2/configurations.py | 2 +- .../models/jasper10x5/configurations.py | 2 +- openspeech/models/jasper5x3/configurations.py | 2 +- .../configurations.py | 2 +- .../models/joint_ctc_conformer_lstm/model.py | 3 +- .../configurations.py | 30 ++--- .../joint_ctc_listen_attend_spell/model.py | 3 +- .../joint_ctc_transformer/configurations.py | 2 +- .../models/joint_ctc_transformer/model.py | 3 +- .../listen_attend_spell/configurations.py | 2 +- .../models/listen_attend_spell/model.py | 3 +- .../configurations.py | 2 +- .../model.py | 3 +- .../configurations.py | 2 +- .../model.py | 3 +- openspeech/models/openspeech_ctc_model.py | 2 +- .../models/openspeech_transducer_model.py | 4 + .../models/quartznet10x5/configurations.py | 2 +- .../models/quartznet15x5/configurations.py | 2 +- .../models/quartznet5x5/configurations.py | 2 +- .../models/rnn_transducer/configurations.py | 2 +- .../models/transformer/configurations.py | 2 +- openspeech/models/transformer/model.py | 3 +- .../transformer_transducer/configurations.py | 2 +- .../transformer_with_ctc/configurations.py | 2 +- .../models/vgg_transformer/configurations.py | 2 +- openspeech/models/vgg_transformer/model.py | 3 +- openspeech/search/base.py | 4 +- openspeech/search/beam_search_lstm.py | 8 +- openspeech/search/beam_search_transformer.py | 7 +- openspeech/search/ensemble_search.py | 96 ++++++++++++++++ openspeech_cli/hydra_ensemble_eval.py | 103 ++++++++++++++++++ openspeech_cli/hydra_eval.py | 88 +++++++++++++++ openspeech_cli/hydra_train.py | 3 +- setup.py | 2 +- tests/test_conformer_lstm.py | 2 +- ..._cnn_with_joint_ctc_listen_attend_spell.py | 2 +- tests/test_joint_ctc_conformer_lstm.py | 2 +- tests/test_joint_ctc_listen_attend_spell.py | 2 +- tests/test_joint_ctc_transformer.py | 2 +- tests/test_listen_attend_spell.py | 2 +- ...listen_attend_spell_with_location_aware.py | 2 +- ...est_listen_attend_spell_with_multi_head.py | 2 +- tests/test_transformer.py | 2 +- tests/test_vgg_transformer.py | 2 +- 64 files changed, 504 insertions(+), 81 deletions(-) create mode 100644 openspeech/configs/eval.yaml rename openspeech/configs/{configs.yaml => train.yaml} (100%) create mode 100644 openspeech/search/ensemble_search.py create mode 100644 openspeech_cli/hydra_ensemble_eval.py create mode 100644 openspeech_cli/hydra_eval.py diff --git a/README.md b/README.md index f2412b4..4e1982f 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,32 @@ $ python ./openspeech_cli/hydra_train.py \ criterion=ctc ``` +### Evaluation examples + +- Example1: Evaluation the `listen_attend_spell` model: + +``` +$ python ./openspeech_cli/hydra_eval.py \ + audio=melspectrogram \ + eval.model_name=listen_attend_spell \ + eval.dataset_path=$DATASET_PATH \ + eval.checkpoint_path=$CHECKPOINT_PATH \ + eval.manifest_file_path=$MANIFEST_FILE_PATH +``` + +- Example2: Evaluation the `listen_attend_spell`, `conformer_lstm` models with ensemble: + +``` +$ python ./openspeech_cli/hydra_eval.py \ + audio=melspectrogram \ + eval.model_names=(listen_attend_spell, conformer_lstm) \ + eval.dataset_path=$DATASET_PATH \ + eval.checkpoint_paths=($CHECKPOINT_PATH1, $CHECKPOINT_PATH2) \ + eval.ensemble_weights=(0.3, 0.7) \ + eval.ensemble_method=weighted \ + eval.manifest_file_path=$MANIFEST_FILE_PATH +``` + ## Installation This project recommends Python 3.7 or higher. diff --git a/openspeech/configs/eval.yaml b/openspeech/configs/eval.yaml new file mode 100644 index 0000000..1936598 --- /dev/null +++ b/openspeech/configs/eval.yaml @@ -0,0 +1,5 @@ +# @package _group_ + +defaults: + - audio: null + - eval: default \ No newline at end of file diff --git a/openspeech/configs/configs.yaml b/openspeech/configs/train.yaml similarity index 100% rename from openspeech/configs/configs.yaml rename to openspeech/configs/train.yaml diff --git a/openspeech/data/audio/filter_bank/configuration.py b/openspeech/data/audio/filter_bank/configuration.py index 073b354..9de1364 100644 --- a/openspeech/data/audio/filter_bank/configuration.py +++ b/openspeech/data/audio/filter_bank/configuration.py @@ -35,7 +35,7 @@ class FilterBankConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: name (str): name of feature transform. (default: fbank) sample_rate (int): sampling rate of audio (default: 16000) frame_length (float): frame length for spectrogram (default: 20.0) diff --git a/openspeech/data/audio/melspectrogram/configuration.py b/openspeech/data/audio/melspectrogram/configuration.py index 5e900d4..5a5b706 100644 --- a/openspeech/data/audio/melspectrogram/configuration.py +++ b/openspeech/data/audio/melspectrogram/configuration.py @@ -35,7 +35,7 @@ class MelSpectrogramConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.OpenspeechDataclass`. - Configurations: + Args: name (str): name of feature transform. (default: melspectrogram) sample_rate (int): sampling rate of audio (default: 16000) frame_length (float): frame length for spectrogram (default: 20.0) diff --git a/openspeech/data/audio/mfcc/configuration.py b/openspeech/data/audio/mfcc/configuration.py index fc419ec..5265c2d 100644 --- a/openspeech/data/audio/mfcc/configuration.py +++ b/openspeech/data/audio/mfcc/configuration.py @@ -35,7 +35,7 @@ class MFCCConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.OpenspeechDataclass`. - Configurations: + Args: name (str): name of feature transform. (default: mfcc) sample_rate (int): sampling rate of audio (default: 16000) frame_length (float): frame length for spectrogram (default: 20.0) diff --git a/openspeech/data/audio/spectrogram/configuration.py b/openspeech/data/audio/spectrogram/configuration.py index f8f9ef0..7720712 100644 --- a/openspeech/data/audio/spectrogram/configuration.py +++ b/openspeech/data/audio/spectrogram/configuration.py @@ -35,7 +35,7 @@ class SpectrogramConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.OpenspeechDataclass`. - Configurations: + Args: name (str): name of feature transform. (default: spectrogram) sample_rate (int): sampling rate of audio (default: 16000) frame_length (float): frame length for spectrogram (default: 20.0) diff --git a/openspeech/data/data_loader.py b/openspeech/data/data_loader.py index 3fd0a00..e73886b 100644 --- a/openspeech/data/data_loader.py +++ b/openspeech/data/data_loader.py @@ -19,6 +19,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from typing import Tuple import torch import numpy as np @@ -132,3 +133,27 @@ def __len__(self): def shuffle(self, epoch): np.random.shuffle(self.bins) + + +def load_dataset(manifest_file_path: str) -> Tuple[list, list]: + """ + Provides dictionary of filename and labels. + + Args: + manifest_file_path (str): evaluation manifest file path. + + Returns: target_dict + * target_dict (dict): dictionary of filename and labels + """ + audio_paths = list() + transcripts = list() + + with open(manifest_file_path) as f: + for idx, line in enumerate(f.readlines()): + audio_path, korean_transcript, transcript = line.split('\t') + transcript = transcript.replace('\n', '') + + audio_paths.append(audio_path) + transcripts.append(transcript) + + return audio_paths, transcripts diff --git a/openspeech/data/dataset.py b/openspeech/data/dataset.py index aae3b5a..1d7e4b6 100644 --- a/openspeech/data/dataset.py +++ b/openspeech/data/dataset.py @@ -90,7 +90,7 @@ def __init__( self.apply_noise_augment = apply_noise_augment self.apply_time_stretch_augment = apply_time_stretch_augment self.apply_joining_augment = apply_joining_augment - self.transforms = AUDIO_FEATURE_TRANSFORM_DATACLASS_REGISTRY[configs.name](configs) + self.transforms = AUDIO_FEATURE_TRANSFORM_DATACLASS_REGISTRY[configs.audio.name](configs) self._load_audio = load_audio if self.apply_spec_augment: diff --git a/openspeech/dataclass/__init__.py b/openspeech/dataclass/__init__.py index 5688b10..6ed175f 100644 --- a/openspeech/dataclass/__init__.py +++ b/openspeech/dataclass/__init__.py @@ -31,6 +31,8 @@ Fp16GPUTrainerConfigs, Fp16TPUTrainerConfigs, Fp64CPUTrainerConfigs, + EvaluationConfigs, + EnsembleEvaluationConfigs, ) OPENSPEECH_CONFIGS = [ @@ -62,3 +64,7 @@ AUGMENT_DATACLASS_REGISTRY = { "default": AugmentConfigs, } +EVAL_DATACLASS_REGISTRY = { + "default": EvaluationConfigs, + "ensemble": EnsembleEvaluationConfigs, +} \ No newline at end of file diff --git a/openspeech/dataclass/configurations.py b/openspeech/dataclass/configurations.py index 3d8ca9b..578c8ef 100644 --- a/openspeech/dataclass/configurations.py +++ b/openspeech/dataclass/configurations.py @@ -311,6 +311,62 @@ class VocabularyConfigs(OpenspeechDataclass): ) +@dataclass +class EvaluationConfigs(OpenspeechDataclass): + model_name: str = field( + default=MISSING, metadata={"help": "Model name."} + ) + dataset_path: str = field( + default=MISSING, metadata={"help": "Path of dataset."} + ) + checkpoint_path: str = field( + default=MISSING, metadata={"help": "Path of model checkpoint."} + ) + manifest_file_path: str = field( + default=MISSING, metadata={"help": "Path of evaluation manifest file."} + ) + num_workers: int = field( + default=4, metadata={"help": "Number of worker."} + ) + batch_size: int = field( + default=32, metadata={"help": "Batch size."} + ) + beam_size: int = field( + default=1, metadata={"help": "Beam size of beam search."} + ) + + +@dataclass +class EnsembleEvaluationConfigs(OpenspeechDataclass): + model_names: str = field( + default=MISSING, metadata={"help": "List of model name."} + ) + dataset_paths: str = field( + default=MISSING, metadata={"help": "Path of dataset."} + ) + checkpoint_paths: str = field( + default=MISSING, metadata={"help": "List of model checkpoint path."} + ) + manifest_file_path: str = field( + default=MISSING, metadata={"help": "Path of evaluation manifest file."} + ) + ensemble_method: str = field( + default="vanilla", metadata={"help": "Method of ensemble (vanilla, weighted)"} + ) + ensemble_weights: str = field( + default="(1.0, 1.0, 1.0 ..)", metadata={"help": "Weights of ensemble models."} + ) + num_workers: int = field( + default=4, metadata={"help": "Number of worker."} + ) + batch_size: int = field( + default=32, metadata={"help": "Batch size."} + ) + beam_size: int = field( + default=1, metadata={"help": "Beam size of beam search."} + ) + + def generate_openspeech_configs_with_help(): from openspeech.dataclass import OPENSPEECH_CONFIGS, TRAINER_DATACLASS_REGISTRY from openspeech.models import MODEL_DATACLASS_REGISTRY diff --git a/openspeech/dataclass/initialize.py b/openspeech/dataclass/initialize.py index 07604a6..2a0c7f6 100644 --- a/openspeech/dataclass/initialize.py +++ b/openspeech/dataclass/initialize.py @@ -51,3 +51,21 @@ def hydra_init() -> None: for k, v in dataclass_registry.items(): cs.store(group=group, name=k, node=v, provider="openspeech") + + +def hydra_eval_init() -> None: + from openspeech.data import AUDIO_FEATURE_TRANSFORM_DATACLASS_REGISTRY + from openspeech.dataclass import EVAL_DATACLASS_REGISTRY + + registries = { + "audio": AUDIO_FEATURE_TRANSFORM_DATACLASS_REGISTRY, + "eval": EVAL_DATACLASS_REGISTRY, + } + + cs = ConfigStore.instance() + + for group in registries.keys(): + dataclass_registry = registries[group] + + for k, v in dataclass_registry.items(): + cs.store(group=group, name=k, node=v, provider="openspeech") \ No newline at end of file diff --git a/openspeech/decoders/transformer_decoder.py b/openspeech/decoders/transformer_decoder.py index 11a0eef..b36ebc3 100644 --- a/openspeech/decoders/transformer_decoder.py +++ b/openspeech/decoders/transformer_decoder.py @@ -255,7 +255,7 @@ def forward( input_var = input_var.fill_(self.pad_id) input_var[:, 0] = self.sos_id - for di in range(1, self.max_length): + for di in range(self.max_length): input_lengths = torch.IntTensor(batch_size).fill_(di) outputs = self.forward_step( diff --git a/openspeech/models/conformer/configurations.py b/openspeech/models/conformer/configurations.py index 9012c9f..015dec4 100644 --- a/openspeech/models/conformer/configurations.py +++ b/openspeech/models/conformer/configurations.py @@ -35,7 +35,7 @@ class ConformerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: conformer) encoder_dim (int): Dimension of encoder. (default: 512) num_encoder_layers (int): The number of encoder layers. (default: 17) diff --git a/openspeech/models/conformer_lstm/configurations.py b/openspeech/models/conformer_lstm/configurations.py index e1d6cc3..37c1375 100644 --- a/openspeech/models/conformer_lstm/configurations.py +++ b/openspeech/models/conformer_lstm/configurations.py @@ -35,7 +35,7 @@ class ConformerLSTMConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: conformer_lstm) encoder_dim (int): Dimension of encoder. (default: 512) num_encoder_layers (int): The number of encoder layers. (default: 17) diff --git a/openspeech/models/conformer_lstm/model.py b/openspeech/models/conformer_lstm/model.py index f181614..a66bdd1 100644 --- a/openspeech/models/conformer_lstm/model.py +++ b/openspeech/models/conformer_lstm/model.py @@ -85,13 +85,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/conformer_transducer/configurations.py b/openspeech/models/conformer_transducer/configurations.py index 9ad0b68..de3772a 100644 --- a/openspeech/models/conformer_transducer/configurations.py +++ b/openspeech/models/conformer_transducer/configurations.py @@ -35,7 +35,7 @@ class ConformerTransducerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: conformer_transducer) encoder_dim (int): Dimension of encoder. (default: 512) num_encoder_layers (int): The number of encoder layers. (default: 17) diff --git a/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/configurations.py b/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/configurations.py index cfc6da8..2e17b83 100644 --- a/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/configurations.py +++ b/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/configurations.py @@ -35,7 +35,7 @@ class DeepCNNWithJointCTCListenAttendSpellConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: deep_cnn_with_joint_ctc_listen_attend_spell) num_encoder_layers (int): The number of encoder layers. (default: 3) num_decoder_layers (int): The number of decoder layers. (default: 2) diff --git a/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/model.py b/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/model.py index 3fe6c16..eb5776b 100644 --- a/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/model.py +++ b/openspeech/models/deep_cnn_with_joint_ctc_listen_attend_spell/model.py @@ -83,13 +83,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/deepspeech2/configurations.py b/openspeech/models/deepspeech2/configurations.py index 16848c2..e75414b 100644 --- a/openspeech/models/deepspeech2/configurations.py +++ b/openspeech/models/deepspeech2/configurations.py @@ -35,7 +35,7 @@ class DeepSpeech2Configs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: deepspeech2) num_rnn_layers (int): The number of rnn layers. (default: 5) rnn_hidden_dim (int): The hidden state dimension of rnn. (default: 1024) diff --git a/openspeech/models/jasper10x5/configurations.py b/openspeech/models/jasper10x5/configurations.py index 0781d8b..d89a2f8 100644 --- a/openspeech/models/jasper10x5/configurations.py +++ b/openspeech/models/jasper10x5/configurations.py @@ -35,7 +35,7 @@ class Jasper10x5Config(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: jasper10x5) num_blocks (int): Number of jasper blocks (default: 10) num_sub_blocks (int): Number of jasper sub blocks (default: 5) diff --git a/openspeech/models/jasper5x3/configurations.py b/openspeech/models/jasper5x3/configurations.py index b0cee88..6382481 100644 --- a/openspeech/models/jasper5x3/configurations.py +++ b/openspeech/models/jasper5x3/configurations.py @@ -35,7 +35,7 @@ class Jasper5x3Config(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: jasper5x3) num_blocks (int): Number of jasper blocks (default: 5) num_sub_blocks (int): Number of jasper sub blocks (default: 3) diff --git a/openspeech/models/joint_ctc_conformer_lstm/configurations.py b/openspeech/models/joint_ctc_conformer_lstm/configurations.py index ec40d2e..dec14f5 100644 --- a/openspeech/models/joint_ctc_conformer_lstm/configurations.py +++ b/openspeech/models/joint_ctc_conformer_lstm/configurations.py @@ -35,7 +35,7 @@ class JointCTCConformerLSTMConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: joint_ctc_conformer_lstm) encoder_dim (int): Dimension of encoder. (default: 512) num_encoder_layers (int): The number of encoder layers. (default: 17) diff --git a/openspeech/models/joint_ctc_conformer_lstm/model.py b/openspeech/models/joint_ctc_conformer_lstm/model.py index 7ed000f..4acb760 100644 --- a/openspeech/models/joint_ctc_conformer_lstm/model.py +++ b/openspeech/models/joint_ctc_conformer_lstm/model.py @@ -85,13 +85,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/joint_ctc_listen_attend_spell/configurations.py b/openspeech/models/joint_ctc_listen_attend_spell/configurations.py index 221546b..62d734d 100644 --- a/openspeech/models/joint_ctc_listen_attend_spell/configurations.py +++ b/openspeech/models/joint_ctc_listen_attend_spell/configurations.py @@ -35,21 +35,21 @@ class JointCTCListenAttendSpellConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: - model_name (str): Model name (default: joint_ctc_listen_attend_spell) - num_encoder_layers (int): The number of encoder layers. (default: 3) - num_decoder_layers (int): The number of decoder layers. (default: 2) - hidden_state_dim (int): The hidden state dimension of encoder. (default: 768) - encoder_dropout_p (float): The dropout probability of encoder. (default: 0.3) - encoder_bidirectional (bool): If True, becomes a bidirectional encoders (default: True) - rnn_type (str): Type of rnn cell (rnn, lstm, gru) (default: lstm) - joint_ctc_attention (bool): Flag indication joint ctc attention or not (default: True) - max_length (int): Max decoding length. (default: 128) - num_attention_heads (int): The number of attention heads. (default: 1) - decoder_dropout_p (float): The dropout probability of decoder. (default: 0.2) - decoder_attn_mechanism (str): The attention mechanism for decoder. (default: loc) - teacher_forcing_ratio (float): The ratio of teacher forcing. (default: 1.0) - optimizer (str): Optimizer for training. (default: adam) + Args: + model_name (str): Model name (default: joint_ctc_listen_attend_spell) + num_encoder_layers (int): The number of encoder layers. (default: 3) + num_decoder_layers (int): The number of decoder layers. (default: 2) + hidden_state_dim (int): The hidden state dimension of encoder. (default: 768) + encoder_dropout_p (float): The dropout probability of encoder. (default: 0.3) + encoder_bidirectional (bool): If True, becomes a bidirectional encoders (default: True) + rnn_type (str): Type of rnn cell (rnn, lstm, gru) (default: lstm) + joint_ctc_attention (bool): Flag indication joint ctc attention or not (default: True) + max_length (int): Max decoding length. (default: 128) + num_attention_heads (int): The number of attention heads. (default: 1) + decoder_dropout_p (float): The dropout probability of decoder. (default: 0.2) + decoder_attn_mechanism (str): The attention mechanism for decoder. (default: loc) + teacher_forcing_ratio (float): The ratio of teacher forcing. (default: 1.0) + optimizer (str): Optimizer for training. (default: adam) """ model_name: str = field( default="joint_ctc_listen_attend_spell", metadata={"help": "Model name"} diff --git a/openspeech/models/joint_ctc_listen_attend_spell/model.py b/openspeech/models/joint_ctc_listen_attend_spell/model.py index ef87423..53764b3 100644 --- a/openspeech/models/joint_ctc_listen_attend_spell/model.py +++ b/openspeech/models/joint_ctc_listen_attend_spell/model.py @@ -82,13 +82,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search.beam_search_lstm import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/joint_ctc_transformer/configurations.py b/openspeech/models/joint_ctc_transformer/configurations.py index db5d063..7970479 100644 --- a/openspeech/models/joint_ctc_transformer/configurations.py +++ b/openspeech/models/joint_ctc_transformer/configurations.py @@ -35,7 +35,7 @@ class JointCTCTransformerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: joint_ctc_transformer) extractor (str): The CNN feature extractor. (default: conv2d_subsample) d_model (int): Dimension of model. (default: 512) diff --git a/openspeech/models/joint_ctc_transformer/model.py b/openspeech/models/joint_ctc_transformer/model.py index f481e69..80df9c0 100644 --- a/openspeech/models/joint_ctc_transformer/model.py +++ b/openspeech/models/joint_ctc_transformer/model.py @@ -79,12 +79,11 @@ def build_model(self): max_length=self.configs.model.max_length, ) - def set_beam_decoder(self, batch_size: int = None, beam_size: int = 3, n_best: int = 1): + def set_beam_decoder(self, beam_size: int = 3, n_best: int = 1): """ Setting beam search decoder """ from openspeech.search import BeamSearchTransformer self.decoder = BeamSearchTransformer( decoder=self.decoder, - batch_size=batch_size, beam_size=beam_size, ) diff --git a/openspeech/models/listen_attend_spell/configurations.py b/openspeech/models/listen_attend_spell/configurations.py index 680450c..12b4039 100644 --- a/openspeech/models/listen_attend_spell/configurations.py +++ b/openspeech/models/listen_attend_spell/configurations.py @@ -35,7 +35,7 @@ class ListenAttendSpellConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: listen_attend_spell) num_encoder_layers (int): The number of encoder layers. (default: 3) num_decoder_layers (int): The number of decoder layers. (default: 2) diff --git a/openspeech/models/listen_attend_spell/model.py b/openspeech/models/listen_attend_spell/model.py index 89942c2..f835d3f 100644 --- a/openspeech/models/listen_attend_spell/model.py +++ b/openspeech/models/listen_attend_spell/model.py @@ -82,13 +82,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/listen_attend_spell_with_location_aware/configurations.py b/openspeech/models/listen_attend_spell_with_location_aware/configurations.py index 9043695..2174918 100644 --- a/openspeech/models/listen_attend_spell_with_location_aware/configurations.py +++ b/openspeech/models/listen_attend_spell_with_location_aware/configurations.py @@ -35,7 +35,7 @@ class ListenAttendSpellWithLocationAwareConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: listen_attend_spell_with_location_aware) num_encoder_layers (int): The number of encoder layers. (default: 3) num_decoder_layers (int): The number of decoder layers. (default: 2) diff --git a/openspeech/models/listen_attend_spell_with_location_aware/model.py b/openspeech/models/listen_attend_spell_with_location_aware/model.py index 9acb340..cfed86c 100644 --- a/openspeech/models/listen_attend_spell_with_location_aware/model.py +++ b/openspeech/models/listen_attend_spell_with_location_aware/model.py @@ -82,13 +82,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/listen_attend_spell_with_multi_head/configurations.py b/openspeech/models/listen_attend_spell_with_multi_head/configurations.py index 16b3b8e..398c1f5 100644 --- a/openspeech/models/listen_attend_spell_with_multi_head/configurations.py +++ b/openspeech/models/listen_attend_spell_with_multi_head/configurations.py @@ -35,7 +35,7 @@ class ListenAttendSpellWithMultiHeadConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: listen_attend_spell_with_multi_head) num_encoder_layers (int): The number of encoder layers. (default: 3) num_decoder_layers (int): The number of decoder layers. (default: 2) diff --git a/openspeech/models/listen_attend_spell_with_multi_head/model.py b/openspeech/models/listen_attend_spell_with_multi_head/model.py index c17ed7f..de1cca1 100644 --- a/openspeech/models/listen_attend_spell_with_multi_head/model.py +++ b/openspeech/models/listen_attend_spell_with_multi_head/model.py @@ -82,13 +82,12 @@ def build_model(self): rnn_type=self.configs.model.rnn_type, ) - def set_beam_decoder(self, batch_size: int, beam_size: int = 3): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchLSTM self.decoder = BeamSearchLSTM( decoder=self.decoder, beam_size=beam_size, - batch_size=batch_size, ) def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]: diff --git a/openspeech/models/openspeech_ctc_model.py b/openspeech/models/openspeech_ctc_model.py index f038bad..955c329 100644 --- a/openspeech/models/openspeech_ctc_model.py +++ b/openspeech/models/openspeech_ctc_model.py @@ -116,7 +116,7 @@ def forward(self, inputs: torch.FloatTensor, input_lengths: torch.IntTensor) -> else: y_hats = logits.max(-1)[1] return { - "y_hats": y_hats, + "predictions": y_hats, "logits": logits, "output_lengths": output_lengths, } diff --git a/openspeech/models/openspeech_transducer_model.py b/openspeech/models/openspeech_transducer_model.py index 6f684dc..656a7f3 100644 --- a/openspeech/models/openspeech_transducer_model.py +++ b/openspeech/models/openspeech_transducer_model.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn +import warnings from torch import Tensor from collections import OrderedDict from omegaconf import DictConfig @@ -68,6 +69,9 @@ def __init__(self, configs: DictConfig, vocab: Vocabulary, ) -> None: Linear(in_features=in_features, out_features=self.num_classes), ) + def set_beam_decoder(self, beam_size: int = 3): + warnings.warn("Currently, Beamsearch has not yet been implemented in the transducer model.") + def collect_outputs( self, stage: str, diff --git a/openspeech/models/quartznet10x5/configurations.py b/openspeech/models/quartznet10x5/configurations.py index f4d1571..96c8c82 100644 --- a/openspeech/models/quartznet10x5/configurations.py +++ b/openspeech/models/quartznet10x5/configurations.py @@ -35,7 +35,7 @@ class QuartzNet10x5Configs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: quartznet5x5) num_blocks (int): Number of quartznet blocks (default: 10) num_sub_blocks (int): Number of quartznet sub blocks (default: 5) diff --git a/openspeech/models/quartznet15x5/configurations.py b/openspeech/models/quartznet15x5/configurations.py index 55087fc..10f991f 100644 --- a/openspeech/models/quartznet15x5/configurations.py +++ b/openspeech/models/quartznet15x5/configurations.py @@ -35,7 +35,7 @@ class QuartzNet15x5Configs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: quartznet15x5) num_blocks (int): Number of quartznet blocks (default: 15) num_sub_blocks (int): Number of quartznet sub blocks (default: 5) diff --git a/openspeech/models/quartznet5x5/configurations.py b/openspeech/models/quartznet5x5/configurations.py index 1c9db3a..525aacf 100644 --- a/openspeech/models/quartznet5x5/configurations.py +++ b/openspeech/models/quartznet5x5/configurations.py @@ -35,7 +35,7 @@ class QuartzNet5x5Configs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: quartznet5x5) num_blocks (int): Number of quartznet blocks (default: 5) num_sub_blocks (int): Number of quartznet sub blocks (default: 5) diff --git a/openspeech/models/rnn_transducer/configurations.py b/openspeech/models/rnn_transducer/configurations.py index 1ba5a8e..0efb842 100644 --- a/openspeech/models/rnn_transducer/configurations.py +++ b/openspeech/models/rnn_transducer/configurations.py @@ -35,7 +35,7 @@ class RNNTransducerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: transformer_transducer) encoder_hidden_state_dim (int): Hidden state dimension of encoder (default: 312) decoder_hidden_state_dim (int): Hidden state dimension of decoder (default: 512) diff --git a/openspeech/models/transformer/configurations.py b/openspeech/models/transformer/configurations.py index fd9819d..607cedd 100644 --- a/openspeech/models/transformer/configurations.py +++ b/openspeech/models/transformer/configurations.py @@ -35,7 +35,7 @@ class TransformerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: transformer) d_model (int): Dimension of model. (default: 512) d_ff (int): Dimenstion of feed forward network. (default: 2048) diff --git a/openspeech/models/transformer/model.py b/openspeech/models/transformer/model.py index d6a922b..f33fc07 100644 --- a/openspeech/models/transformer/model.py +++ b/openspeech/models/transformer/model.py @@ -78,12 +78,11 @@ def build_model(self): max_length=self.configs.model.max_length, ) - def set_beam_decoder(self, batch_size: int = None, beam_size: int = 3, n_best: int = 1): + def set_beam_decoder(self, beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchTransformer self.decoder = BeamSearchTransformer( decoder=self.decoder, - batch_size=batch_size, beam_size=beam_size, ) diff --git a/openspeech/models/transformer_transducer/configurations.py b/openspeech/models/transformer_transducer/configurations.py index 51f1aa9..b0796f1 100644 --- a/openspeech/models/transformer_transducer/configurations.py +++ b/openspeech/models/transformer_transducer/configurations.py @@ -35,7 +35,7 @@ class TransformerTransducerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: transformer_transducer) extractor (str): The CNN feature extractor. (default: conv2d_subsample) d_model (int): Dimension of model. (default: 512) diff --git a/openspeech/models/transformer_with_ctc/configurations.py b/openspeech/models/transformer_with_ctc/configurations.py index eb5b2ce..efb82e6 100644 --- a/openspeech/models/transformer_with_ctc/configurations.py +++ b/openspeech/models/transformer_with_ctc/configurations.py @@ -35,7 +35,7 @@ class TransformerWithCTCConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: transformer_with_ctc) extractor (str): The CNN feature extractor. (default: vgg) d_model (int): Dimension of model. (default: 512) diff --git a/openspeech/models/vgg_transformer/configurations.py b/openspeech/models/vgg_transformer/configurations.py index 8de17e6..28d816e 100644 --- a/openspeech/models/vgg_transformer/configurations.py +++ b/openspeech/models/vgg_transformer/configurations.py @@ -35,7 +35,7 @@ class VGGTransformerConfigs(OpenspeechDataclass): Configuration objects inherit from :class: `~openspeech.dataclass.configs.OpenspeechDataclass`. - Configurations: + Args: model_name (str): Model name (default: vgg_transformer) extractor (str): The CNN feature extractor. (default: vgg) d_model (int): Dimension of model. (default: 512) diff --git a/openspeech/models/vgg_transformer/model.py b/openspeech/models/vgg_transformer/model.py index 3165fcb..cf7633a 100644 --- a/openspeech/models/vgg_transformer/model.py +++ b/openspeech/models/vgg_transformer/model.py @@ -79,12 +79,11 @@ def build_model(self): max_length=self.configs.model.max_length, ) - def set_beam_decoder(self, batch_size: int = None, beam_size: int = 3, n_best: int = 1): + def set_beam_decoder(self,beam_size: int = 3): """ Setting beam search decoder """ from openspeech.search import BeamSearchTransformer self.decoder = BeamSearchTransformer( decoder=self.decoder, - batch_size=batch_size, beam_size=beam_size, ) diff --git a/openspeech/search/base.py b/openspeech/search/base.py index e809093..bec9dfc 100644 --- a/openspeech/search/base.py +++ b/openspeech/search/base.py @@ -25,7 +25,7 @@ class OpenspeechBeamSearchBase(nn.Module): - def __init__(self, decoder, beam_size: int, batch_size: int): + def __init__(self, decoder, beam_size: int): super(OpenspeechBeamSearchBase, self).__init__() self.decoder = decoder self.beam_size = beam_size @@ -34,8 +34,6 @@ def __init__(self, decoder, beam_size: int, batch_size: int): self.eos_id = decoder.eos_id self.ongoing_beams = None self.cumulative_ps = None - self.finished = [[] for _ in range(batch_size)] - self.finished_ps = [[] for _ in range(batch_size)] self.forward_step = decoder.forward_step def _inflate(self, tensor: torch.Tensor, n_repeat: int, dim: int) -> torch.Tensor: diff --git a/openspeech/search/beam_search_lstm.py b/openspeech/search/beam_search_lstm.py index 029226a..9459ec9 100644 --- a/openspeech/search/beam_search_lstm.py +++ b/openspeech/search/beam_search_lstm.py @@ -47,8 +47,8 @@ class BeamSearchLSTM(OpenspeechBeamSearchBase): Returns: * logits (torch.FloatTensor): Log probability of model predictions. """ - def __init__(self, decoder: LSTMDecoder, beam_size: int, batch_size: int): - super(BeamSearchLSTM, self).__init__(decoder, beam_size, batch_size) + def __init__(self, decoder: LSTMDecoder, beam_size: int): + super(BeamSearchLSTM, self).__init__(decoder, beam_size) self.hidden_state_dim = decoder.hidden_state_dim self.num_layers = decoder.num_layers self.validate_args = decoder.validate_args @@ -69,6 +69,10 @@ def forward( * logits (torch.FloatTensor): Log probability of model predictions. """ batch_size, hidden_states = encoder_outputs.size(0), None + + self.finished = [[] for _ in range(batch_size)] + self.finished_ps = [[] for _ in range(batch_size)] + inputs, batch_size, max_length = self.validate_args(None, encoder_outputs, teacher_forcing_ratio=0.0) step_outputs, hidden_states, attn = self.forward_step(inputs, hidden_states, encoder_outputs) diff --git a/openspeech/search/beam_search_transformer.py b/openspeech/search/beam_search_transformer.py index 505df51..d165358 100644 --- a/openspeech/search/beam_search_transformer.py +++ b/openspeech/search/beam_search_transformer.py @@ -27,8 +27,8 @@ class BeamSearchTransformer(OpenspeechBeamSearchBase): - def __init__(self, decoder: TransformerDecoder, batch_size: int, beam_size: int = 3) -> None: - super(BeamSearchTransformer, self).__init__(decoder, beam_size, batch_size) + def __init__(self, decoder: TransformerDecoder, beam_size: int = 3) -> None: + super(BeamSearchTransformer, self).__init__(decoder, beam_size) self.use_cuda = True if torch.cuda.is_available() else False def forward( @@ -38,6 +38,9 @@ def forward( ): batch_size = encoder_outputs.size(0) + self.finished = [[] for _ in range(batch_size)] + self.finished_ps = [[] for _ in range(batch_size)] + decoder_inputs = torch.IntTensor(batch_size, self.decoder.max_length).fill_(self.sos_id).long() decoder_input_lengths = torch.IntTensor(batch_size).fill_(1) diff --git a/openspeech/search/ensemble_search.py b/openspeech/search/ensemble_search.py new file mode 100644 index 0000000..5114c72 --- /dev/null +++ b/openspeech/search/ensemble_search.py @@ -0,0 +1,96 @@ +# MIT License +# +# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import torch.nn as nn +from typing import Union + + +class EnsembleSearch(nn.Module): + """ + Class for ensemble search. + + Args: + models (tuple): list of ensemble model + + Inputs: + - **inputs** (torch.FloatTensor): A input sequence passed to encoders. Typically for inputs this will be + a padded `FloatTensor` of size ``(batch, seq_length, dimension)``. + - **input_lengths** (torch.LongTensor): The length of input tensor. ``(batch)`` + + Returns: + * predictions (torch.LongTensor): prediction of ensemble models + """ + def __init__(self, models: Union[list, tuple]): + super(EnsembleSearch, self).__init__() + assert len(models) > 1, "Ensemble search should be multiple models." + self.models = models + + def forward(self, inputs: torch.FloatTensor, input_lengths: torch.LongTensor): + logits = list() + + for model in self.models: + output = model(inputs, input_lengths) + logits.append(output["logits"]) + + output = logits[0] + + for logit in logits[1:]: + output += logit + + return output.max(-1)[1] + + +class WeightedEnsembleSearch(nn.Module): + """ + Args: + models (tuple): list of ensemble model + weights (tuple: list of ensemble's weight + + Inputs: + - **inputs** (torch.FloatTensor): A input sequence passed to encoders. Typically for inputs this will be + a padded `FloatTensor` of size ``(batch, seq_length, dimension)``. + - **input_lengths** (torch.LongTensor): The length of input tensor. ``(batch)`` + + Returns: + * predictions (torch.LongTensor): prediction of ensemble models + """ + def __init__(self, models: Union[list, tuple], weights: Union[list, tuple]): + super(WeightedEnsembleSearch, self).__init__() + assert len(models) > 1, "Ensemble search should be multiple models." + assert len(models) == len(weights), "len(models), len(weight) should be same." + self.models = models + self.weights = weights + + def forward(self, inputs: torch.FloatTensor, input_lengths: torch.LongTensor): + logits = list() + + for model in self.models: + output = model(inputs, input_lengths) + logits.append(output["logits"]) + + output = logits[0] * self.weights[0] + + for idx, logit in enumerate(logits[1:]): + output += logit * self.weights[1] + + return output.max(-1)[1] diff --git a/openspeech_cli/hydra_ensemble_eval.py b/openspeech_cli/hydra_ensemble_eval.py new file mode 100644 index 0000000..58ac2a0 --- /dev/null +++ b/openspeech_cli/hydra_ensemble_eval.py @@ -0,0 +1,103 @@ +# MIT License +# +# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +import hydra +import warnings +import logging +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.utilities import rank_zero_info + +from openspeech.metrics import WordErrorRate, CharacterErrorRate +from openspeech.data.dataset import SpeechToTextDataset +from openspeech.data.data_loader import load_dataset, AudioDataLoader, BucketingSampler +from openspeech.dataclass.initialize import hydra_eval_init +from openspeech.models import MODEL_REGISTRY +from openspeech.search.ensemble_search import EnsembleSearch, WeightedEnsembleSearch + +logger = logging.getLogger(__name__) + + +@hydra.main(config_path=os.path.join("..", "openspeech", "configs"), config_name="eval") +def hydra_main(configs: DictConfig) -> None: + rank_zero_info(OmegaConf.to_yaml(configs)) + wer, cer = 1.0, 1.0 + models = list() + + audio_paths, transcripts = load_dataset(configs.eval.manifest_file_path) + + model_names = eval(configs.eval.model_names) + checkpoint_paths = eval(configs.eval.checkpoint_paths) + ensemble_weights = eval(configs.eval.ensemble_weights) + + for model_name, checkpoint_path in zip(model_names, checkpoint_paths): + models.append(MODEL_REGISTRY[model_name].load_from_checkpoint(checkpoint_path)) + + if configs.eval.beam_size > 1: + warnings.warn("Currently, Ensemble + beam search is not supports.") + + vocab = models[0].vocab + + if configs.eval.ensemble_method == "vanilla": + model = EnsembleSearch(models) + elif configs.eval.ensemble_method == "weighted": + model = WeightedEnsembleSearch(models, ensemble_weights) + else: + raise ValueError(f"Unsupported ensemble method: {configs.eval.ensemble_method}") + + dataset = SpeechToTextDataset( + configs=configs, + dataset_path=configs.eval.dataset_path, + audio_paths=audio_paths, + transcripts=transcripts, + sos_id=vocab.sos_id, + eos_id=vocab.eos_id, + ) + sampler = BucketingSampler( + data_source=dataset, + batch_size=configs.eval.batch_size + ) + data_loader = AudioDataLoader( + dataset=dataset, + num_workers=configs.eval.num_workers, + batch_sampler=sampler, + ) + + wer_metric = WordErrorRate(vocab) + cer_metric = CharacterErrorRate(vocab) + + for i, (batch) in enumerate(data_loader): + inputs, targets, input_lengths, target_lengths = batch + + outputs = model(inputs, input_lengths) + + wer = wer_metric(targets, outputs) + cer = cer_metric(targets, outputs) + + logger.info(f"Word Error Rate: {wer}, Character Error Rate: {cer}") + + +if __name__ == '__main__': + warnings.filterwarnings("ignore") + hydra_eval_init() + hydra_main() + diff --git a/openspeech_cli/hydra_eval.py b/openspeech_cli/hydra_eval.py new file mode 100644 index 0000000..404b0ee --- /dev/null +++ b/openspeech_cli/hydra_eval.py @@ -0,0 +1,88 @@ +# MIT License +# +# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +import hydra +import warnings +import logging +from omegaconf import DictConfig, OmegaConf +from openspeech.metrics import WordErrorRate, CharacterErrorRate +from pytorch_lightning.utilities import rank_zero_info + +from openspeech.data.dataset import SpeechToTextDataset +from openspeech.data.data_loader import load_dataset, AudioDataLoader, BucketingSampler +from openspeech.dataclass.initialize import hydra_eval_init +from openspeech.models import MODEL_REGISTRY + +logger = logging.getLogger(__name__) + + +@hydra.main(config_path=os.path.join("..", "openspeech", "configs"), config_name="eval") +def hydra_main(configs: DictConfig) -> None: + rank_zero_info(OmegaConf.to_yaml(configs)) + wer, cer = 1.0, 1.0 + + audio_paths, transcripts = load_dataset(configs.eval.manifest_file_path) + + model = MODEL_REGISTRY[configs.eval.model_name].load_from_checkpoint(configs.eval.checkpoint_path) + + if configs.eval.beam_size > 1: + model.set_beam_decoder(beam_size=configs.eval.beam_size) + + vocab = model.vocab + + dataset = SpeechToTextDataset( + configs=configs, + dataset_path=configs.eval.dataset_path, + audio_paths=audio_paths, + transcripts=transcripts, + sos_id=vocab.sos_id, + eos_id=vocab.eos_id, + ) + sampler = BucketingSampler( + data_source=dataset, + batch_size=configs.eval.batch_size + ) + data_loader = AudioDataLoader( + dataset=dataset, + num_workers=configs.eval.num_workers, + batch_sampler=sampler, + ) + + wer_metric = WordErrorRate(vocab) + cer_metric = CharacterErrorRate(vocab) + + for i, (batch) in enumerate(data_loader): + inputs, targets, input_lengths, target_lengths = batch + + outputs = model(inputs, input_lengths) + + wer = wer_metric(targets, outputs["predictions"]) + cer = cer_metric(targets, outputs["predictions"]) + + logger.info(f"Word Error Rate: {wer}, Character Error Rate: {cer}") + + +if __name__ == '__main__': + warnings.filterwarnings("ignore") + hydra_eval_init() + hydra_main() diff --git a/openspeech_cli/hydra_train.py b/openspeech_cli/hydra_train.py index aff04b2..187bde2 100644 --- a/openspeech_cli/hydra_train.py +++ b/openspeech_cli/hydra_train.py @@ -33,7 +33,7 @@ from openspeech.utils import parse_configs, get_pl_trainer -@hydra.main(config_path=os.path.join("..", "openspeech", "configs"), config_name="configs") +@hydra.main(config_path=os.path.join("..", "openspeech", "configs"), config_name="train") def hydra_main(configs: DictConfig) -> None: rank_zero_info(OmegaConf.to_yaml(configs)) pl.seed_everything(configs.trainer.seed) @@ -49,6 +49,7 @@ def hydra_main(configs: DictConfig) -> None: trainer = get_pl_trainer(configs, num_devices, logger) trainer.fit(model, data_module) + trainer.test() if __name__ == '__main__': diff --git a/setup.py b/setup.py index 63e411f..e3c6edb 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup( name='openspeech-py', - version='v0.1', + version='0.2', description='Open-Source Toolkit for End-to-End Automatic Speech Recognition', author='Kim, Soohwan and Ha, Sangchun and Cho, Soyoung', author_email='sh951011@gmail.com, seomk9896@naver.com, soyoung.cho@kaist.ac.kr', diff --git a/tests/test_conformer_lstm.py b/tests/test_conformer_lstm.py index ffb1cea..3ecab5a 100644 --- a/tests/test_conformer_lstm.py +++ b/tests/test_conformer_lstm.py @@ -50,7 +50,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = ConformerLSTMModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_deep_cnn_with_joint_ctc_listen_attend_spell.py b/tests/test_deep_cnn_with_joint_ctc_listen_attend_spell.py index 8300664..5acc1dc 100644 --- a/tests/test_deep_cnn_with_joint_ctc_listen_attend_spell.py +++ b/tests/test_deep_cnn_with_joint_ctc_listen_attend_spell.py @@ -46,7 +46,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = DeepCNNWithJointCTCListenAttendSpellModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_joint_ctc_conformer_lstm.py b/tests/test_joint_ctc_conformer_lstm.py index eea4bbc..3634607 100644 --- a/tests/test_joint_ctc_conformer_lstm.py +++ b/tests/test_joint_ctc_conformer_lstm.py @@ -46,7 +46,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = JointCTCConformerLSTMModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_joint_ctc_listen_attend_spell.py b/tests/test_joint_ctc_listen_attend_spell.py index 12b1b71..9e9ebaf 100644 --- a/tests/test_joint_ctc_listen_attend_spell.py +++ b/tests/test_joint_ctc_listen_attend_spell.py @@ -46,7 +46,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = JointCTCListenAttendSpellModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_joint_ctc_transformer.py b/tests/test_joint_ctc_transformer.py index 9bff9f0..fc3dccd 100644 --- a/tests/test_joint_ctc_transformer.py +++ b/tests/test_joint_ctc_transformer.py @@ -46,7 +46,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = JointCTCTransformerModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_listen_attend_spell.py b/tests/test_listen_attend_spell.py index f9873b0..be9ab2c 100644 --- a/tests/test_listen_attend_spell.py +++ b/tests/test_listen_attend_spell.py @@ -50,7 +50,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = ListenAttendSpellModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_listen_attend_spell_with_location_aware.py b/tests/test_listen_attend_spell_with_location_aware.py index 58e0dbc..61ad7d0 100644 --- a/tests/test_listen_attend_spell_with_location_aware.py +++ b/tests/test_listen_attend_spell_with_location_aware.py @@ -50,7 +50,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = ListenAttendSpellWithLocationAwareModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_listen_attend_spell_with_multi_head.py b/tests/test_listen_attend_spell_with_multi_head.py index d332e03..7578ef2 100644 --- a/tests/test_listen_attend_spell_with_multi_head.py +++ b/tests/test_listen_attend_spell_with_multi_head.py @@ -50,7 +50,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = ListenAttendSpellWithMultiHeadModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_transformer.py b/tests/test_transformer.py index c17784c..b1ec9c8 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -50,7 +50,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = TransformerModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"] diff --git a/tests/test_vgg_transformer.py b/tests/test_vgg_transformer.py index 6200564..b737668 100644 --- a/tests/test_vgg_transformer.py +++ b/tests/test_vgg_transformer.py @@ -50,7 +50,7 @@ def test_beam_search(self): vocab = KsponSpeechCharacterVocabulary(configs) model = VGGTransformerModel(configs, vocab) model.build_model() - model.set_beam_decoder(batch_size=3, beam_size=3) + model.set_beam_decoder(beam_size=3) for i in range(3): prediction = model(DUMMY_INPUTS, DUMMY_INPUT_LENGTHS)["predictions"]