-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup RNN-T greedy decoding #7926
Changes from all commits
9342489
7bcc4c0
7a0942f
26ec40c
1d556ea
cf631dd
a50965d
510eb90
659cfff
40d1568
ca2d94b
b328fac
7997bd6
6f7746b
95da9d1
ef35381
ca5779d
97092ff
c9785ff
1e09979
f4b7b68
d67b14b
c7d298d
2ea8f7f
5c8e18e
e8c43d0
ffe2a67
77bf674
02a9bbd
83c4793
430e159
266be2c
ce33493
9d545ee
9669149
1dbf29e
b4421cd
3e1ca1e
1b97e33
4429432
b7b83df
3df991a
31649fa
5f67c66
df86b17
0f4463b
c38f222
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,9 +138,9 @@ def return_hat_ilm(self): | |
def return_hat_ilm(self, hat_subtract_ilm): | ||
self._return_hat_ilm = hat_subtract_ilm | ||
|
||
def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]: | ||
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]: | ||
""" | ||
Compute the joint step of the network. | ||
Compute the joint step of the network after Encoder/Decoder projection. | ||
|
||
Here, | ||
B = Batch size | ||
|
@@ -169,14 +169,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJoin | |
Log softmaxed tensor of shape (B, T, U, V + 1). | ||
Internal LM probability (B, 1, U, V) -- in case of return_ilm==True. | ||
""" | ||
# f = [B, T, H1] | ||
f = self.enc(f) | ||
f.unsqueeze_(dim=2) # (B, T, 1, H) | ||
|
||
# g = [B, U, H2] | ||
g = self.pred(g) | ||
g.unsqueeze_(dim=1) # (B, 1, U, H) | ||
|
||
f = f.unsqueeze(dim=2) # (B, T, 1, H) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove the preemptive enc() pred() ? This is shown to be equivalent to RNNT and saves a ton of memory There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inplace Due to separating projections I needed to replace in-place
You can check it manually: import torch
device = torch.device('cuda:0')
def print_allocated(device, prefix=""):
allocated_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
print(f"{prefix}{allocated_mb:.0f}MB")
print_allocated(device, prefix="Before: ") # Should be 0MB
# allocate memory ~projection result
data = torch.rand([128, 30 * 1000 // 10 // 8, 640], device=device)
print_allocated(device, prefix="After project encoder output: ") # 118MB
# apply unsqueeze
data2 = data.unsqueeze(-1) # unsqueeze returns a new tensor, but storage is the same (only metadata is new!)
print_allocated(device, prefix="After Unsqueeze: ") # same, 118MB |
||
g = g.unsqueeze(dim=1) # (B, 1, U, H) | ||
inp = f + g # [B, T, U, H] | ||
|
||
del f | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -398,6 +398,22 @@ def batch_copy_states( | |
|
||
return old_states | ||
|
||
def mask_select_states( | ||
self, states: Optional[List[torch.Tensor]], mask: torch.Tensor | ||
) -> Optional[List[torch.Tensor]]: | ||
""" | ||
Return states by mask selection | ||
Args: | ||
states: states for the batch | ||
mask: boolean mask for selecting states; batch dimension should be the same as for states | ||
|
||
Returns: | ||
states filtered by mask | ||
""" | ||
if states is None: | ||
return None | ||
return [states[0][mask]] | ||
|
||
def batch_score_hypothesis( | ||
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] | ||
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: | ||
|
@@ -1047,6 +1063,21 @@ def batch_copy_states( | |
|
||
return old_states | ||
|
||
def mask_select_states( | ||
self, states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Return states by mask selection | ||
Args: | ||
states: states for the batch | ||
mask: boolean mask for selecting states; batch dimension should be the same as for states | ||
|
||
Returns: | ||
states filtered by mask | ||
""" | ||
# LSTM in PyTorch returns a tuple of 2 tensors as a state | ||
return states[0][:, mask], states[1][:, mask] | ||
|
||
# Adapter method overrides | ||
def add_adapter(self, name: str, cfg: DictConfig): | ||
# Update the config with correct input dim | ||
|
@@ -1382,9 +1413,33 @@ def forward( | |
|
||
return losses, wer, wer_num, wer_denom | ||
|
||
def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: | ||
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Project the encoder output to the joint hidden dimension. | ||
|
||
Args: | ||
encoder_output: A torch.Tensor of shape [B, T, D] | ||
|
||
Returns: | ||
A torch.Tensor of shape [B, T, H] | ||
""" | ||
return self.enc(encoder_output) | ||
|
||
def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Project the Prediction Network (Decoder) output to the joint hidden dimension. | ||
|
||
Args: | ||
prednet_output: A torch.Tensor of shape [B, U, D] | ||
|
||
Returns: | ||
A torch.Tensor of shape [B, U, H] | ||
""" | ||
return self.pred(prednet_output) | ||
|
||
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert name change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is essential to separate projections from other joint computations. It introduces no memory/computational overhead. See details in slack |
||
""" | ||
Compute the joint step of the network. | ||
Compute the joint step of the network after projection. | ||
|
||
Here, | ||
B = Batch size | ||
|
@@ -1412,14 +1467,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: | |
Returns: | ||
Logits / log softmaxed tensor of shape (B, T, U, V + 1). | ||
""" | ||
# f = [B, T, H1] | ||
f = self.enc(f) | ||
f.unsqueeze_(dim=2) # (B, T, 1, H) | ||
|
||
# g = [B, U, H2] | ||
g = self.pred(g) | ||
g.unsqueeze_(dim=1) # (B, 1, U, H) | ||
|
||
f = f.unsqueeze(dim=2) # (B, T, 1, H) | ||
g = g.unsqueeze(dim=1) # (B, 1, U, H) | ||
inp = f + g # [B, T, U, H] | ||
|
||
del f, g | ||
|
@@ -1536,7 +1585,7 @@ def set_fuse_loss_wer(self, fuse_loss_wer, loss=None, metric=None): | |
|
||
@property | ||
def fused_batch_size(self): | ||
return self._fuse_loss_wer | ||
return self._fused_batch_size | ||
|
||
def set_fused_batch_size(self, fused_batch_size): | ||
self._fused_batch_size = fused_batch_size | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,45 @@ class AbstractRNNTJoint(NeuralModule, ABC): | |
""" | ||
|
||
@abstractmethod | ||
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert name change. It's fine to keep joint There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the comments above |
||
""" | ||
Compute the joint step of the network after the projection step. | ||
Args: | ||
f: Output of the Encoder model after projection. A torch.Tensor of shape [B, T, H] | ||
g: Output of the Decoder model (Prediction Network) after projection. A torch.Tensor of shape [B, U, H] | ||
|
||
Returns: | ||
Logits / log softmaxed tensor of shape (B, T, U, V + 1). | ||
Arbitrary return type, preferably torch.Tensor, but not limited to (e.g., see HatJoint) | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Project the encoder output to the joint hidden dimension. | ||
|
||
Args: | ||
encoder_output: A torch.Tensor of shape [B, T, D] | ||
|
||
Returns: | ||
A torch.Tensor of shape [B, T, H] | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Project the Prediction Network (Decoder) output to the joint hidden dimension. | ||
|
||
Args: | ||
prednet_output: A torch.Tensor of shape [B, U, D] | ||
|
||
Returns: | ||
A torch.Tensor of shape [B, U, H] | ||
""" | ||
raise NotImplementedError() | ||
|
||
def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute the joint step of the network. | ||
|
@@ -58,7 +97,7 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: | |
Returns: | ||
Logits / log softmaxed tensor of shape (B, T, U, V + 1). | ||
""" | ||
raise NotImplementedError() | ||
return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g)) | ||
|
||
@property | ||
def num_classes_with_blank(self): | ||
|
@@ -277,3 +316,15 @@ def batch_copy_states( | |
(L x B x H, L x B x H) | ||
""" | ||
raise NotImplementedError() | ||
|
||
def mask_select_states(self, states: Any, mask: torch.Tensor) -> Any: | ||
""" | ||
Return states by mask selection | ||
Args: | ||
states: states for the batch (preferably a list of tensors, but not limited to) | ||
mask: boolean mask for selecting states; batch dimension should be the same as for states | ||
|
||
Returns: | ||
states filtered by mask (same type as `states`) | ||
""" | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to have the similar API name across the RNNT Joints, is it necessary to change this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API is changed for all Joints, starting from
AbstractRNNTJoint
(see details in Slack)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it is the following: