Skip to content

Commit

Permalink
Several fixes based on PR feedback.
Browse files Browse the repository at this point in the history
Probably not quite done yet. In particular, need to iron out guarding
the cuda-python import.
  • Loading branch information
galv committed Feb 9, 2024
1 parent 754844f commit e5827b1
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 287 deletions.
8 changes: 6 additions & 2 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ def mask_select_states(
return None
return [states[0][mask]]

def batch_copy_states_mask(
self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], mask: torch.Tensor):
torch.where(mask.unsqueeze(1), new_states[0], old_states[0], out=old_states[0])

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
Expand Down Expand Up @@ -1121,6 +1125,8 @@ def batch_copy_states_mask(
self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], mask: torch.Tensor,
) -> List[torch.Tensor]:
"""Copy states from new state to old state at certain indices specified by mask
Unlike `mask_select_states` and `batch_copy_states`, this does not require any
synchronization with the host.
Args:
old_states(list): packed decoder states
Expand All @@ -1141,8 +1147,6 @@ def batch_copy_states_mask(
mask.unsqueeze(1).unsqueeze(0), new_states[state_id], old_states[state_id], out=old_states[state_id]
)

return old_states

# Adapter method overrides
def add_adapter(self, name: str, cfg: DictConfig):
# Update the config with correct input dim
Expand Down
21 changes: 21 additions & 0 deletions nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,24 @@ def mask_select_states(self, states: Any, mask: torch.Tensor) -> Any:
states filtered by mask (same type as `states`)
"""
raise NotImplementedError()

def batch_copy_states_mask(
self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], mask: torch.Tensor):
"""Copy states from new state to old state at certain indices specified by mask.
Unlike `mask_select_states` and `batch_copy_states`, this does not require any
synchronization with the host.
Args:
old_states(list): packed decoder states
(L x B x H, L x B x H)
new_states: packed decoder states
(L x B x H, L x B x H)
mask (tensor): Boolean tensor mask of length B. If True, copy from new_States to old_states at that index
Returns:
batch of decoder states with partial copy at ids (or a specific value).
(L x B x H, L x B x H)
"""
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.core.utils.cuda_python_utils import check_cuda_python_cuda_graphs_conditional_nodes_supported

check_cuda_python_cuda_graphs_conditional_nodes_supported()

import contextlib
import ctypes
import time
from dataclasses import dataclass, field
from itertools import product
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from cuda import cuda, cudart, nvrtc
from omegaconf import DictConfig, OmegaConf
try:
from cuda import cuda, cudart, nvrtc
HAVE_CUDA_PYTHON = True
except ImportError:
HAVE_CUDA_PYTHON = False
from typing import List, Optional

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.rnn import label_collate
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType
from nemo.core.utils.cuda_python_utils import check_cuda_python_cuda_graphs_conditional_nodes_supported, assert_drv, cu_call
from nemo.utils import logging


def ASSERT_DRV(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
elif isinstance(err, nvrtc.nvrtcResult):
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError("Nvrtc Error: {}".format(err))
else:
raise RuntimeError("Unknown error type: {}".format(err))


def cu_call(f_call_out):
error, *others = f_call_out
if error != cudart.cudaError_t.cudaSuccess:
# import ipdb; ipdb.set_trace()
raise Exception(f"CUDA failure! {error}")
else:
# print("GALVEZ:", others)
return tuple(others)


def run_nvrtc(kernel_string, kernel_name):
err, prog = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"while_loop_conditional.cu", 0, [], [])

ASSERT_DRV(err)

assert_drv(err)
# Compile program
# Not specifying --gpu-architecture will default us to a fairly low compute capability, which is a safe bet.
# Otherwise, there are ways to query the current device's compute capability.
Expand All @@ -72,28 +41,30 @@ def run_nvrtc(kernel_string, kernel_name):
err, size = nvrtc.nvrtcGetProgramLogSize(prog)
buf = b" " * size
(err,) = nvrtc.nvrtcGetProgramLog(prog, buf)
# print(buf.decode("utf-8"))
ASSERT_DRV(err)
assert_drv(err)

# Get PTX from compilation
err, ptxSize = nvrtc.nvrtcGetPTXSize(prog)
ASSERT_DRV(err)
assert_drv(err)
ptx = b" " * ptxSize
(err,) = nvrtc.nvrtcGetPTX(prog, ptx)
ASSERT_DRV(err)
assert_drv(err)

