Skip to content

Commit

Permalink
Add mean_resizing for every VLMs' resizing_token_embeddings() (#35717)
Browse files Browse the repository at this point in the history
* refine all resize_token_embedding()

* ruff format

* hotfix
  • Loading branch information
YenFuLin authored Feb 3, 2025
1 parent 7eecdf2 commit 9d2056f
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 32 deletions.
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)

# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
Expand Down
3 changes: 2 additions & 1 deletion examples/modular-transformers/modular_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)

# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
Expand Down
19 changes: 15 additions & 4 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,11 +1189,11 @@ def set_output_embeddings(self, new_output_embeddings):
# one lm_head for each codebook
self.lm_heads = new_output_embeddings

def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
old_embeddings_list = self.get_input_embeddings()
new_embeddings_list = nn.ModuleList(
[
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
for old_embeddings in old_embeddings_list
]
)
Expand All @@ -1211,7 +1211,10 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
return self.get_input_embeddings()

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Expand All @@ -1230,11 +1233,19 @@ def resize_token_embeddings(
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
mean_resizing (`bool`):
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,8 +1577,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2457,8 +2457,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,8 +1232,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,8 +1184,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,10 @@ def prepare_inputs_for_generation(
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,8 +2319,10 @@ def get_encoder(self):
def get_decoder(self):
return self.led.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,9 +1072,11 @@ def __init__(self, config):
}
self.visual_losses = visual_losses

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
# Adding the following steps to resize bias to match the shape of resized embeddings
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens)
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,8 +1252,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if self.config.share_encoder_decoder_embeddings:
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,8 +1546,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mvp/modeling_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,8 +1370,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,9 +1658,11 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_backbone.model.set_input_embeddings(value)

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True
) -> nn.Embedding:
model_embeds = self.language_backbone.model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,8 +1265,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,8 +1274,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down

0 comments on commit 9d2056f

Please sign in to comment.