Skip to content

Commit

Permalink
[FIX] Corrects typos from the SpecialTokensEmbeddings code
Browse files Browse the repository at this point in the history
  • Loading branch information
GerrySant committed Nov 30, 2024
1 parent 675bda0 commit 47f1cdb
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions multimodalhugs/modules/special_tokens_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ def __init__(
def forward(self, x, encoder_padding_mask, src_prompt):
"""
It adds and/or corrects the special tokens from the input secuence:
# '<src_lang>', ..., '</s>', '<pad>', '<pad>'
# '<prompt_token_1>', ..., '<prompt_token_N>', ..., '</s>', '<pad>', '<pad>'
INPUTS:
- x: B x N_tokens x Embed_dim
- encoder_padding_mask: B x N_tokens <- 0 indicates padding elements
- src_prompt: B x 1
- src_prompt: B x N_tokens_prompt
"""
print(f"encoder_padding_mask_before: {encoder_padding_mask.shape}")
# Append <src_prompt>:
if src_prompt is not None:
src_prompt = self.special_tokens_embeddings(src_prompt)
Expand All @@ -39,7 +38,7 @@ def forward(self, x, encoder_padding_mask, src_prompt):

# Correct Padding Mask
new_mask_entry = torch.full((encoder_padding_mask.size(0), x.size(1) - encoder_padding_mask.size(1)), 1, dtype=encoder_padding_mask.dtype, device=encoder_padding_mask.device) # torch.Size([B, 1])
encoder_padding_mask = torch.cat([new_mask_entry, encoder_padding_mask], dim=1) # torch.Size([B, 1]) + torch.Size([B, N_tokens]) = torch.Size([B, N_tokens + 1])
encoder_padding_mask = torch.cat([new_mask_entry, encoder_padding_mask], dim=1) # torch.Size([B, N_tokens_prompt]) + torch.Size([B, N_tokens]) = torch.Size([B, N_tokens + N_tokens_prompt])

# Adjust <pad> tokens and add <eos> token to every secuence in the batch:
if self.pad_idx is not None and self.eos_idx is not None:
Expand Down

0 comments on commit 47f1cdb

Please sign in to comment.