# print("GALVEZ:PTX=")
# print(ptx.decode("utf-8"))
ptx = np.char.array(ptx)
err, module = cuda.cuModuleLoadData(ptx.ctypes.data)
ASSERT_DRV(err)
assert_drv(err)
err, kernel = cuda.cuModuleGetFunction(module, kernel_name)
ASSERT_DRV(err)
assert_drv(err)

return kernel


def create_outer_for_loop_kernel():
"""
Creates a kernel that evaluates whether or not to enter the for loop body.
Effectively substitutes for `for time_idx in range(trip_count)`
such that that for loop can run on a GPU.
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
Expand All @@ -108,12 +79,12 @@ def create_outer_for_loop_kernel():
return run_nvrtc(kernel_string, b"for_loop_conditional")


# Observations: If cudaGraphSetConditional is true once, the kernel never completes...
# The GPU is doing *something*. I'm just not sure what...
# No way to query that at runtime...


def create_while_loop_kernel():
def create_inner_while_loop_kernel():
"""
Evaluates whether or not to keep evaluating the inner while loop body.
Continue until all elements of the batch output blank or the while loop
has run max_symbols times.
"""
kernel_string = r"""\
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
Expand All @@ -130,6 +101,14 @@ def create_while_loop_kernel():

@contextlib.contextmanager
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle):
"""
Even though we add a conditional node only once, we need to
capture the kernel that calls cudaGraphSetConditional() both
before in the parent graph containing the while loop body graph
and after the rest of the while loop body graph (because we need
to decide both whether to enter the loop, and also whether to
execute the next iteration of the loop).
"""
capture_status, _, graph, _, _ = cu_call(cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream().cuda_stream))
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

Expand All @@ -153,6 +132,7 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
driver_params.conditional.ctx = ctx

# Use driver API here because of bug in cuda-python runtime API: https://github.com/NVIDIA/cuda-python/issues/55
# TODO: Change call to this after fix goes in:
# node, = cu_call(cudart.cudaGraphAddNode(graph, dependencies, len(dependencies), driver_params))
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params))
body_graph = driver_params.conditional.phGraph_out[0]
Expand Down Expand Up @@ -190,8 +170,11 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
torch.cuda.set_stream(previous_stream)


class RNNTGreedyDecodeFast:
class RNNTGreedyDecodeCudaGraph:
def __init__(self, max_symbols: int, cuda_device, caller):
if not HAVE_CUDA_PYTHON:
check_cuda_python_cuda_graphs_conditional_nodes_supported()

assert max_symbols is not None

self.symbols_added_t = torch.tensor(0, dtype=torch.int64, device=cuda_device)
Expand All @@ -208,8 +191,9 @@ def __init__(self, max_symbols: int, cuda_device, caller):
self.encoder_output_length = None
self.f = None

# Reasonable default maximum time. 375 frames * 40ms / frame = 15 seconds
# This affects only the size of the CPU-pinned memory buffers
# Reasonable default maximum time. 375 frames * (80ms / frame) = 30 seconds
# 80ms is the frame size of recent fastconformer models
# This does not affect correctness.
self.max_time = 375
self.batch_size = 0

Expand All @@ -218,15 +202,24 @@ def __init__(self, max_symbols: int, cuda_device, caller):
self.graph = None
self.graph_exec = None

self.first_call = True

self.caller = caller

def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_length):
# We need to call _greedy_decode_blank_as_pad at least once
# before hand in order to make sure that pytorch is
# "initialize".
self.caller._greedy_decode_blank_as_pad_loop_frames(
encoder_output, encoder_output_length, encoder_output.device
)
if self.first_call:
# We need to call the original _greedy_decode_blank_as_pad
# implementation at least once beforehand in order to make
# sure that pytorch is "initialized". Pytorch may be
# uninitialized if this code runs before any other pytorch
# operation in this process. Pytorch often lazily
# initializes things like a cudnnHandle_t via
# cudnnCreate(), which can involve synchronizing with the
# host. Such actions are not stream capturable to a graph.
self.caller._greedy_decode_blank_as_pad_loop_frames(
encoder_output, encoder_output_length, encoder_output.device
)
self.first_call = False

self.max_time = max(self.max_time, max_time)
self.batch_size = max(self.batch_size, batch_size)
Expand Down Expand Up @@ -255,6 +248,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len

