diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 3cd06c48e5249..2afa964b50b41 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -429,6 +429,10 @@ def mask_select_states( return None return [states[0][mask]] + def batch_copy_states_mask( + self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], mask: torch.Tensor): + torch.where(mask.unsqueeze(1), new_states[0], old_states[0], out=old_states[0]) + def batch_score_hypothesis( self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: @@ -1121,6 +1125,8 @@ def batch_copy_states_mask( self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], mask: torch.Tensor, ) -> List[torch.Tensor]: """Copy states from new state to old state at certain indices specified by mask + Unlike `mask_select_states` and `batch_copy_states`, this does not require any + synchronization with the host. Args: old_states(list): packed decoder states @@ -1141,8 +1147,6 @@ def batch_copy_states_mask( mask.unsqueeze(1).unsqueeze(0), new_states[state_id], old_states[state_id], out=old_states[state_id] ) - return old_states - # Adapter method overrides def add_adapter(self, name: str, cfg: DictConfig): # Update the config with correct input dim diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index 01f23d682da2b..ab2f938447309 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -342,3 +342,24 @@ def mask_select_states(self, states: Any, mask: torch.Tensor) -> Any: states filtered by mask (same type as `states`) """ raise NotImplementedError() + + def batch_copy_states_mask( + self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], mask: torch.Tensor): + """Copy states from new state to old state at certain indices specified by mask. + Unlike `mask_select_states` and `batch_copy_states`, this does not require any + synchronization with the host. + + Args: + old_states(list): packed decoder states + (L x B x H, L x B x H) + + new_states: packed decoder states + (L x B x H, L x B x H) + + mask (tensor): Boolean tensor mask of length B. If True, copy from new_States to old_states at that index + + Returns: + batch of decoder states with partial copy at ids (or a specific value). + (L x B x H, L x B x H) + """ + raise NotImplementedError() diff --git a/nemo/collections/asr/parts/submodules/fast_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py similarity index 80% rename from nemo/collections/asr/parts/submodules/fast_rnnt_greedy_decoding.py rename to nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index 51041910011e5..ace4b407534da 100644 --- a/nemo/collections/asr/parts/submodules/fast_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -12,57 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.core.utils.cuda_python_utils import check_cuda_python_cuda_graphs_conditional_nodes_supported - -check_cuda_python_cuda_graphs_conditional_nodes_supported() - import contextlib -import ctypes -import time -from dataclasses import dataclass, field -from itertools import product -from typing import List, Optional, Tuple, Union import numpy as np import torch -from cuda import cuda, cudart, nvrtc -from omegaconf import DictConfig, OmegaConf +try: + from cuda import cuda, cudart, nvrtc + HAVE_CUDA_PYTHON = True +except ImportError: + HAVE_CUDA_PYTHON = False +from typing import List, Optional -from nemo.collections.asr.modules import rnnt_abstract from nemo.collections.asr.parts.utils import rnnt_utils -from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin -from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType +from nemo.core.utils.cuda_python_utils import check_cuda_python_cuda_graphs_conditional_nodes_supported, assert_drv, cu_call from nemo.utils import logging - -def ASSERT_DRV(err): - if isinstance(err, cuda.CUresult): - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("Cuda Error: {}".format(err)) - elif isinstance(err, nvrtc.nvrtcResult): - if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError("Nvrtc Error: {}".format(err)) - else: - raise RuntimeError("Unknown error type: {}".format(err)) - - -def cu_call(f_call_out): - error, *others = f_call_out - if error != cudart.cudaError_t.cudaSuccess: - # import ipdb; ipdb.set_trace() - raise Exception(f"CUDA failure! {error}") - else: - # print("GALVEZ:", others) - return tuple(others) - - def run_nvrtc(kernel_string, kernel_name): err, prog = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"while_loop_conditional.cu", 0, [], []) - - ASSERT_DRV(err) - + assert_drv(err) # Compile program # Not specifying --gpu-architecture will default us to a fairly low compute capability, which is a safe bet. # Otherwise, there are ways to query the current device's compute capability. @@ -72,28 +41,30 @@ def run_nvrtc(kernel_string, kernel_name): err, size = nvrtc.nvrtcGetProgramLogSize(prog) buf = b" " * size (err,) = nvrtc.nvrtcGetProgramLog(prog, buf) - # print(buf.decode("utf-8")) - ASSERT_DRV(err) + assert_drv(err) # Get PTX from compilation err, ptxSize = nvrtc.nvrtcGetPTXSize(prog) - ASSERT_DRV(err) + assert_drv(err) ptx = b" " * ptxSize (err,) = nvrtc.nvrtcGetPTX(prog, ptx) - ASSERT_DRV(err) + assert_drv(err) - # print("GALVEZ:PTX=") - # print(ptx.decode("utf-8")) ptx = np.char.array(ptx) err, module = cuda.cuModuleLoadData(ptx.ctypes.data) - ASSERT_DRV(err) + assert_drv(err) err, kernel = cuda.cuModuleGetFunction(module, kernel_name) - ASSERT_DRV(err) + assert_drv(err) return kernel def create_outer_for_loop_kernel(): + """ + Creates a kernel that evaluates whether or not to enter the for loop body. + Effectively substitutes for `for time_idx in range(trip_count)` + such that that for loop can run on a GPU. + """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; @@ -108,12 +79,12 @@ def create_outer_for_loop_kernel(): return run_nvrtc(kernel_string, b"for_loop_conditional") -# Observations: If cudaGraphSetConditional is true once, the kernel never completes... -# The GPU is doing *something*. I'm just not sure what... -# No way to query that at runtime... - - -def create_while_loop_kernel(): +def create_inner_while_loop_kernel(): + """ + Evaluates whether or not to keep evaluating the inner while loop body. + Continue until all elements of the batch output blank or the while loop + has run max_symbols times. + """ kernel_string = r"""\ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle; @@ -130,6 +101,14 @@ def create_while_loop_kernel(): @contextlib.contextmanager def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle): + """ + Even though we add a conditional node only once, we need to + capture the kernel that calls cudaGraphSetConditional() both + before in the parent graph containing the while loop body graph + and after the rest of the while loop body graph (because we need + to decide both whether to enter the loop, and also whether to + execute the next iteration of the loop). + """ capture_status, _, graph, _, _ = cu_call(cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream().cuda_stream)) assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive @@ -153,6 +132,7 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi driver_params.conditional.ctx = ctx # Use driver API here because of bug in cuda-python runtime API: https://github.com/NVIDIA/cuda-python/issues/55 + # TODO: Change call to this after fix goes in: # node, = cu_call(cudart.cudaGraphAddNode(graph, dependencies, len(dependencies), driver_params)) (node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params)) body_graph = driver_params.conditional.phGraph_out[0] @@ -190,8 +170,11 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi torch.cuda.set_stream(previous_stream) -class RNNTGreedyDecodeFast: +class RNNTGreedyDecodeCudaGraph: def __init__(self, max_symbols: int, cuda_device, caller): + if not HAVE_CUDA_PYTHON: + check_cuda_python_cuda_graphs_conditional_nodes_supported() + assert max_symbols is not None self.symbols_added_t = torch.tensor(0, dtype=torch.int64, device=cuda_device) @@ -208,8 +191,9 @@ def __init__(self, max_symbols: int, cuda_device, caller): self.encoder_output_length = None self.f = None - # Reasonable default maximum time. 375 frames * 40ms / frame = 15 seconds - # This affects only the size of the CPU-pinned memory buffers + # Reasonable default maximum time. 375 frames * (80ms / frame) = 30 seconds + # 80ms is the frame size of recent fastconformer models + # This does not affect correctness. self.max_time = 375 self.batch_size = 0 @@ -218,15 +202,24 @@ def __init__(self, max_symbols: int, cuda_device, caller): self.graph = None self.graph_exec = None + self.first_call = True + self.caller = caller def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_length): - # We need to call _greedy_decode_blank_as_pad at least once - # before hand in order to make sure that pytorch is - # "initialize". - self.caller._greedy_decode_blank_as_pad_loop_frames( - encoder_output, encoder_output_length, encoder_output.device - ) + if self.first_call: + # We need to call the original _greedy_decode_blank_as_pad + # implementation at least once beforehand in order to make + # sure that pytorch is "initialized". Pytorch may be + # uninitialized if this code runs before any other pytorch + # operation in this process. Pytorch often lazily + # initializes things like a cudnnHandle_t via + # cudnnCreate(), which can involve synchronizing with the + # host. Such actions are not stream capturable to a graph. + self.caller._greedy_decode_blank_as_pad_loop_frames( + encoder_output, encoder_output_length, encoder_output.device + ) + self.first_call = False self.max_time = max(self.max_time, max_time) self.batch_size = max(self.batch_size, batch_size) @@ -255,6 +248,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len self.graph = torch.cuda.CUDAGraph() + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. with torch.cuda.stream(torch.cuda.Stream()), torch.inference_mode(), torch.cuda.graph(self.graph): self.f = torch.zeros( @@ -307,7 +301,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len torch.ge(self.time_idx_t, self.encoder_output_length, out=self.blank_mask) - while_loop_kernel = create_while_loop_kernel() + while_loop_kernel = create_inner_while_loop_kernel() (while_loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) not_blank_ptr = np.array([self.not_all_blank_t.data_ptr()], dtype=np.uint64) symbols_added_ptr = np.array([self.symbols_added_t.data_ptr()], dtype=np.uint64) @@ -369,58 +363,43 @@ def __call__( device: torch.device, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, ): - - # Need to copy x and out_len into "staging buffers", or do a graph update. - if partial_hypotheses is not None: - raise NotImplementedError("`partial_hypotheses` support is not supported") + raise NotImplementedError("`partial_hypotheses` support is not available with cuda graphs (but could be)") + + if self.caller.preserve_alignments: + raise NotImplementedError("`preserve_alignments` support is not available with cuda graphs (but could be)") - assert not self.caller.preserve_alignments - assert not self.caller.preserve_frame_confidence + if self.caller.preserve_frame_confidence: + raise NotImplementedError("`preserve_frame_confidence` support is not available with cuda graphs (but could be)") batch_size = x.shape[0] - # ideally we would use out_len.max() here... + # We could use out_len.max() here instead of x.shape[1], in + # case for some reason the user passes in a larger buffer than + # required, since we know that `out_len.max() <= x.shape[1]`. max_time = x.shape[1] if torch.is_autocast_enabled(): x = x.to(torch.get_autocast_gpu_dtype()) - # What do we do if batch_size is actually smaller for the - # input? Is this a problem? the clone() call will fail... if max_time > self.max_time or batch_size > self.batch_size: self._reinitialize(max_time, batch_size, x, out_len) - torch.cuda.nvtx.range_push("Graph exec") self.encoder_output[: x.shape[0], : x.shape[1], ...].copy_(x) self.encoder_output_length[: out_len.shape[0]].copy_(out_len) self.graph.replay() torch.cuda.current_stream().synchronize() - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("Copy data out") - - torch.cuda.nvtx.range_push("scores_cpu mask") self.scores_cpu[self.labels_cpu == self.caller._blank_index] = 0.0 - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("scores_cpu sum") - total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1, 2)) - torch.cuda.nvtx.range_pop() + total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1,2)) tokens_per_timestep = (self.labels_cpu != self.caller._blank_index).sum(axis=-1) - torch.cuda.nvtx.range_push("repeat_interleave") - timesteps_packed = torch.repeat_interleave( - torch.arange(self.max_time).repeat(self.batch_size), tokens_per_timestep.flatten() - ) + timesteps_packed = torch.repeat_interleave(torch.arange(self.max_time).repeat(self.batch_size), tokens_per_timestep.flatten()) timestep_segments = tokens_per_timestep.sum(axis=-1) - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("labels") - valid_labels_mask = self.labels_cpu != self.caller._blank_index - labels_segments = valid_labels_mask.sum(axis=(1, 2)) + valid_labels_mask = (self.labels_cpu != self.caller._blank_index) + labels_segments = valid_labels_mask.sum(axis=(1,2)) labels_packed = self.labels_cpu[valid_labels_mask] - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("Convert to Hypotheses") hypotheses = [ rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batch_size) ] @@ -433,8 +412,5 @@ def __call__( hypotheses[i].score = float(total_scores[i]) hypotheses[i].y_sequence = labels_packed[labels_start : labels_start + labels_segments[i]].tolist() labels_start += labels_segments[i] - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_pop() return hypotheses diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 01decdc49901e..0b5a2f4222ce4 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -621,9 +621,9 @@ def __init__( confidence_method_cfg=confidence_method_cfg, ) elif use_cuda_graph_decoder: - from nemo.collections.asr.parts.submodules.fast_rnnt_greedy_decoding import RNNTGreedyDecodeFast + from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph - self._greedy_decode = RNNTGreedyDecodeFast(max_symbols_per_step, torch.device("cuda"), self) + self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, torch.device("cuda"), self) else: # previous algo: loop over frames self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index 45d5f61c1ad93..b6423d1b9c65f 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -20,7 +20,7 @@ def check_cuda_python_cuda_graphs_conditional_nodes_supported(): try: from cuda import cuda - except ModuleNotFoundError: + except ImportError: raise ModuleNotFoundError("Please do `pip install cuda-python>=12.3`") from cuda import __version__ as cuda_python_version @@ -55,5 +55,35 @@ def skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported(): import pytest pytest.skip( - f"Test using cuda grapphs is being skipped because cuda graphs aren't supported. Error message: {e}" + f"Test using cuda graphs with conditional nodes is being skipped because cuda graphs with conditional nodes aren't supported. Error message: {e}" ) + + +def assert_drv(err): + """ + Throws an exception if the return value of a cuda-python call is not success. + """ + from cuda import cuda, cudart, nvrtc + if isinstance(err, cuda.CUresult): + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + elif isinstance(err, nvrtc.nvrtcResult): + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("Nvrtc Error: {}".format(err)) + elif isinstance(err, cudart.cudaError_t): + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError("Cuda Runtime Error: {}".format(err)) + else: + raise RuntimeError("Unknown error type: {}".format(err)) + + +def cu_call(f_call_out): + """ + Makes calls to cuda-python's functions inside cuda.cuda more python by throwing an exception if they return a status which is not cudaSuccess + """ + from cuda import cudart + error, *others = f_call_out + if error != cudart.cudaError_t.cudaSuccess: + raise Exception(f"CUDA failure! {error}") + else: + return tuple(others) diff --git a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py new file mode 100644 index 0000000000000..242147c5aebed --- /dev/null +++ b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py @@ -0,0 +1,62 @@ +import glob +import json +import os +import tempfile + +import jiwer +import pytest +import torch +from omegaconf import OmegaConf, open_dict + +from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInfer +from nemo.core.utils.cuda_python_utils import \ + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported + + +@pytest.mark.parametrize( + ("model_name", "batch_size", "enable_bfloat16"), + [ + ("stt_en_fastconformer_transducer_large", 7, True), + ("stt_en_fastconformer_transducer_large", 7, False), + pytest.param("stt_en_fastconformer_transducer_large", 8, True, + marks=pytest.mark.xfail(reason="""Cannot instantiate the +body cuda graph of a conditional node with a persistent kernel (in this case, +a persistent LSTM), which is triggered in cudnn by using a batch size of 8.""")), + ], +) +def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16): + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() + + conf = ASRModel.from_pretrained(model_name, return_config=True) + with open_dict(conf): + conf["decoding"]["greedy"]["max_symbols"] = 5 + conf["decoding"]["greedy"]["loop_labels"] = False + conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = False + + with tempfile.NamedTemporaryFile() as fp: + OmegaConf.save(config=conf, f=fp.name) + nemo_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda") + + audio_filepaths = glob.glob("tests/.data/asr/test/an4/wav/*.wav") + + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): + actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) + + with open_dict(conf): + conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True + + nemo_model.change_decoding_strategy(conf["decoding"]) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): + fast_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) + + wer = jiwer.wer(actual_transcripts, fast_transcripts) + + assert wer <= 1e-3, "Cuda graph greedy decoder should match original decoder implementation." + + for actual, fast in zip(actual_transcripts, fast_transcripts): + if actual != fast: + print("erroneous samples:") + print("Original transcript:", actual) + print("New transcript:", fast) diff --git a/tests/collections/asr/decoding/test_fast_rnnt_greedy_decoding.py b/tests/collections/asr/decoding/test_fast_rnnt_greedy_decoding.py deleted file mode 100644 index 03f3e141c1d00..0000000000000 --- a/tests/collections/asr/decoding/test_fast_rnnt_greedy_decoding.py +++ /dev/null @@ -1,187 +0,0 @@ -import glob -import json -import os -import sys -import tempfile -import traceback - -import ipdb -import jiwer -import pytest -import torch -from omegaconf import OmegaConf, open_dict - -from nemo.collections.asr.models import ASRModel -from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel -from nemo.collections.asr.parts.submodules.fast_rnnt_greedy_decoding import RNNTGreedyDecodeFast -from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInfer - - -@pytest.mark.parametrize( - ("model_name", "batch_size", "use_subset", "enable_bfloat16"), - [ - # The xfail is not catching for some reason... - # pytest.param("stt_en_fastconformer_transducer_large", 8, True, True, marks=pytest.mark.xfail(reason="Cannot instantiate graph with persistent RNN")), - ("stt_en_fastconformer_transducer_large", 7, False, True), - ("stt_en_fastconformer_transducer_xlarge", 16, False, True), - ("stt_en_fastconformer_transducer_xlarge", 16, True, True), - ("stt_en_fastconformer_transducer_xlarge", 128, True, True), - ("stt_en_fastconformer_transducer_xlarge", 16, False, False), - ], -) -def test_for_loop(model_name, batch_size, use_subset, enable_bfloat16): - nemo_model = ASRModel.from_pretrained(model_name, map_location="cuda") - conf = nemo_model.to_config_dict() - with open_dict(conf): - conf["decoding"]["greedy"]["max_symbols"] = 5 - conf["decoding"]["greedy"]["loop_labels"] = False - - with tempfile.NamedTemporaryFile() as fp: - OmegaConf.save(config=conf, f=fp.name) - nemo_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda") - nemo_model.freeze() - - nemo_model.preprocessor.featurizer.dither = 0.0 - nemo_model.preprocessor.featurizer.pad_to = 0 - - # Switch model to evaluation mode - nemo_model.eval() - # Freeze the encoder and decoder modules - nemo_model.encoder.freeze() - nemo_model.decoder.freeze() - nemo_model.joint.freeze() - - audio_filepaths = glob.glob("/home/dgalvez/scratch/data/LibriSpeech/test-clean-processed/*.wav") - - if use_subset: - audio_filepaths = audio_filepaths[: batch_size * 4] - - conf = nemo_model.to_config_dict() - - with open_dict(conf): - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True - conf["decoding"]["greedy"]["loop_labels"] = False - conf["decoding"]["greedy"]["max_symbols"] = 5 - with tempfile.NamedTemporaryFile() as fp: - OmegaConf.save(config=conf, f=fp.name) - fast_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda") - - fast_model.freeze() - - fast_model.preprocessor.featurizer.dither = 0.0 - fast_model.preprocessor.featurizer.pad_to = 0 - - # Switch model to evaluation mode - fast_model.eval() - # Freeze the encoder and decoder modules - fast_model.encoder.freeze() - fast_model.decoder.freeze() - fast_model.joint.freeze() - - torch.cuda.cudart().cudaProfilerStart() - - with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): - fast_transcripts, _ = fast_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): - actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) - - wer = jiwer.wer(actual_transcripts, fast_transcripts) - - assert wer <= 1e-3 - - print("GALVEZ:wer=", wer) - - for actual, fast in zip(actual_transcripts, fast_transcripts): - if actual != fast: - print("GALVEZ:erroneous!") - print(actual) - print(fast) - - torch.cuda.cudart().cudaProfilerStop() - - # import ipdb; ipdb.set_trace() - - -def test_reproducibility(): - nemo_model = ASRModel.from_pretrained("stt_en_fastconformer_transducer_xlarge", map_location="cuda") - - conf = nemo_model.to_config_dict() - with open_dict(conf): - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True - conf["decoding"]["greedy"]["loop_labels"] = False - conf["decoding"]["greedy"]["max_symbols"] = 5 - - with tempfile.NamedTemporaryFile() as fp: - OmegaConf.save(config=conf, f=fp.name) - nemo_model = ASRModel.from_pretrained( - "stt_en_fastconformer_transducer_xlarge", override_config_path=fp.name, map_location="cuda" - ) - - device = "cuda" - - paths2audio_files = glob.glob("/home/dgalvez/scratch/data/LibriSpeech/test-clean-processed/*.wav") - batch_size = 16 - num_workers = 2 - - nemo_model.preprocessor.featurizer.dither = 0.0 - nemo_model.preprocessor.featurizer.pad_to = 0 - - # Switch model to evaluation mode - nemo_model.eval() - # Freeze the encoder and decoder modules - nemo_model.encoder.freeze() - nemo_model.decoder.freeze() - nemo_model.joint.freeze() - # Work in tmp directory - will store manifest file there - with tempfile.TemporaryDirectory() as tmpdir: - with open(os.path.join(tmpdir, 'manifest.json'), 'w', encoding='utf-8') as fp: - for audio_file in paths2audio_files: - entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''} - fp.write(json.dumps(entry) + '\n') - - config = { - 'paths2audio_files': paths2audio_files, - 'batch_size': batch_size, - 'temp_dir': tmpdir, - 'num_workers': num_workers, - 'channel_selector': None, - } - - temporary_datalayer = nemo_model._setup_transcribe_dataloader(config) - for test_batch in temporary_datalayer: - encoded, encoded_len = nemo_model.forward( - input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) - ) - - best_hyp, all_hyp = nemo_model.decoding.rnnt_decoder_predictions_tensor( - encoded, encoded_len, return_hypotheses=False, partial_hypotheses=None, - ) - - best_hyp_0, all_hyp_0 = nemo_model.decoding.rnnt_decoder_predictions_tensor( - encoded[0:1, ...], encoded_len[0:1], return_hypotheses=False, partial_hypotheses=None, - ) - - best_hyp_1, all_hyp_1 = nemo_model.decoding.rnnt_decoder_predictions_tensor( - encoded[1:2, ...], encoded_len[1:2], return_hypotheses=False, partial_hypotheses=None, - ) - - encoded_0, encoded_len_0 = nemo_model.forward( - input_signal=test_batch[0][0:1].to(device), input_signal_length=test_batch[1][0:1].to(device) - ) - - best_hyp_0_single, all_hyp_0_single = nemo_model.decoding.rnnt_decoder_predictions_tensor( - encoded_0, encoded_len_0, return_hypotheses=False, partial_hypotheses=None, - ) - - encoded_1, encoded_len_1 = nemo_model.forward( - input_signal=test_batch[0][1:2].to(device), input_signal_length=test_batch[1][1:2].to(device) - ) - - best_hyp_1_single, all_hyp_1_single = nemo_model.decoding.rnnt_decoder_predictions_tensor( - encoded_1, encoded_len_1, return_hypotheses=False, partial_hypotheses=None, - ) - - import ipdb - - ipdb.set_trace()