Skip to content

Commit

Permalink
[squeezeformer] Support ONNX GPU export. (#1634)
Browse files Browse the repository at this point in the history
* [squeezeformer] Support ONNX GPU export.

* fix Lint

Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
yygle and takinaudiollm authored Dec 26, 2022
1 parent 31d7bee commit 1975ea9
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 2 deletions.
193 changes: 192 additions & 1 deletion wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def forward(self, chunk_xs, chunk_lens, offset,
r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
if self.encoder.normalize_before:
chunk_out = self.encoder.after_norm(xs)
else:
chunk_out = xs

r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
if not self.transformer:
Expand All @@ -184,6 +186,190 @@ def forward(self, chunk_xs, chunk_lens, offset,
chunk_out_lens = chunk_lens // self.subsampling_rate
r_offset = r_offset.unsqueeze(1)


return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \
r_offset, r_att_cache, r_cnn_cache, r_cache_mask


class StreamingSqueezeformerEncoder(torch.nn.Module):
def __init__(self, model, required_cache_size, beam_size):
super().__init__()
self.ctc = model.ctc
self.subsampling_rate = model.encoder.embed.subsampling_rate
self.embed = model.encoder.embed
self.global_cmvn = model.encoder.global_cmvn
self.required_cache_size = required_cache_size
self.beam_size = beam_size
self.encoder = model.encoder
self.reduce_idx = model.encoder.reduce_idx
self.recover_idx = model.encoder.recover_idx
if self.reduce_idx is None:
self.time_reduce = None
else:
if self.recover_idx is None:
self.time_reduce = 'normal' # no recovery at the end
else:
self.time_reduce = 'recover' # recovery at the end
assert len(self.reduce_idx) == len(self.recover_idx)

def calculate_downsampling_factor(self, i: int) -> int:
if self.reduce_idx is None:
return 1
else:
reduce_exp, recover_exp = 0, 0
for exp, rd_idx in enumerate(self.reduce_idx):
if i >= rd_idx:
reduce_exp = exp + 1
if self.recover_idx is not None:
for exp, rc_idx in enumerate(self.recover_idx):
if i >= rc_idx:
recover_exp = exp + 1
return int(2 ** (reduce_exp - recover_exp))

def forward(self, chunk_xs, chunk_lens, offset,
att_cache, cnn_cache, cache_mask):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset = offset.squeeze(1)
T = chunk_xs.size(1)
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
# B X 1 X T
chunk_mask = chunk_mask.to(chunk_xs.dtype)
# transpose batch & num_layers dim
att_cache = torch.transpose(att_cache, 0, 1)
cnn_cache = torch.transpose(cnn_cache, 0, 1)

# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs = self.global_cmvn(chunk_xs)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
elayers, cache_size = att_cache.size(0), att_cache.size(3)
att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
index = offset - cache_size

pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
pos_emb = pos_emb.to(dtype=xs.dtype)

next_cache_start = -self.required_cache_size
r_cache_mask = att_mask[:, :, next_cache_start:]

r_att_cache = []
r_cnn_cache = []
mask_pad = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
mask_pad = mask_pad.unsqueeze(1)
max_att_len: int = 0
recover_activations: \
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
index = 0
xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
xs = self.encoder.preln(xs)
for i, layer in enumerate(self.encoder.encoders):
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append((xs, att_mask, pos_emb, mask_pad))
xs, xs_lens, att_mask, mask_pad = \
self.encoder.time_reduction_layer(
xs, xs_lens, att_mask, mask_pad)
pos_emb = pos_emb[:, ::2, :]
if self.encoder.pos_enc_layer_type == "rel_pos_repaired":
pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :]
index += 1

if self.recover_idx is not None:
if self.time_reduce == 'recover' and i in self.recover_idx:
index -= 1
(recover_tensor, recover_att_mask,
recover_pos_emb, recover_mask_pad) \
= recover_activations[index]
# recover output length for ctc decode
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
xs = self.encoder.time_recover_layer(xs)
recoverd_t = recover_tensor.size(1)
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
att_mask = recover_att_mask
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad

factor = self.calculate_downsampling_factor(i)

xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i][:, :, ::factor, :]
[:, :, :pos_emb.size(1) - xs.size(1), :] if
elayers > 0 else att_cache[:, :, ::factor, :],
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
)
cached_att \
= new_att_cache[:, :, next_cache_start // factor:, :]
cached_cnn = new_cnn_cache.unsqueeze(1)
cached_att = cached_att.unsqueeze(3). \
repeat(1, 1, 1, factor, 1).flatten(2, 3)
if i == 0:
# record length for the first block as max length
max_att_len = cached_att.size(2)
r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1))
r_cnn_cache.append(cached_cnn)

chunk_out = xs
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers

# <---------forward_chunk END--------->

log_ctc_probs = self.ctc.log_softmax(chunk_out)
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
self.beam_size,
dim=2)
log_probs = log_probs.to(chunk_xs.dtype)

r_offset = offset + chunk_out.shape[1]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens = chunk_lens // self.subsampling_rate
r_offset = r_offset.unsqueeze(1)

return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \
r_offset, r_att_cache, r_cnn_cache, r_cache_mask

Expand Down Expand Up @@ -356,7 +542,12 @@ def export_online_encoder(model, configs, args, logger, encoder_onnx_path):
transformer = True
num_decoding_left_chunks = args.num_decoding_left_chunks
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
encoder = StreamingEncoder(model, required_cache_size, args.beam_size, transformer)
if configs['encoder'] == 'squeezeformer':
encoder = StreamingSqueezeformerEncoder(
model, required_cache_size, args.beam_size)
else:
encoder = StreamingEncoder(
model, required_cache_size, args.beam_size, transformer)
encoder.eval()

# begin to export encoder
Expand Down
3 changes: 2 additions & 1 deletion wenet/squeezeformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.pos_enc_layer_type = pos_enc_layer_type
activation = get_activation(activation_type)

# self-attention module definition
Expand Down Expand Up @@ -236,7 +237,7 @@ def forward(
recover_pos_emb, recover_mask_pad) \
= recover_activations[index]
# recover output length for ctc decode
xs = torch.repeat_interleave(xs, repeats=2, dim=1)
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
xs = self.time_recover_layer(xs)
recoverd_t = recover_tensor.size(1)
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
Expand Down

0 comments on commit 1975ea9

Please sign in to comment.