Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli] add timestamp support #2082

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.utils.ctc_utils import gen_timestamps_from_peak
from wenet.utils.file_utils import read_symbol_table
from wenet.transformer.search import (attention_rescoring,
ctc_prefix_beam_search)
Expand All @@ -33,7 +34,7 @@ def __init__(self, language: str):
symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in symbol_table.items()}

def transcribe(self, audio_file: str):
def transcribe(self, audio_file: str, token_times: bool = False):
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
Expand All @@ -46,10 +47,26 @@ def transcribe(self, audio_file: str):
encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1)
encoder_lens = torch.tensor([encoder_out.size(1)], dtype=torch.long)
ctc_probs = self.model.ctc_activation(encoder_out)
ctc_prefix_results = ctc_prefix_beam_search(ctc_probs, encoder_lens,
10)
results = attention_rescoring(self.model, ctc_prefix_results,
encoder_out, encoder_lens, 0.3, 0.5)
hyp = [self.char_dict[x] for x in results[0].tokens]
result = ''.join(hyp)
ctc_prefix_results = ctc_prefix_beam_search(ctc_probs, encoder_lens, 2)
rescoring_results = attention_rescoring(self.model, ctc_prefix_results,
encoder_out, encoder_lens, 0.3,
0.5)
res = rescoring_results[0]
result = {}
result['rec'] = ''.join([self.char_dict[x] for x in res.tokens])

if token_times:
frame_rate = self.model.subsampling_rate(
) * 0.01 # 0.01 seconds per frame
max_duration = encoder_out.size(1) * frame_rate
times = gen_timestamps_from_peak(res.times, max_duration,
frame_rate, 1.0)
times_info = []
for i, x in enumerate(res.tokens):
times_info.append({
'token': self.char_dict[x],
'start': times[i][0],
'end': times[i][1]
})
result['times'] = times_info
return result
8 changes: 6 additions & 2 deletions wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from wenet.cli.model import Model


def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('audio_file', help='audio file to transcribe')
Expand All @@ -26,15 +27,18 @@ def get_args():
],
default='chinese',
help='language type')

parser.add_argument('-t',
'--gen_token_times',
action='store_true',
help='whether to generate token times')
args = parser.parse_args()
return args


def main():
args = get_args()
model = Model(args.language)
result = model.transcribe(args.audio_file)
result = model.transcribe(args.audio_file, args.gen_token_times)
print(result)


Expand Down
84 changes: 66 additions & 18 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import defaultdict
from typing import List, Optional, Tuple
from typing import List, Optional

