Skip to content

Commit

Permalink
Rework test for greedy decoding
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Bataev <[email protected]>
  • Loading branch information
artbataev committed Jan 15, 2024
1 parent 3df991a commit 31649fa
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions tests/collections/asr/decoding/test_rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ def test_greedy_decoding_preserve_alignments(self, test_data_dir):
@pytest.mark.with_downloads
@pytest.mark.unit
@pytest.mark.parametrize("loop_labels", [True, False])
def test_greedy_decoding_preserve_alignments(self, test_data_dir, loop_labels: bool):
def test_batched_greedy_decoding_preserve_alignments(self, test_data_dir, loop_labels: bool):
"""Test batched greedy decoding using non-batched decoding as a reference"""
model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small')

beam = greedy_decode.GreedyBatchedRNNTInfer(
search_algo = greedy_decode.GreedyBatchedRNNTInfer(
model.decoder,
model.joint,
blank_index=model.joint.num_classes_with_blank - 1,
Expand All @@ -195,31 +196,43 @@ def test_greedy_decoding_preserve_alignments(self, test_data_dir, loop_labels: b
loop_labels=loop_labels,
)

etalon_search_algo = greedy_decode.GreedyRNNTInfer(
model.decoder,
model.joint,
blank_index=model.joint.num_classes_with_blank - 1,
max_symbols_per_step=5,
preserve_alignments=True,
)

enc_out = encoded
enc_len = encoded_len

with torch.no_grad():
hyps: list[rnnt_utils.Hypothesis] = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0]
hyp = decode_text_from_greedy_hypotheses(hyps, model.decoding)
hyp = hyp[0]
hyps: list[rnnt_utils.Hypothesis] = search_algo(encoder_output=enc_out, encoded_lengths=enc_len)[0]
hyp = decode_text_from_greedy_hypotheses(hyps, model.decoding)[0]
etalon_hyps: list[rnnt_utils.Hypothesis] = etalon_search_algo(
encoder_output=enc_out, encoded_lengths=enc_len
)[0]
etalon_hyp = decode_text_from_greedy_hypotheses(etalon_hyps, model.decoding)[0]

assert hyp.alignments is not None
assert etalon_hyp.alignments is not None

assert hyp.text == etalon_hyp.text
assert len(hyp.alignments) == len(etalon_hyp.alignments)

# Use the following commented print statements to check
# the alignment of other algorithms compared to the default
print("Text", hyp.text)
for t in range(len(hyp.alignments)):
t_u = []
for u in range(len(hyp.alignments[t])):
logp, label = hyp.alignments[t][u]
assert torch.is_tensor(logp)
assert torch.is_tensor(label)
etalon_logp, etalon_label = etalon_hyp.alignments[t][u]
assert label == etalon_label
assert torch.allclose(logp, etalon_logp, atol=1e-4, rtol=1e-4)

t_u.append(int(label))

print(f"Tokens at timestep {t} = {t_u}")
print()

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
Expand Down

0 comments on commit 31649fa

Please sign in to comment.