Skip to content

Commit

Permalink
feat(decode): support arbitrary blank_id during decoding (wenet-e2e#2193
Browse files Browse the repository at this point in the history
)

* feat(decode): support arbitrary blank_id

* feat(decode): refine code
  • Loading branch information
xingchensong authored Dec 6, 2023
1 parent 91c618d commit 571382b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 25 deletions.
10 changes: 8 additions & 2 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.context_graph import ContextGraph
from wenet.utils.ctc_utils import get_blank_id


def get_args():
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_args():
type=str,
default='',
help='''Context bias mode, selectable from the following
option: decoding-graphdeep-biasing''')
option: decoding-graph, deep-biasing''')
parser.add_argument('--context_list_path',
type=str,
default='',
Expand Down Expand Up @@ -216,6 +217,7 @@ def main():
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

# Init asr model from configs
args.jit = False
model, configs = init_model(args, configs)

use_cuda = args.gpu >= 0 and torch.cuda.is_available()
Expand All @@ -228,6 +230,9 @@ def main():
context_graph = ContextGraph(args.context_list_path, tokenizer.symbol_table,
args.bpe_model, args.context_graph_score)

_, blank_id = get_blank_id(configs, tokenizer.symbol_table)
logging.info("blank_id is {}".format(blank_id))

# TODO(Dinghao Zhou): Support RNN-T related decoding
# TODO(Lv Xiang): Support k2 related decoding
# TODO(Kaixun Huang): Support context graph
Expand Down Expand Up @@ -255,7 +260,8 @@ def main():
ctc_weight=args.ctc_weight,
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight,
context_graph=context_graph)
context_graph=context_graph,
blank_id=blank_id)
for i, key in enumerate(keys):
for mode, hyps in results.items():
tokens = hyps[i].tokens
Expand Down
6 changes: 4 additions & 2 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def decode(
simulate_streaming: bool = False,
reverse_weight: float = 0.0,
context_graph: ContextGraph = None,
blank_id: int = 0,
) -> Dict[str, List[DecodeResult]]:
""" Decode input speech
Expand Down Expand Up @@ -233,10 +234,11 @@ def decode(
self, encoder_out, encoder_mask, beam_size)
if 'ctc_greedy_search' in methods:
results['ctc_greedy_search'] = ctc_greedy_search(
ctc_probs, encoder_lens)
ctc_probs, encoder_lens, blank_id)
if 'ctc_prefix_beam_search' in methods:
ctc_prefix_result = ctc_prefix_beam_search(ctc_probs, encoder_lens,
beam_size, context_graph)
beam_size, context_graph,
blank_id)
results['ctc_prefix_beam_search'] = ctc_prefix_result
if 'attention_rescoring' in methods:
# attention_rescoring depends on ctc_prefix_beam_search nbest
Expand Down
10 changes: 6 additions & 4 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,26 @@ def update_context(self, context_graph, prefix_score, word_id):


def ctc_greedy_search(ctc_probs: torch.Tensor,
ctc_lens: torch.Tensor) -> List[DecodeResult]:
ctc_lens: torch.Tensor,
blank_id: int = 0) -> List[DecodeResult]:
batch_size = ctc_probs.shape[0]
maxlen = ctc_probs.size(1)
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
mask = make_pad_mask(ctc_lens, maxlen) # (B, maxlen)
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
scores = topk_prob.max(1)
results = []
for hyp in hyps:
r = DecodeResult(remove_duplicates_and_blank(hyp))
r = DecodeResult(remove_duplicates_and_blank(hyp, blank_id))
results.append(r)
return results


def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
beam_size: int, context_graph: ContextGraph = None,
blank_id: int = 0,
) -> List[DecodeResult]:
"""
Returns:
Expand Down Expand Up @@ -151,7 +153,7 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
prob = logp[u].item()
for prefix, prefix_score in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if u == 0: # blank
if u == blank_id: # blank
next_score = next_hyps[prefix]
next_score.s = log_add(next_score.s,
prefix_score.score() + prob)
Expand Down
29 changes: 22 additions & 7 deletions wenet/utils/ctc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,36 @@
import torch


def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
def remove_duplicates_and_blank(hyp: List[int], blank_id: int = 0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
def replace_duplicates_with_blank(hyp: List[int], blank_id: int = 0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
new_hyp.append(hyp[cur])
prev = cur
cur += 1
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0:
new_hyp.append(0)
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id:
new_hyp.append(blank_id)
cur += 1
return new_hyp


def gen_ctc_peak_time(hyp: List[int]) -> List[int]:
def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]:
times = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
if hyp[cur] != blank_id:
times.append(cur)
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
Expand Down Expand Up @@ -156,3 +156,18 @@ def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
output_alignment.append(y_insert_blank[state_seq[t, 0]])

return output_alignment


def get_blank_id(configs, symbol_table):
if 'ctc_conf' not in configs:
configs['ctc_conf'] = {}

if '<blank>' in symbol_table:
if 'ctc_blank_id' in configs['ctc_conf']:
assert configs['ctc_conf']['ctc_blank_id'] == symbol_table['<blank>']
else:
configs['ctc_conf']['ctc_blank_id'] = symbol_table['<blank>']
else:
assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml"

return configs, configs['ctc_conf']['ctc_blank_id']
12 changes: 2 additions & 10 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id


def add_model_args(parser):
Expand Down Expand Up @@ -211,16 +212,7 @@ def check_modify_and_save_config(args, configs, symbol_table):
else:
input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins']

if 'ctc_conf' not in configs:
configs['ctc_conf'] = {}

if '<blank>' in symbol_table:
if 'ctc_blank_id' in configs['ctc_conf']:
assert configs['ctc_conf']['ctc_blank_id'] == symbol_table['<blank>']
else:
configs['ctc_conf']['ctc_blank_id'] = symbol_table['<blank>']
else:
assert 'ctc_blank_id' in configs['ctc_conf'], "PLZ set ctc_blank_id in yaml"
configs, _ = get_blank_id(configs, symbol_table)

configs['input_dim'] = input_dim
configs['output_dim'] = configs['vocab_size']
Expand Down

0 comments on commit 571382b

Please sign in to comment.