import torch
from torch.nn.utils.rnn import pad_sequence
Expand All @@ -30,9 +30,10 @@ def __init__(self,
score: float = 0.0,
confidence: float = 0.0,
tokens_confidence: List[float] = None,
times: List[Tuple[float, float]] = None,
times: List[int] = None,
nbest: List[List[int]] = None,
nbest_scores: List[float] = None):
nbest_scores: List[float] = None,
nbest_times: List[List[int]] = None):
"""
Args:
tokens: decode token list
Expand All @@ -42,6 +43,7 @@ def __init__(self,
times: timestamp of each token, list of (start, end)
nbest: nbest result
nbest_scores: score of each nbest
nbest_times:
"""
self.tokens = tokens
self.score = score
Expand All @@ -50,18 +52,33 @@ def __init__(self,
self.times = times
self.nbest = nbest
self.nbest_scores = nbest_scores
self.nbest_times = nbest_times


class PrefixScore:
""" For CTC prefix beam search """
def __init__(self, s=float('-inf'), ns=float('-inf')):

def __init__(self,
s: float = float('-inf'),
ns: float = float('-inf'),
v_s: float = float('-inf'),
v_ns: float = float('-inf')):
self.s = s # blank_ending_score
self.ns = ns # none_blank_ending_score
self.v_s = v_s # viterbi blank ending score
self.v_ns = v_ns # viterbi none blank ending score
self.cur_token_prob = float('-inf') # prob of current token
self.times_s = [] # times of viterbi blank path
self.times_ns = [] # times of viterbi none blank path

def score(self):
return log_add(self.s, self.ns)

def viterbi_score(self):
return self.v_s if self.v_s > self.v_ns else self.v_ns

def times(self):
return self.times_s if self.v_s > self.v_ns else self.times_ns


def ctc_greedy_search(ctc_probs: torch.Tensor,
ctc_lens: torch.Tensor) -> List[DecodeResult]:
Expand Down Expand Up @@ -92,7 +109,7 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
for i in range(batch_size):
ctc_prob = ctc_probs[i]
num_t = ctc_lens[i]
cur_hyps = [(tuple(), PrefixScore(0.0, -float('inf')))]
cur_hyps = [(tuple(), PrefixScore(0.0, -float('inf'), 0.0, 0.0))]
# 2. CTC beam search step by step
for t in range(0, num_t):
logp = ctc_prob[t] # (vocab_size,)
Expand All @@ -106,22 +123,46 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,
for prefix, prefix_score in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if u == 0: # blank
next_hyps[prefix].s = log_add(
next_hyps[prefix].s,
prefix_score.score() + prob)
next_score = next_hyps[prefix]
next_score.s = log_add(next_score.s,
prefix_score.score() + prob)
next_score.v_s = prefix_score.viterbi_score() + prob
next_score.times_s = prefix_score.times().copy()
elif u == last:
# Update *uu -> *u;
next_hyps[prefix].ns = log_add(next_hyps[prefix].ns,
prefix_score.ns + prob)
next_score1 = next_hyps[prefix]
next_score1.ns = log_add(next_score1.ns,
prefix_score.ns + prob)
if next_score1.v_ns < prefix_score.v_ns + prob:
next_score1.vs_ns = prefix_score.v_ns + prob
if next_score1.cur_token_prob < prob:
next_score1.cur_token_prob = prob
next_score1.times_ns = prefix_score.times_ns.copy(
)
next_score1.times_ns[-1] = t

# Update *u-u -> *uu, - is for blank
n_prefix = prefix + (u, )
next_hyps[n_prefix].ns = log_add(
next_hyps[n_prefix].ns, prefix_score.s + prob)
next_score2 = next_hyps[n_prefix]
next_score2.ns = log_add(next_score2.ns,
prefix_score.s + prob)
if next_score2.v_ns < prefix_score.v_s + prob:
next_score2.v_ns = prefix_score.v_s + prob
next_score2.cur_token_prob = prob
next_score2.times_ns = prefix_score.times_s.copy()
next_score2.times_ns.append(t)
else:
n_prefix = prefix + (u, )
next_hyps[n_prefix].ns = log_add(
next_hyps[n_prefix].ns,
prefix_score.score() + prob)
next_score = next_hyps[n_prefix]
next_score.ns = log_add(next_score.ns,
prefix_score.score() + prob)
if next_score.v_ns < prefix_score.viterbi_score(
) + prob:
next_score.v_ns = prefix_score.viterbi_score(
) + prob
next_score.cur_token_prob = prob
next_score.times_ns = prefix_score.times().copy()
next_score.times_ns.append(t)
# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: x[1].score(),
Expand All @@ -130,13 +171,17 @@ def ctc_prefix_beam_search(ctc_probs: torch.Tensor, ctc_lens: torch.Tensor,

nbest = [y[0] for y in cur_hyps]
nbest_scores = [y[1].score() for y in cur_hyps]
nbest_times = [y[1].times() for y in cur_hyps]
best = nbest[0]
best_score = nbest_scores[0]
best_time = nbest_times[0]
results.append(
DecodeResult(tokens=best,
score=best_score,
times=best_time,
nbest=nbest,
nbest_scores=nbest_scores))
nbest_scores=nbest_scores,
nbest_times=nbest_times))
return results


Expand Down Expand Up @@ -283,5 +328,8 @@ def attention_rescoring(
if score > best_score:
best_score = score
best_index = i
results.append(DecodeResult(hyps[best_index], best_score))
results.append(
DecodeResult(hyps[best_index],
best_score,
times=ctc_prefix_results[b].nbest_times[best_index]))
return results
49 changes: 41 additions & 8 deletions wenet/utils/ctc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Tuple

import numpy as np

import torch


def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
Expand All @@ -43,6 +44,39 @@ def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
return new_hyp


def gen_timestamps_from_peak(
peaks: List[int],
max_duration: float,
frame_rate: float = 0.04,
max_token_duration: float = 1.0,
) -> List[Tuple[float, float]]:
"""
Args:
peaks: ctc peaks time stamp
max_duration: max_duration of the sentence
frame_rate: frame rate of every time stamp, in seconds
max_token_duration: max duration of the token, in seconds
Returns:
list(start, end) of each token
"""
times = []
half_max = max_token_duration / 2
for i in range(len(peaks)):
if i == 0:
start = max(0, peaks[0] * frame_rate - half_max)
else:
start = max((peaks[i - 1] + peaks[i]) / 2 * frame_rate,
peaks[i] * frame_rate - half_max)

if i == len(peaks) - 1:
end = min(max_duration, peaks[-1] * frame_rate + half_max)
else:
end = min((peaks[i] + peaks[i + 1]) / 2 * frame_rate,
peaks[i] * frame_rate + half_max)
times.append((start, end))
return times


def insert_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
Expand All @@ -52,9 +86,8 @@ def insert_blank(label, blank_id=0):
label = np.append(label, label[0])
return label

def forced_align(ctc_probs: torch.Tensor,
y: torch.Tensor,
blank_id=0) -> list:

def forced_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
"""ctc forced alignment.

Args:
Expand All @@ -70,9 +103,8 @@ def forced_align(ctc_probs: torch.Tensor,

log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
log_alpha = log_alpha - float('inf') # log of zero
state_path = (torch.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
) # state path
state_path = torch.zeros((ctc_probs.size(0), len(y_insert_blank)),
dtype=torch.int16) - 1 # state path

# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
Expand All @@ -92,7 +124,8 @@ def forced_align(ctc_probs: torch.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
log_alpha[
t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
state_path[t, s] = prev_state[torch.argmax(candidates)]

state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
Expand Down