Skip to content

Commit

Permalink
Speed up copying data into Hypothesis data structure by minimizing lo…
Browse files Browse the repository at this point in the history
…ops.
  • Loading branch information
galv committed Jan 4, 2024
1 parent c87f35b commit f132c52
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 63 deletions.
80 changes: 19 additions & 61 deletions nemo/collections/asr/parts/submodules/fast_rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
device=encoder_output_length.device)

# (batch_size, max_time * max_symbols)
self.scores_cpu = torch.zeros((self.max_time * self.max_symbols, self.batch_size), dtype=self.dtype, device="cpu", pin_memory=True)
self.labels_cpu = torch.zeros((self.max_time * self.max_symbols, self.batch_size), dtype=torch.int64, device="cpu", pin_memory=True)
self.scores_cpu = torch.zeros((self.batch_size, self.max_time * self.max_symbols), dtype=self.dtype, device="cpu", pin_memory=True)
self.labels_cpu = torch.zeros((self.batch_size, self.max_time * self.max_symbols), dtype=torch.int64, device="cpu", pin_memory=True)
self.symbols_per_time_step_cpu = torch.zeros(self.max_time, dtype=torch.int64, device="cpu", pin_memory=True)

with torch.cuda.stream(torch.cuda.Stream()), torch.inference_mode():
Expand Down Expand Up @@ -293,8 +293,8 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
self.symbols_per_time_step.index_copy_(0, self.time_idx_t, self.symbols_added_t)
self.time_idx_t += 1

self.scores_cpu.copy_(self.scores, non_blocking=True)
self.labels_cpu.copy_(self.labels, non_blocking=True)
self.scores_cpu.copy_(self.scores.transpose(0, 1).contiguous(), non_blocking=True)
self.labels_cpu.copy_(self.labels.transpose(0, 1).contiguous(), non_blocking=True)
self.symbols_per_time_step_cpu.copy_(self.symbols_per_time_step, non_blocking=True)

self.last_label.fill_(self.caller._SOS)
Expand Down Expand Up @@ -340,69 +340,27 @@ def __call__(

torch.cuda.nvtx.range_push("Copy data out")

out_len_cpu = out_len.cpu()

hypotheses = [
rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batch_size)
]

j = 0
# Can I make this more efficient somehow? max_symbols=5 makes this very inefficient. Hmm...
# No need to iterate over all times for all elements of the batch, right?
for t in range(max_time):
max_non_blank_symbols = self.symbols_per_time_step_cpu[t]
# print("GALVEZ:", t, max_non_blank_symbols)
for _ in range(max_non_blank_symbols):
for i in range(batch_size):
if self.labels_cpu[j, i] == caller._blank_index:
# This is slow because we go through every symbol in a batch even once we have seen our first blank
for i in range(batch_size):
j = 0
for t in range(out_len_cpu[i]):
max_non_blank_symbols = self.symbols_per_time_step_cpu[t]
for counter in range(max_non_blank_symbols):
if self.labels_cpu[i, j] == caller._blank_index:
j += 1
continue
hypotheses[i].y_sequence.append(self.labels_cpu[j, i])
# j += (max_non_blank_symbols - counter)
# break
hypotheses[i].y_sequence.append(self.labels_cpu[i, j])
hypotheses[i].timestep.append(t)
hypotheses[i].score += self.scores_cpu[j, i]
j += 1
torch.cuda.nvtx.range_pop()

return hypotheses

hypotheses[i].score += self.scores_cpu[i, j]
j += 1

torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("RNN-T greedy search Inference")

with torch.inference_mode():
start = time.time()
# torch.cuda.cudart().cudaProfilerStart()
cu_call(cudart.cudaGraphLaunch(self.graph_exec, torch.cuda.current_stream().cuda_stream))
cu_call(cudart.cudaStreamSynchronize(torch.cuda.current_stream().cuda_stream))
end = time.time()
# print("total time:", end - start)

# torch.set_printoptions(threshold=100_000)
# print("GALVEZ:symbols_per_time_step=", self.symbols_per_time_step_cpu)
# print("GALVEZ:scores=", self.scores_cpu)
# print("GALVEZ:labels=", self.labels_cpu)
# print("GALVEZ:symbols_per_time_step=", self.symbols_per_time_step_cpu)
# print("GALVEZ:hidden1=", torch.sum(torch.abs(hidden[0]), dim=2))
# print("GALVEZ:hidden2=", torch.sum(torch.abs(hidden[1]), dim=2))

torch.cuda.nvtx.range_push("Copy data out")
# js = torch.zeros(batch_size, dtype=torch.int64, device="cpu")
j = 0
for t in range(max_time):
max_non_blank_symbols = self.symbols_per_time_step_cpu[t]
# print("GALVEZ:", t, max_non_blank_symbols)
for _ in range(max_non_blank_symbols):
for i in range(batch_size):
if self.labels_cpu[j, i] == caller._blank_index:
# Ooops! This is not correct!!!!! It's continue... It's fine...
continue
hypotheses[i].y_sequence.append(self.labels_cpu[j, i])
hypotheses[i].timestep.append(t)
hypotheses[i].score += self.scores_cpu[j, i]
j += 1
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
# torch.cuda.cudart().cudaProfilerStop()
# print("NEW:", hypotheses)
# import ipdb; ipdb.set_trace()

return hypotheses
return hypotheses
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_for_loop():
map_location="cuda")
conf = nemo_model.to_config_dict()
with open_dict(conf):
conf["decoding"]["greedy"]["max_symbols"] = 1
conf["decoding"]["greedy"]["max_symbols"] = 5

with tempfile.NamedTemporaryFile() as fp:
OmegaConf.save(config=conf, f=fp.name)
Expand All @@ -48,7 +48,7 @@ def test_for_loop():
nemo_model.decoder.freeze()
nemo_model.joint.freeze()

audio_filepaths = glob.glob("/home/dgalvez/scratch/data/LibriSpeech/test-clean-processed/*.wav")[:64]
audio_filepaths = glob.glob("/home/dgalvez/scratch/data/LibriSpeech/test-clean-processed/*.wav")
batch_size = 16

torch.cuda.cudart().cudaProfilerStart()
Expand Down

0 comments on commit f132c52

Please sign in to comment.