Skip to content

Commit

Permalink
[ssl] bestrq support multiple codebooks (#1754)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Mar 16, 2023
1 parent 0bfdf49 commit cce20d6
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions wenet/ssl/bestrq/bestqr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
input_dim: int = 256,
embedding_dim: int = 256,
num_embeddings: int = 8192,
num_codebooks: int = 1,
dropout_rate: float = 0.1,
mask_prob: float = 0.01,
mask_length: int = 10,
Expand All @@ -31,7 +32,8 @@ def __init__(
self.input_dropout = torch.nn.Dropout(dropout_rate)

# [embedding_dim, num_embeddings]
random_embedding_weight = torch.empty(embedding_dim,
random_embedding_weight = torch.empty(num_codebooks,
embedding_dim,
num_embeddings,
requires_grad=False)
self.embeddings = torch.nn.init.normal_(random_embedding_weight)
Expand All @@ -49,8 +51,9 @@ def __init__(
self.input_layer_norm = torch.nn.LayerNorm(input_dim,
layer_norm_epsilon)
self.encoder = encoder
self.encoder_top_linear = torch.nn.Linear(self.encoder.output_size(),
num_embeddings)
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.Tensor(num_codebooks, self.encoder.output_size(),
num_embeddings))

def forward(
self,
Expand All @@ -74,7 +77,11 @@ def forward(
out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb,
masks)
# 4 get logits
out = self.encoder_top_linear(out) # [B, T', num_embedding]
out = out.unsqueeze(1) # [B, 1, T', dim]
top_n_out = self.encoder_top_n_out.unsqueeze(
0) # [num_codebooks, dim, num_embeddings]
out = torch.matmul(out,
top_n_out) # [B, num_codebooks, T', num_embeddings]

# 5 compute loss
loss = self._compute_loss(out, target_ids,
Expand All @@ -83,13 +90,12 @@ def forward(

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor):
input = input.transpose(1, 2) # [B,C,T]
entropy = torch.nn.functional.cross_entropy(input,
target,
reduction='none') # [B,T]
input = input.transpose(1, 3) # [B, num_embeddings, T' num_codebooks]
entropy = torch.nn.functional.cross_entropy(
input, target, reduction='none') # [B, T', num_codebooks]
# stop gradient for non mask area
loss = entropy * mask
return loss.sum() / loss.size(0)
loss = entropy * mask.unsqueeze(2)
return loss.sum() / (mask.sum() * loss.size(2))

def _forward_encoder_blocks(self, xs: torch.Tensor, xs_masks: torch.Tensor,
pos_emb: torch.Tensor, mask_pad: torch.Tensor):
Expand All @@ -111,13 +117,17 @@ def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:

B, T, C = xs.size()
flattened_input = xs.view(-1, C)
embeddings = self.embeddings.to(xs.device)
distance = (torch.sum(flattened_input**2, dim=1, keepdim=True) +
torch.sum(embeddings**2, dim=0, keepdim=False) -
2 * torch.matmul(flattened_input, embeddings))

out = torch.argmin(distance, dim=-1)
return out.reshape(B, T)
embeddings = self.embeddings.to(
xs.device) # [num_codebooks, embedding_dim, num_embeddings]
# [num_codebooks, B*T, num_embeddings]
distance = (
torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0) +
torch.sum(embeddings**2, dim=1, keepdim=True) -
2 * torch.matmul(flattened_input.unsqueeze(0), embeddings))

out = torch.argmin(distance, dim=-1) # [num_codebooks, B*T]
out = out.transpose(0, 1) # [B*T, num_codebooks]
return out.reshape(B, T, -1) # [B, T, num_codebooks]

def _apply_mask(self,
xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit cce20d6

Please sign in to comment.