Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 4, 2024
1 parent 5cbf7ff commit 0345f30
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
29 changes: 23 additions & 6 deletions nemo/collections/asr/parts/submodules/fast_rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,15 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
(self.batch_size,), dtype=encoder_output_length.dtype, device=encoder_output_length.device
)

self.scores_cpu = torch.zeros((self.batch_size, self.max_time * self.max_symbols), dtype=encoder_output.dtype, device="cpu", pin_memory=True)
self.labels_cpu = torch.zeros((self.batch_size, self.max_time * self.max_symbols), dtype=torch.int64, device="cpu", pin_memory=True)
self.scores_cpu = torch.zeros(
(self.batch_size, self.max_time * self.max_symbols),
dtype=encoder_output.dtype,
device="cpu",
pin_memory=True,
)
self.labels_cpu = torch.zeros(
(self.batch_size, self.max_time * self.max_symbols), dtype=torch.int64, device="cpu", pin_memory=True
)
self.symbols_per_time_step_cpu = torch.zeros(self.max_time, dtype=torch.int64, device="cpu", pin_memory=True)

with torch.cuda.stream(torch.cuda.Stream()), torch.inference_mode():
Expand All @@ -237,7 +244,11 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
)
)

self.f = torch.zeros((self.batch_size, 1, self.encoder_output.shape[-1]), dtype=encoder_output.dtype, device=encoder_output.device)
self.f = torch.zeros(
(self.batch_size, 1, self.encoder_output.shape[-1]),
dtype=encoder_output.dtype,
device=encoder_output.device,
)
hidden = self.caller.decoder.initialize_state(self.f)
self.last_label = torch.full(
[self.batch_size], fill_value=self.caller._SOS, dtype=torch.long, device=encoder_output.device
Expand All @@ -247,8 +258,14 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
)
self.seq_idx_t = torch.zeros([1], dtype=torch.int64, device=encoder_output.device)

self.scores = torch.zeros((self.max_time * self.max_symbols, self.batch_size), dtype=encoder_output.dtype, device=encoder_output.device)
self.labels = torch.zeros((self.max_time * self.max_symbols, self.batch_size), dtype=torch.int64, device=encoder_output.device)
self.scores = torch.zeros(
(self.max_time * self.max_symbols, self.batch_size),
dtype=encoder_output.dtype,
device=encoder_output.device,
)
self.labels = torch.zeros(
(self.max_time * self.max_symbols, self.batch_size), dtype=torch.int64, device=encoder_output.device
)
self.symbols_per_time_step = torch.zeros(self.max_time, dtype=torch.int64, device=encoder_output.device)

# Get max sequence length
Expand Down Expand Up @@ -380,7 +397,7 @@ def __call__(
max_non_blank_symbols = self.symbols_per_time_step_cpu[t]
for counter in range(max_non_blank_symbols):
if self.labels_cpu[i, j] == caller._blank_index:
j += (max_non_blank_symbols - counter)
j += max_non_blank_symbols - counter
break
hypotheses[i].y_sequence.append(self.labels_cpu[i, j])
hypotheses[i].timestep.append(t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ipdb
import jiwer
import pytest
import torch
from omegaconf import OmegaConf, open_dict

Expand All @@ -15,39 +16,25 @@
from nemo.collections.asr.parts.submodules.fast_rnnt_greedy_decoding import RNNTGreedyDecodeFast
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInfer

from omegaconf import open_dict
from omegaconf import OmegaConf
import pytest

import jiwer


import torch

import tempfile
import sys, ipdb, traceback

@pytest.mark.parametrize(
("model_name", "batch_size", "use_subset"),
[
pytest.param("stt_en_fastconformer_transducer_large", 8, True),
# marks=pytest.mark.xfail(reason="Cannot instantiate graph with persistent RNN")),
("stt_en_fastconformer_transducer_large", 7, False),
("stt_en_fastconformer_transducer_xlarge", 16, False)
]
("stt_en_fastconformer_transducer_xlarge", 16, False),
],
)
def test_for_loop(model_name, batch_size, use_subset):
nemo_model = ASRModel.from_pretrained(model_name,
map_location="cuda")
nemo_model = ASRModel.from_pretrained(model_name, map_location="cuda")
conf = nemo_model.to_config_dict()
with open_dict(conf):
conf["decoding"]["greedy"]["max_symbols"] = 5

with tempfile.NamedTemporaryFile() as fp:
OmegaConf.save(config=conf, f=fp.name)
nemo_model = ASRModel.from_pretrained(
model_name, override_config_path=fp.name, map_location="cuda"
)
nemo_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda")
nemo_model.freeze()

nemo_model.preprocessor.featurizer.dither = 0.0
Expand All @@ -68,8 +55,7 @@ def test_for_loop(model_name, batch_size, use_subset):
torch.cuda.cudart().cudaProfilerStart()

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size,
num_workers=None)
actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None)

conf = nemo_model.to_config_dict()

Expand All @@ -78,9 +64,7 @@ def test_for_loop(model_name, batch_size, use_subset):
conf["decoding"]["greedy"]["max_symbols"] = 5
with tempfile.NamedTemporaryFile() as fp:
OmegaConf.save(config=conf, f=fp.name)
fast_model = ASRModel.from_pretrained(
model_name, override_config_path=fp.name, map_location="cuda"
)
fast_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda")

fast_model.freeze()

Expand All @@ -95,8 +79,7 @@ def test_for_loop(model_name, batch_size, use_subset):
fast_model.joint.freeze()

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
fast_transcripts, _ = fast_model.transcribe(audio_filepaths, batch_size=batch_size,
num_workers=None)
fast_transcripts, _ = fast_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None)

wer = jiwer.wer(actual_transcripts, fast_transcripts)

Expand All @@ -116,8 +99,7 @@ def test_for_loop(model_name, batch_size, use_subset):


def test_reproducibility():
nemo_model = ASRModel.from_pretrained("stt_en_fastconformer_transducer_xlarge",
map_location="cuda")
nemo_model = ASRModel.from_pretrained("stt_en_fastconformer_transducer_xlarge", map_location="cuda")

conf = nemo_model.to_config_dict()
with open_dict(conf):
Expand Down Expand Up @@ -194,4 +176,6 @@ def test_reproducibility():
encoded_1, encoded_len_1, return_hypotheses=False, partial_hypotheses=None,
)

import ipdb; ipdb.set_trace()
import ipdb

ipdb.set_trace()

0 comments on commit 0345f30

Please sign in to comment.