diff --git a/wenet/k2/__init__.py b/wenet/k2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wenet/k2/model.py b/wenet/k2/model.py new file mode 100644 index 000000000..735a8a2e1 --- /dev/null +++ b/wenet/k2/model.py @@ -0,0 +1,278 @@ +# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import torch +from torch.nn.utils.rnn import pad_sequence + +try: + import k2 + from icefall.utils import get_texts + from icefall.decode import get_lattice, Nbest, one_best_decoding + from icefall.mmi import LFMMILoss + from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler +except ImportError: + print('Warning: Failed to import k2 & icefall, which are for LF-MMI/hlg') + +from wenet.transformer.asr_model import ASRModel +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.transformer.encoder import TransformerEncoder +from wenet.utils.common import (IGNORE_ID, add_sos_eos, reverse_pad_list) + + +class K2Model(ASRModel): + + def __init__( + self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + ignore_id: int = IGNORE_ID, + reverse_weight: float = 0.0, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + lfmmi_dir: str = '', + ): + super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, + ignore_id, reverse_weight, lsm_weight, + length_normalized_loss) + self.lfmmi_dir = lfmmi_dir + if self.lfmmi_dir != '': + self.load_lfmmi_resource() + + @torch.jit.ignore(drop=True) + def _forward_ctc(self, encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, text: torch.Tensor, + text_lengths: torch.Tensor) -> torch.Tensor: + loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, text) + return loss_ctc + + @torch.jit.ignore(drop=True) + def load_lfmmi_resource(self): + with open('{}/tokens.txt'.format(self.lfmmi_dir), 'r') as fin: + for line in fin: + arr = line.strip().split() + if arr[0] == '': + self.sos_eos_id = int(arr[1]) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.graph_compiler = MmiTrainingGraphCompiler( + self.lfmmi_dir, + device=device, + oov="", + sos_id=self.sos_eos_id, + eos_id=self.sos_eos_id, + ) + self.lfmmi = LFMMILoss( + graph_compiler=self.graph_compiler, + den_scale=1, + use_pruned_intersect=False, + ) + self.word_table = {} + with open('{}/words.txt'.format(self.lfmmi_dir), 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + self.word_table[int(arr[1])] = arr[0] + + @torch.jit.ignore(drop=True) + def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): + ctc_probs = self.ctc.log_softmax(encoder_out) + supervision_segments = torch.stack(( + torch.arange(len(encoder_mask)), + torch.zeros(len(encoder_mask)), + encoder_mask.squeeze(dim=1).sum(dim=1).to('cpu'), + ), 1).to(torch.int32) + dense_fsa_vec = k2.DenseFsaVec( + ctc_probs, + supervision_segments, + allow_truncate=3, + ) + text = [ + ' '.join([self.word_table[j.item()] for j in i if j != -1]) + for i in text + ] + loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text) + return loss + + def load_hlg_resource_if_necessary(self, hlg, word): + if not hasattr(self, 'hlg'): + device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device)) + if not hasattr(self.hlg, "lm_scores"): + self.hlg.lm_scores = self.hlg.scores.clone() + if not hasattr(self, 'word_table'): + self.word_table = {} + with open(word, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + self.word_table[int(arr[1])] = arr[0] + + @torch.no_grad() + def hlg_onebest( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + hlg: str = '', + word: str = '', + symbol_table: Dict[str, int] = None, + ) -> List[int]: + self.load_hlg_resource_if_necessary(hlg, word) + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + supervision_segments = torch.stack( + (torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), + encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), + 1, + ).to(torch.int32) + lattice = get_lattice(nnet_output=ctc_probs, + decoding_graph=self.hlg, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=7, + min_active_states=30, + max_active_states=10000, + subsampling_factor=4) + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + hyps = get_texts(best_path) + hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] + for i in hyps] + return hyps + + @torch.no_grad() + def hlg_rescore( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + lm_scale: float = 0, + decoder_scale: float = 0, + r_decoder_scale: float = 0, + hlg: str = '', + word: str = '', + symbol_table: Dict[str, int] = None, + ) -> List[int]: + self.load_hlg_resource_if_necessary(hlg, word) + device = speech.device + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + supervision_segments = torch.stack( + (torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), + encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), + 1, + ).to(torch.int32) + lattice = get_lattice(nnet_output=ctc_probs, + decoding_graph=self.hlg, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=7, + min_active_states=30, + max_active_states=10000, + subsampling_factor=4) + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=100, + use_double_scores=True, + nbest_scale=0.5, + ) + nbest = nbest.intersect(lattice) + assert hasattr(nbest.fsa, "lm_scores") + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + hyps = tokens.tolist() + + # cal attention_score + hyps_pad = pad_sequence([ + torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + ori_hyps_pad = hyps_pad + hyps_lens = torch.tensor([len(hyp) for hyp in hyps], + device=device, + dtype=torch.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + encoder_out_repeat = [] + tot_scores = nbest.tot_scores() + repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)] + for i in range(len(encoder_out)): + encoder_out_repeat.append(encoder_out[i:i + 1].repeat( + repeats[i], 1, 1)) + encoder_out = torch.concat(encoder_out_repeat, dim=0) + encoder_mask = torch.ones(encoder_out.size(0), + 1, + encoder_out.size(1), + dtype=torch.bool, + device=device) + # used for right to left decoder + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) + reverse_weight = 0.5 + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out + + decoder_scores = torch.tensor([ + sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))]) + for i in range(len(hyps)) + ], + device=device) # noqa + r_decoder_scores = [] + for i in range(len(hyps)): + score = 0 + for j in range(len(hyps[i])): + score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]] + score += r_decoder_out[i, len(hyps[i]), self.eos] + r_decoder_scores.append(score) + r_decoder_scores = torch.tensor(r_decoder_scores, device=device) + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \ + decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] + for i in hyps] + return hyps diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index b95518a16..7f0f2b353 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -16,16 +16,6 @@ from typing import Dict, List, Optional, Tuple import torch -from torch.nn.utils.rnn import pad_sequence - -try: - import k2 - from icefall.utils import get_texts - from icefall.decode import get_lattice, Nbest, one_best_decoding - from icefall.mmi import LFMMILoss - from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -except ImportError: - print('Warning: Failed to import k2 & icefall, which are for LF-MMI/hlg') from wenet.transformer.ctc import CTC from wenet.transformer.decoder import TransformerDecoder @@ -35,9 +25,7 @@ ctc_prefix_beam_search, attention_beam_search, attention_rescoring) -from wenet.utils.common import (IGNORE_ID, - add_sos_eos, - th_accuracy, +from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy, reverse_pad_list) from wenet.utils.context_graph import ContextGraph @@ -56,7 +44,6 @@ def __init__( reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, - lfmmi_dir: str = '', ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight @@ -78,9 +65,6 @@ def __init__( smoothing=lsm_weight, normalize_length=length_normalized_loss, ) - self.lfmmi_dir = lfmmi_dir - if self.lfmmi_dir != '': - self.load_lfmmi_resource() @torch.jit.ignore(drop=True) def forward( @@ -115,14 +99,10 @@ def forward( else: loss_att = None - # 2b. CTC branch or LF-MMI loss + # 2b. CTC branch if self.ctc_weight != 0.0: - if self.lfmmi_dir != '': - loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, - text) - else: - loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, - text_lengths) + loss_ctc = self._forward_ctc(encoder_out, encoder_mask, text, + text_lengths) else: loss_ctc = None @@ -135,6 +115,14 @@ def forward( self.ctc_weight) * loss_att return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc} + @torch.jit.ignore(drop=True) + def _forward_ctc(self, encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, text: torch.Tensor, + text_lengths: torch.Tensor) -> torch.Tensor: + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) + return loss_ctc + def _calc_att_loss( self, encoder_out: torch.Tensor, @@ -392,223 +380,8 @@ def attention_rescoring( speech, speech_lengths, beam_size, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, context_graph) - return attention_rescoring(self, hyps, encoder_out, ctc_weight, reverse_weight) - - - @torch.jit.ignore(drop=True) - def load_lfmmi_resource(self): - with open('{}/tokens.txt'.format(self.lfmmi_dir), 'r') as fin: - for line in fin: - arr = line.strip().split() - if arr[0] == '': - self.sos_eos_id = int(arr[1]) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.graph_compiler = MmiTrainingGraphCompiler( - self.lfmmi_dir, - device=device, - oov="", - sos_id=self.sos_eos_id, - eos_id=self.sos_eos_id, - ) - self.lfmmi = LFMMILoss( - graph_compiler=self.graph_compiler, - den_scale=1, - use_pruned_intersect=False, - ) - self.word_table = {} - with open('{}/words.txt'.format(self.lfmmi_dir), 'r') as fin: - for line in fin: - arr = line.strip().split() - assert len(arr) == 2 - self.word_table[int(arr[1])] = arr[0] - - @torch.jit.ignore(drop=True) - def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): - ctc_probs = self.ctc.log_softmax(encoder_out) - supervision_segments = torch.stack(( - torch.arange(len(encoder_mask)), - torch.zeros(len(encoder_mask)), - encoder_mask.squeeze(dim=1).sum(dim=1).to('cpu'), - ), 1).to(torch.int32) - dense_fsa_vec = k2.DenseFsaVec( - ctc_probs, - supervision_segments, - allow_truncate=3, - ) - text = [ - ' '.join([self.word_table[j.item()] for j in i if j != -1]) - for i in text - ] - loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text) - return loss - - def load_hlg_resource_if_necessary(self, hlg, word): - if not hasattr(self, 'hlg'): - device = torch.device( - 'cuda' if torch.cuda.is_available() else 'cpu') - self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device)) - if not hasattr(self.hlg, "lm_scores"): - self.hlg.lm_scores = self.hlg.scores.clone() - if not hasattr(self, 'word_table'): - self.word_table = {} - with open(word, 'r') as fin: - for line in fin: - arr = line.strip().split() - assert len(arr) == 2 - self.word_table[int(arr[1])] = arr[0] - - @torch.no_grad() - def hlg_onebest( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - decoding_chunk_size: int = -1, - num_decoding_left_chunks: int = -1, - simulate_streaming: bool = False, - hlg: str = '', - word: str = '', - symbol_table: Dict[str, int] = None, - ) -> List[int]: - self.load_hlg_resource_if_necessary(hlg, word) - encoder_out, encoder_mask = self._forward_encoder( - speech, speech_lengths, decoding_chunk_size, - num_decoding_left_chunks, - simulate_streaming) # (B, maxlen, encoder_dim) - ctc_probs = self.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - supervision_segments = torch.stack( - (torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), - encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), - 1, - ).to(torch.int32) - lattice = get_lattice(nnet_output=ctc_probs, - decoding_graph=self.hlg, - supervision_segments=supervision_segments, - search_beam=20, - output_beam=7, - min_active_states=30, - max_active_states=10000, - subsampling_factor=4) - best_path = one_best_decoding(lattice=lattice, use_double_scores=True) - hyps = get_texts(best_path) - hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] - for i in hyps] - return hyps - - @torch.no_grad() - def hlg_rescore( - self, - speech: torch.Tensor, - speech_lengths: torch.Tensor, - decoding_chunk_size: int = -1, - num_decoding_left_chunks: int = -1, - simulate_streaming: bool = False, - lm_scale: float = 0, - decoder_scale: float = 0, - r_decoder_scale: float = 0, - hlg: str = '', - word: str = '', - symbol_table: Dict[str, int] = None, - ) -> List[int]: - self.load_hlg_resource_if_necessary(hlg, word) - device = speech.device - encoder_out, encoder_mask = self._forward_encoder( - speech, speech_lengths, decoding_chunk_size, - num_decoding_left_chunks, - simulate_streaming) # (B, maxlen, encoder_dim) - ctc_probs = self.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - supervision_segments = torch.stack( - (torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), - encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), - 1, - ).to(torch.int32) - lattice = get_lattice(nnet_output=ctc_probs, - decoding_graph=self.hlg, - supervision_segments=supervision_segments, - search_beam=20, - output_beam=7, - min_active_states=30, - max_active_states=10000, - subsampling_factor=4) - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=100, - use_double_scores=True, - nbest_scale=0.5, - ) - nbest = nbest.intersect(lattice) - assert hasattr(nbest.fsa, "lm_scores") - assert hasattr(nbest.fsa, "tokens") - assert isinstance(nbest.fsa.tokens, torch.Tensor) - - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) - tokens = tokens.remove_values_leq(0) - hyps = tokens.tolist() - - # cal attention_score - hyps_pad = pad_sequence([ - torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps - ], True, self.ignore_id) # (beam_size, max_hyps_len) - ori_hyps_pad = hyps_pad - hyps_lens = torch.tensor([len(hyp) for hyp in hyps], - device=device, - dtype=torch.long) # (beam_size,) - hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining - encoder_out_repeat = [] - tot_scores = nbest.tot_scores() - repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)] - for i in range(len(encoder_out)): - encoder_out_repeat.append(encoder_out[i:i + 1].repeat( - repeats[i], 1, 1)) - encoder_out = torch.concat(encoder_out_repeat, dim=0) - encoder_mask = torch.ones(encoder_out.size(0), - 1, - encoder_out.size(1), - dtype=torch.bool, - device=device) - # used for right to left decoder - r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) - r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, - self.ignore_id) - reverse_weight = 0.5 - decoder_out, r_decoder_out, _ = self.decoder( - encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, - reverse_weight) # (beam_size, max_hyps_len, vocab_size) - decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) - decoder_out = decoder_out - # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a - # conventional transformer decoder. - r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) - r_decoder_out = r_decoder_out - - decoder_scores = torch.tensor([ - sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))]) - for i in range(len(hyps)) - ], - device=device) # noqa - r_decoder_scores = [] - for i in range(len(hyps)): - score = 0 - for j in range(len(hyps[i])): - score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]] - score += r_decoder_out[i, len(hyps[i]), self.eos] - r_decoder_scores.append(score) - r_decoder_scores = torch.tensor(r_decoder_scores, device=device) - - am_scores = nbest.compute_am_scores() - ngram_lm_scores = nbest.compute_lm_scores() - tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \ - decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] - for i in hyps] - return hyps + return attention_rescoring(self, hyps, encoder_out, ctc_weight, + reverse_weight) @torch.jit.export def subsampling_rate(self) -> int: diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 886bca77c..4c8544ced 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -13,6 +13,8 @@ # limitations under the License. import torch + +from wenet.k2.model import K2Model from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, RNNPredictor) @@ -55,13 +57,12 @@ def init_model(configs): global_cmvn=global_cmvn, **configs['encoder_conf']) elif encoder_type == 'efficientConformer': - encoder = EfficientConformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf'], - **configs['encoder_conf'] - ['efficient_conf'] - if 'efficient_conf' in - configs['encoder_conf'] else {}) + encoder = EfficientConformerEncoder( + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf'], + **configs['encoder_conf']['efficient_conf'] + if 'efficient_conf' in configs['encoder_conf'] else {}) elif encoder_type == 'branchformer': encoder = BranchformerEncoder(input_dim, global_cmvn=global_cmvn, @@ -123,10 +124,18 @@ def init_model(configs): predictor=predictor, **configs['model_conf']) else: - model = ASRModel(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - lfmmi_dir=configs.get('lfmmi_dir', ''), - **configs['model_conf']) + print(configs) + if configs.get('lfmmi_dir', '') != '': + model = K2Model(vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + lfmmi_dir=configs['lfmmi_dir'], + **configs['model_conf']) + else: + model = ASRModel(vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf']) return model