diff --git a/wenet/ssl/bestrq/bestqr_model.py b/wenet/ssl/bestrq/bestqr_model.py index 7ecf3c642..053ed3ab7 100644 --- a/wenet/ssl/bestrq/bestqr_model.py +++ b/wenet/ssl/bestrq/bestqr_model.py @@ -13,6 +13,7 @@ def __init__( input_dim: int = 256, embedding_dim: int = 256, num_embeddings: int = 8192, + num_codebooks: int = 1, dropout_rate: float = 0.1, mask_prob: float = 0.01, mask_length: int = 10, @@ -31,7 +32,8 @@ def __init__( self.input_dropout = torch.nn.Dropout(dropout_rate) # [embedding_dim, num_embeddings] - random_embedding_weight = torch.empty(embedding_dim, + random_embedding_weight = torch.empty(num_codebooks, + embedding_dim, num_embeddings, requires_grad=False) self.embeddings = torch.nn.init.normal_(random_embedding_weight) @@ -49,8 +51,9 @@ def __init__( self.input_layer_norm = torch.nn.LayerNorm(input_dim, layer_norm_epsilon) self.encoder = encoder - self.encoder_top_linear = torch.nn.Linear(self.encoder.output_size(), - num_embeddings) + self.encoder_top_n_out = torch.nn.parameter.Parameter( + torch.Tensor(num_codebooks, self.encoder.output_size(), + num_embeddings)) def forward( self, @@ -74,7 +77,11 @@ def forward( out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, masks) # 4 get logits - out = self.encoder_top_linear(out) # [B, T', num_embedding] + out = out.unsqueeze(1) # [B, 1, T', dim] + top_n_out = self.encoder_top_n_out.unsqueeze( + 0) # [num_codebooks, dim, num_embeddings] + out = torch.matmul(out, + top_n_out) # [B, num_codebooks, T', num_embeddings] # 5 compute loss loss = self._compute_loss(out, target_ids, @@ -83,13 +90,12 @@ def forward( def _compute_loss(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor): - input = input.transpose(1, 2) # [B,C,T] - entropy = torch.nn.functional.cross_entropy(input, - target, - reduction='none') # [B,T] + input = input.transpose(1, 3) # [B, num_embeddings, T' num_codebooks] + entropy = torch.nn.functional.cross_entropy( + input, target, reduction='none') # [B, T', num_codebooks] # stop gradient for non mask area - loss = entropy * mask - return loss.sum() / loss.size(0) + loss = entropy * mask.unsqueeze(2) + return loss.sum() / (mask.sum() * loss.size(2)) def _forward_encoder_blocks(self, xs: torch.Tensor, xs_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor): @@ -111,13 +117,17 @@ def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: B, T, C = xs.size() flattened_input = xs.view(-1, C) - embeddings = self.embeddings.to(xs.device) - distance = (torch.sum(flattened_input**2, dim=1, keepdim=True) + - torch.sum(embeddings**2, dim=0, keepdim=False) - - 2 * torch.matmul(flattened_input, embeddings)) - - out = torch.argmin(distance, dim=-1) - return out.reshape(B, T) + embeddings = self.embeddings.to( + xs.device) # [num_codebooks, embedding_dim, num_embeddings] + # [num_codebooks, B*T, num_embeddings] + distance = ( + torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0) + + torch.sum(embeddings**2, dim=1, keepdim=True) - + 2 * torch.matmul(flattened_input.unsqueeze(0), embeddings)) + + out = torch.argmin(distance, dim=-1) # [num_codebooks, B*T] + out = out.transpose(0, 1) # [B*T, num_codebooks] + return out.reshape(B, T, -1) # [B, T, num_codebooks] def _apply_mask(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: