diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 1a842f8fe..e7be3be39 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -28,6 +28,7 @@ from wenet.utils.init_model import init_model from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.context_graph import ContextGraph +from wenet.utils.ctc_utils import get_blank_id def get_args(): @@ -159,7 +160,7 @@ def get_args(): type=str, default='', help='''Context bias mode, selectable from the following - option: decoding-graph、deep-biasing''') + option: decoding-graph, deep-biasing''') parser.add_argument('--context_list_path', type=str, default='', @@ -216,6 +217,7 @@ def main(): test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) # Init asr model from configs + args.jit = False model, configs = init_model(args, configs) use_cuda = args.gpu >= 0 and torch.cuda.is_available() @@ -228,6 +230,9 @@ def main(): context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table, args.bpe_model, args.context_graph_score) + _, blank_id = get_blank_id(configs, tokenizer.symbol_table) + logging.info("blank_id is {}".format(blank_id)) + # TODO(Dinghao Zhou): Support RNN-T related decoding # TODO(Lv Xiang): Support k2 related decoding # TODO(Kaixun Huang): Support context graph @@ -255,7 +260,8 @@ def main(): ctc_weight=args.ctc_weight, simulate_streaming=args.simulate_streaming, reverse_weight=args.reverse_weight, - context_graph=context_graph) + context_graph=context_graph, + blank_id=blank_id) for i, key in enumerate(keys): for mode, hyps in results.items(): tokens = hyps[i].tokens diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index b5c57a9f4..d790a0919 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -194,6 +194,7 @@ def decode( simulate_streaming: bool = False, reverse_weight: float = 0.0, context_graph: ContextGraph = None, + blank_id: int = 0, ) -> Dict[str, List[DecodeResult]]: """ Decode input speech @@ -233,10 +234,11 @@ def decode( self, encoder_out, encoder_mask, beam_size) if 'ctc_greedy_search' in methods: results['ctc_greedy_search'] = ctc_greedy_search( - ctc_probs, encoder_lens) + ctc_probs, encoder_lens, blank_id) if 'ctc_prefix_beam_search' in methods: ctc_prefix_result = ctc_prefix_beam_search(ctc_probs, encoder_lens, - beam_size, context_graph) + beam_size, context_graph, + blank_id) results['ctc_prefix_beam_search'] = ctc_prefix_result if 'attention_rescoring' in methods: # attention_rescoring depends on ctc_prefix_beam_search nbest diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index 762c9b0a0..9f8827a85 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -102,24 +102,26 @@ def update_context(self, context_graph, prefix_score, word_id): def ctc_greedy_search(ctc_probs: torch.Tensor, - ctc_lens: torch.Tensor) -> List[DecodeResult]: + ctc_lens: torch.Tensor, + blank_id: int = 0) -> List[DecodeResult]: batch_size = ctc_probs.shape[0] maxlen = ctc_probs.size(1) topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) mask = make_pad_mask(ctc_lens, maxlen) # (B, maxlen) - topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] scores = topk_prob.max(1) results = [] for hyp in hyps: - r = DecodeResult(remove_duplicates_and_blank(hyp)) + r = DecodeResult(remove_duplicates_and_blank(hyp, blank_id)) results.append(r) return results def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor, beam_size: int, context_graph: ContextGraph = None, + blank_id: int = 0, ) -> List[DecodeResult]: """ Returns: @@ -151,7 +153,7 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor, prob = logp[u].item() for prefix, prefix_score in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None - if u == 0: # blank + if u == blank_id: # blank next_score = next_hyps[prefix] next_score.s = log_add(next_score.s, prefix_score.score() + prob) diff --git a/wenet/utils/ctc_utils.py b/wenet/utils/ctc_utils.py index ee4ab43b1..084e32c1c 100644 --- a/wenet/utils/ctc_utils.py +++ b/wenet/utils/ctc_utils.py @@ -19,11 +19,11 @@ import torch -def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: +def remove_duplicates_and_blank(hyp: List[int], blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): - if hyp[cur] != 0: + if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: @@ -31,24 +31,24 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: return new_hyp -def replace_duplicates_with_blank(hyp: List[int]) -> List[int]: +def replace_duplicates_with_blank(hyp: List[int], blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): new_hyp.append(hyp[cur]) prev = cur cur += 1 - while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0: - new_hyp.append(0) + while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: + new_hyp.append(blank_id) cur += 1 return new_hyp -def gen_ctc_peak_time(hyp: List[int]) -> List[int]: +def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]: times = [] cur = 0 while cur < len(hyp): - if hyp[cur] != 0: + if hyp[cur] != blank_id: times.append(cur) prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: @@ -156,3 +156,18 @@ def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: output_alignment.append(y_insert_blank[state_seq[t, 0]]) return output_alignment + + +def get_blank_id(configs, symbol_table): + if 'ctc_conf' not in configs: + configs['ctc_conf'] = {} + + if '' in symbol_table: + if 'ctc_blank_id' in configs['ctc_conf']: + assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[''] + else: + configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] + else: + assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml" + + return configs, configs['ctc_conf']['ctc_blank_id'] diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 43c0cf12a..9343c2431 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -40,6 +40,7 @@ from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing +from wenet.utils.ctc_utils import get_blank_id def add_model_args(parser): @@ -211,16 +212,7 @@ def check_modify_and_save_config(args, configs, symbol_table): else: input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] - if 'ctc_conf' not in configs: - configs['ctc_conf'] = {} - - if '' in symbol_table: - if 'ctc_blank_id' in configs['ctc_conf']: - assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[''] - else: - configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] - else: - assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml" + configs, _ = get_blank_id(configs, symbol_table) configs['input_dim'] = input_dim configs['output_dim'] = configs['vocab_size']