-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup RNN-T greedy decoding #7926
Conversation
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
jenkins |
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but see comments.
I'd like to especially commend your tests, thanks for improving NeMo!
|
||
# Use the following commented print statements to check | ||
# the alignment of other algorithms compared to the default | ||
print("Text", hyp.text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the following commented print statements
not commented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copied from the code nearby.
I reworked the test: instead of just printing the alignment, I use non-batched greedy decoding as a reference, and check if the batched version returns the same results.
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work. Minor comments on inline documentation of the actual decoding loop and explain what is loop labels.
I also want to ask why the separation of joint into 3 functions - it seems ok but for example allows HAT to use less memory efficient path which can cause oom.
Finally, excellent tests, much better coverage of cases than before
@@ -138,9 +138,9 @@ def return_hat_ilm(self): | |||
def return_hat_ilm(self, hat_subtract_ilm): | |||
self._return_hat_ilm = hat_subtract_ilm | |||
|
|||
def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]: | |||
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to have the similar API name across the RNNT Joints, is it necessary to change this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API is changed for all Joints, starting from AbstractRNNTJoint
(see details in Slack)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it is the following:
class AbstractRNNTJoint(NeuralModule, ABC):
@abstractmethod
def project_encoder(self, encoder_output):
raise NotImplementedError() # can be Linear or identity
@abstractmethod
def project_prednet(self, encoder_output):
raise NotImplementedError() # can be Linear or identity
@abstractmethod
def joint_after_projection(self, f, g):
"""This is the main method that one should implement for Joint"""
raise NotImplementedError()
def joint(self, f, g):
"""Full joint computation. Not abstract anymore!"""
return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g))
g = self.pred(g) | ||
g.unsqueeze_(dim=1) # (B, 1, U, H) | ||
|
||
f = f.unsqueeze(dim=2) # (B, T, 1, H) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the preemptive enc() pred() ? This is shown to be equivalent to RNNT and saves a ton of memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inplace unsqueeze_
does not save memory.
Due to separating projections I needed to replace in-place unsqueeze_
operation with unsqueeze
. There is no overhead in memory.
According to the documentation https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
The returned tensor shares the same underlying data with this tensor.
You can check it manually:
import torch
device = torch.device('cuda:0')
def print_allocated(device, prefix=""):
allocated_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
print(f"{prefix}{allocated_mb:.0f}MB")
print_allocated(device, prefix="Before: ") # Should be 0MB
# allocate memory ~projection result
data = torch.rand([128, 30 * 1000 // 10 // 8, 640], device=device)
print_allocated(device, prefix="After project encoder output: ") # 118MB
# apply unsqueeze
data2 = data.unsqueeze(-1) # unsqueeze returns a new tensor, but storage is the same (only metadata is new!)
print_allocated(device, prefix="After Unsqueeze: ") # same, 118MB
""" | ||
return self.pred(prednet_output) | ||
|
||
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert name change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is essential to separate projections from other joint computations. It introduces no memory/computational overhead. See details in slack
@@ -28,6 +28,45 @@ class AbstractRNNTJoint(NeuralModule, ABC): | |||
""" | |||
|
|||
@abstractmethod | |||
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert name change. It's fine to keep joint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comments above
@@ -545,6 +573,7 @@ def __init__( | |||
preserve_alignments: bool = False, | |||
preserve_frame_confidence: bool = False, | |||
confidence_method_cfg: Optional[DictConfig] = None, | |||
loop_labels: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain in docstring what this isc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, missed the class docstring before)
if self.preserve_frame_confidence | ||
else None, | ||
) | ||
advance_mask = torch.logical_and(blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment
.squeeze(1) | ||
.squeeze(1) | ||
) | ||
more_scores, more_labels = logits.max(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment (above this line)
|
||
# stage 4: to avoid looping, go to next frame after max_symbols emission | ||
if self.max_symbols is not None: | ||
force_blank_mask = torch.logical_and( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment above
|
||
|
||
class BatchedHyps: | ||
"""Class to store batched hypotheses (labels, time_indices, scores) for efficient RNNT decoding""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very neat, this is done so that jit compile is happy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep) There is also a test that torch.jit is fine with this structure :)
return hypotheses | ||
|
||
|
||
def return_empty_hypotheses( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty hys might be needed for beam search init and temp placeholders now that I remember
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this function, this was used only when max_symbols=0
for the new decoding algorithm
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Vladimir Bataev <[email protected]>
jenkins |
Signed-off-by: Vladimir Bataev <[email protected]>
jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After detailed explanation, the changes make sense design wize. If you could put those explanation in the PR itself it will help for future discussion. Thanks again for the significant speedup !
Pasting here the discussion from Slack about Joint refactoring, Main points:
1) separating projections in Joint from other operations is helpful in many cases.Even in the original encoder algorithm, when we loop over encoder frames, we can project the frame immediately (one-by-one => no memory overhead), but this will save computations: for each encoder frame, multiple evaluations for Joint are used => we waste time when recalculating the encoder vector's projection. 2) The immediate projection of encoder output is a tiny overheadI see the speedup from projecting the encoder output immediately.
Given all these facts, it is acceptable to project the encoder output immediately. If we need a robust memory consumption optimization, we can use a separate flag ( 3) in-place
|
* Add structure for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add faster decoding algo Signed-off-by: Vladimir Bataev <[email protected]> * Simplify max_symbols support. More speedup Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Filtering only when necessary Signed-off-by: Vladimir Bataev <[email protected]> * Move max_symbols check to the end of loop Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support returning prediction network states Signed-off-by: Vladimir Bataev <[email protected]> * Support preserve_alignments flag Signed-off-by: Vladimir Bataev <[email protected]> * Support confidence Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Partial fix for jit compatibility Signed-off-by: Vladimir Bataev <[email protected]> * Support switching between decoding algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Fix switching algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Fix max symbols per step Signed-off-by: Vladimir Bataev <[email protected]> * Add tests. Preserve torch.jit compatibility for BatchedHyps Signed-off-by: Vladimir Bataev <[email protected]> * Separate projection from Joint calculation in decoding Signed-off-by: Vladimir Bataev <[email protected]> * Fix config instantiation Signed-off-by: Vladimir Bataev <[email protected]> * Fix after main merge Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Speedup alignments Signed-off-by: Vladimir Bataev <[email protected]> * Test alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests for alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add more tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence tests Signed-off-by: Vladimir Bataev <[email protected]> * Avoid common package modification Signed-off-by: Vladimir Bataev <[email protected]> * Support Stateless prediction network Signed-off-by: Vladimir Bataev <[email protected]> * Improve stateless decoder support. Separate alignments and confidence Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step=0 Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Batched Hyps/Alignments: lengths -> current_lengths Signed-off-by: Vladimir Bataev <[email protected]> * Simplify indexing Signed-off-by: Vladimir Bataev <[email protected]> * Improve type annotations Signed-off-by: Vladimir Bataev <[email protected]> * Rework test for greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Document loop_labels Signed-off-by: Vladimir Bataev <[email protected]> * Raise ValueError if max_symbols_per_step <= 0 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Add structure for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add faster decoding algo Signed-off-by: Vladimir Bataev <[email protected]> * Simplify max_symbols support. More speedup Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Filtering only when necessary Signed-off-by: Vladimir Bataev <[email protected]> * Move max_symbols check to the end of loop Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support returning prediction network states Signed-off-by: Vladimir Bataev <[email protected]> * Support preserve_alignments flag Signed-off-by: Vladimir Bataev <[email protected]> * Support confidence Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Partial fix for jit compatibility Signed-off-by: Vladimir Bataev <[email protected]> * Support switching between decoding algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Fix switching algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Fix max symbols per step Signed-off-by: Vladimir Bataev <[email protected]> * Add tests. Preserve torch.jit compatibility for BatchedHyps Signed-off-by: Vladimir Bataev <[email protected]> * Separate projection from Joint calculation in decoding Signed-off-by: Vladimir Bataev <[email protected]> * Fix config instantiation Signed-off-by: Vladimir Bataev <[email protected]> * Fix after main merge Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Speedup alignments Signed-off-by: Vladimir Bataev <[email protected]> * Test alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests for alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add more tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence tests Signed-off-by: Vladimir Bataev <[email protected]> * Avoid common package modification Signed-off-by: Vladimir Bataev <[email protected]> * Support Stateless prediction network Signed-off-by: Vladimir Bataev <[email protected]> * Improve stateless decoder support. Separate alignments and confidence Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step=0 Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Batched Hyps/Alignments: lengths -> current_lengths Signed-off-by: Vladimir Bataev <[email protected]> * Simplify indexing Signed-off-by: Vladimir Bataev <[email protected]> * Improve type annotations Signed-off-by: Vladimir Bataev <[email protected]> * Rework test for greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Document loop_labels Signed-off-by: Vladimir Bataev <[email protected]> * Raise ValueError if max_symbols_per_step <= 0 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Add structure for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add faster decoding algo Signed-off-by: Vladimir Bataev <[email protected]> * Simplify max_symbols support. More speedup Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Filtering only when necessary Signed-off-by: Vladimir Bataev <[email protected]> * Move max_symbols check to the end of loop Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support returning prediction network states Signed-off-by: Vladimir Bataev <[email protected]> * Support preserve_alignments flag Signed-off-by: Vladimir Bataev <[email protected]> * Support confidence Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Partial fix for jit compatibility Signed-off-by: Vladimir Bataev <[email protected]> * Support switching between decoding algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Fix switching algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Fix max symbols per step Signed-off-by: Vladimir Bataev <[email protected]> * Add tests. Preserve torch.jit compatibility for BatchedHyps Signed-off-by: Vladimir Bataev <[email protected]> * Separate projection from Joint calculation in decoding Signed-off-by: Vladimir Bataev <[email protected]> * Fix config instantiation Signed-off-by: Vladimir Bataev <[email protected]> * Fix after main merge Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Speedup alignments Signed-off-by: Vladimir Bataev <[email protected]> * Test alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests for alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add more tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence tests Signed-off-by: Vladimir Bataev <[email protected]> * Avoid common package modification Signed-off-by: Vladimir Bataev <[email protected]> * Support Stateless prediction network Signed-off-by: Vladimir Bataev <[email protected]> * Improve stateless decoder support. Separate alignments and confidence Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step=0 Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Batched Hyps/Alignments: lengths -> current_lengths Signed-off-by: Vladimir Bataev <[email protected]> * Simplify indexing Signed-off-by: Vladimir Bataev <[email protected]> * Improve type annotations Signed-off-by: Vladimir Bataev <[email protected]> * Rework test for greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Document loop_labels Signed-off-by: Vladimir Bataev <[email protected]> * Raise ValueError if max_symbols_per_step <= 0 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: stevehuang52 <[email protected]>
* Add structure for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add faster decoding algo Signed-off-by: Vladimir Bataev <[email protected]> * Simplify max_symbols support. More speedup Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Filtering only when necessary Signed-off-by: Vladimir Bataev <[email protected]> * Move max_symbols check to the end of loop Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support returning prediction network states Signed-off-by: Vladimir Bataev <[email protected]> * Support preserve_alignments flag Signed-off-by: Vladimir Bataev <[email protected]> * Support confidence Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Partial fix for jit compatibility Signed-off-by: Vladimir Bataev <[email protected]> * Support switching between decoding algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Fix switching algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Fix max symbols per step Signed-off-by: Vladimir Bataev <[email protected]> * Add tests. Preserve torch.jit compatibility for BatchedHyps Signed-off-by: Vladimir Bataev <[email protected]> * Separate projection from Joint calculation in decoding Signed-off-by: Vladimir Bataev <[email protected]> * Fix config instantiation Signed-off-by: Vladimir Bataev <[email protected]> * Fix after main merge Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Speedup alignments Signed-off-by: Vladimir Bataev <[email protected]> * Test alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests for alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add more tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence tests Signed-off-by: Vladimir Bataev <[email protected]> * Avoid common package modification Signed-off-by: Vladimir Bataev <[email protected]> * Support Stateless prediction network Signed-off-by: Vladimir Bataev <[email protected]> * Improve stateless decoder support. Separate alignments and confidence Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step=0 Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Batched Hyps/Alignments: lengths -> current_lengths Signed-off-by: Vladimir Bataev <[email protected]> * Simplify indexing Signed-off-by: Vladimir Bataev <[email protected]> * Improve type annotations Signed-off-by: Vladimir Bataev <[email protected]> * Rework test for greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Document loop_labels Signed-off-by: Vladimir Bataev <[email protected]> * Raise ValueError if max_symbols_per_step <= 0 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister <[email protected]>
* Add structure for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add faster decoding algo Signed-off-by: Vladimir Bataev <[email protected]> * Simplify max_symbols support. More speedup Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Filtering only when necessary Signed-off-by: Vladimir Bataev <[email protected]> * Move max_symbols check to the end of loop Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support returning prediction network states Signed-off-by: Vladimir Bataev <[email protected]> * Support preserve_alignments flag Signed-off-by: Vladimir Bataev <[email protected]> * Support confidence Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Partial fix for jit compatibility Signed-off-by: Vladimir Bataev <[email protected]> * Support switching between decoding algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Fix switching algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Fix max symbols per step Signed-off-by: Vladimir Bataev <[email protected]> * Add tests. Preserve torch.jit compatibility for BatchedHyps Signed-off-by: Vladimir Bataev <[email protected]> * Separate projection from Joint calculation in decoding Signed-off-by: Vladimir Bataev <[email protected]> * Fix config instantiation Signed-off-by: Vladimir Bataev <[email protected]> * Fix after main merge Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Speedup alignments Signed-off-by: Vladimir Bataev <[email protected]> * Test alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests for alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add more tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence tests Signed-off-by: Vladimir Bataev <[email protected]> * Avoid common package modification Signed-off-by: Vladimir Bataev <[email protected]> * Support Stateless prediction network Signed-off-by: Vladimir Bataev <[email protected]> * Improve stateless decoder support. Separate alignments and confidence Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step=0 Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Batched Hyps/Alignments: lengths -> current_lengths Signed-off-by: Vladimir Bataev <[email protected]> * Simplify indexing Signed-off-by: Vladimir Bataev <[email protected]> * Improve type annotations Signed-off-by: Vladimir Bataev <[email protected]> * Rework test for greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Document loop_labels Signed-off-by: Vladimir Bataev <[email protected]> * Raise ValueError if max_symbols_per_step <= 0 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Pablo Garay <[email protected]>
* Add structure for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Add faster decoding algo Signed-off-by: Vladimir Bataev <[email protected]> * Simplify max_symbols support. More speedup Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Filtering only when necessary Signed-off-by: Vladimir Bataev <[email protected]> * Move max_symbols check to the end of loop Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support returning prediction network states Signed-off-by: Vladimir Bataev <[email protected]> * Support preserve_alignments flag Signed-off-by: Vladimir Bataev <[email protected]> * Support confidence Signed-off-by: Vladimir Bataev <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Partial fix for jit compatibility Signed-off-by: Vladimir Bataev <[email protected]> * Support switching between decoding algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Fix switching algorithms Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Clean up Signed-off-by: Vladimir Bataev <[email protected]> * Fix max symbols per step Signed-off-by: Vladimir Bataev <[email protected]> * Add tests. Preserve torch.jit compatibility for BatchedHyps Signed-off-by: Vladimir Bataev <[email protected]> * Separate projection from Joint calculation in decoding Signed-off-by: Vladimir Bataev <[email protected]> * Fix config instantiation Signed-off-by: Vladimir Bataev <[email protected]> * Fix after main merge Signed-off-by: Vladimir Bataev <[email protected]> * Add tests for batched hypotheses Signed-off-by: Vladimir Bataev <[email protected]> * Speedup alignments Signed-off-by: Vladimir Bataev <[email protected]> * Test alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests for alignments Signed-off-by: Vladimir Bataev <[email protected]> * Add more tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix confidence tests Signed-off-by: Vladimir Bataev <[email protected]> * Avoid common package modification Signed-off-by: Vladimir Bataev <[email protected]> * Support Stateless prediction network Signed-off-by: Vladimir Bataev <[email protected]> * Improve stateless decoder support. Separate alignments and confidence Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step Signed-off-by: Vladimir Bataev <[email protected]> * Fix alignments for max_symbols_per_step=0 Signed-off-by: Vladimir Bataev <[email protected]> * Fix tests Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Batched Hyps/Alignments: lengths -> current_lengths Signed-off-by: Vladimir Bataev <[email protected]> * Simplify indexing Signed-off-by: Vladimir Bataev <[email protected]> * Improve type annotations Signed-off-by: Vladimir Bataev <[email protected]> * Rework test for greedy decoding Signed-off-by: Vladimir Bataev <[email protected]> * Document loop_labels Signed-off-by: Vladimir Bataev <[email protected]> * Raise ValueError if max_symbols_per_step <= 0 Signed-off-by: Vladimir Bataev <[email protected]> * Add comments Signed-off-by: Vladimir Bataev <[email protected]> * Fix test Signed-off-by: Vladimir Bataev <[email protected]> --------- Signed-off-by: Vladimir Bataev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
What does this PR do ?
New algorithm for greedy batched decoding for RNN-Transducer.
With large batch sizes (e.g., 128) the expected speedup for large Fast Conformer-Transducer (full evaluation time including Encoder) is 1.7x-1.9x (when using
speech_to_text_eval.py
). For small batch sizes, e.g., 16, the observed speedup is ~1.3x.The original algorithm is preserved and can be enabled by using
loop_labels=False
E.g., on my local machine, with bf16, bs=128, Fast Conformer-Transducer Large, full
test-other
decodingCollection: [ASR]
Changelog
Usage
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information