self.graph = torch.cuda.CUDAGraph()

# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
with torch.cuda.stream(torch.cuda.Stream()), torch.inference_mode(), torch.cuda.graph(self.graph):

self.f = torch.zeros(
Expand Down Expand Up @@ -307,7 +301,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len

torch.ge(self.time_idx_t, self.encoder_output_length, out=self.blank_mask)

while_loop_kernel = create_while_loop_kernel()
while_loop_kernel = create_inner_while_loop_kernel()
(while_loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0))
not_blank_ptr = np.array([self.not_all_blank_t.data_ptr()], dtype=np.uint64)
symbols_added_ptr = np.array([self.symbols_added_t.data_ptr()], dtype=np.uint64)
Expand Down Expand Up @@ -369,58 +363,43 @@ def __call__(
device: torch.device,
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):

# Need to copy x and out_len into "staging buffers", or do a graph update.

if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not supported")
raise NotImplementedError("`partial_hypotheses` support is not available with cuda graphs (but could be)")

if self.caller.preserve_alignments:
raise NotImplementedError("`preserve_alignments` support is not available with cuda graphs (but could be)")

assert not self.caller.preserve_alignments
assert not self.caller.preserve_frame_confidence
if self.caller.preserve_frame_confidence:
raise NotImplementedError("`preserve_frame_confidence` support is not available with cuda graphs (but could be)")

batch_size = x.shape[0]
# ideally we would use out_len.max() here...
# We could use out_len.max() here instead of x.shape[1], in
# case for some reason the user passes in a larger buffer than
# required, since we know that `out_len.max() <= x.shape[1]`.
max_time = x.shape[1]

if torch.is_autocast_enabled():
x = x.to(torch.get_autocast_gpu_dtype())

# What do we do if batch_size is actually smaller for the
# input? Is this a problem? the clone() call will fail...
if max_time > self.max_time or batch_size > self.batch_size:
self._reinitialize(max_time, batch_size, x, out_len)

torch.cuda.nvtx.range_push("Graph exec")
self.encoder_output[: x.shape[0], : x.shape[1], ...].copy_(x)
self.encoder_output_length[: out_len.shape[0]].copy_(out_len)
self.graph.replay()
torch.cuda.current_stream().synchronize()
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("Copy data out")

torch.cuda.nvtx.range_push("scores_cpu mask")
self.scores_cpu[self.labels_cpu == self.caller._blank_index] = 0.0
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("scores_cpu sum")
total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1, 2))
torch.cuda.nvtx.range_pop()
total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1,2))

tokens_per_timestep = (self.labels_cpu != self.caller._blank_index).sum(axis=-1)
torch.cuda.nvtx.range_push("repeat_interleave")
timesteps_packed = torch.repeat_interleave(
torch.arange(self.max_time).repeat(self.batch_size), tokens_per_timestep.flatten()
)
timesteps_packed = torch.repeat_interleave(torch.arange(self.max_time).repeat(self.batch_size), tokens_per_timestep.flatten())
timestep_segments = tokens_per_timestep.sum(axis=-1)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("labels")
valid_labels_mask = self.labels_cpu != self.caller._blank_index
labels_segments = valid_labels_mask.sum(axis=(1, 2))
valid_labels_mask = (self.labels_cpu != self.caller._blank_index)
labels_segments = valid_labels_mask.sum(axis=(1,2))
labels_packed = self.labels_cpu[valid_labels_mask]
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("Convert to Hypotheses")
hypotheses = [
rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batch_size)
]
Expand All @@ -433,8 +412,5 @@ def __call__(
hypotheses[i].score = float(total_scores[i])
hypotheses[i].y_sequence = labels_packed[labels_start : labels_start + labels_segments[i]].tolist()
labels_start += labels_segments[i]
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_pop()

return hypotheses
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,9 @@ def __init__(
confidence_method_cfg=confidence_method_cfg,
)
elif use_cuda_graph_decoder:
from nemo.collections.asr.parts.submodules.fast_rnnt_greedy_decoding import RNNTGreedyDecodeFast
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph

self._greedy_decode = RNNTGreedyDecodeFast(max_symbols_per_step, torch.device("cuda"), self)
self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, torch.device("cuda"), self)
else:
# previous algo: loop over frames
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
Expand Down
Loading

0 comments on commit e5827b1

Please sign in to comment.