diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index db5c57cc79bf..f9ba83006492 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -19,7 +19,7 @@ import six import logging import inspect -from typing import Optional +from typing import Any, Optional import paddle import numpy as np @@ -546,11 +546,31 @@ def resize_token_embeddings(self, self.base_model.config['vocab_size'] = new_num_tokens self.vocab_size = new_num_tokens + # update init_config + self._update_init_config(self.init_config, 'vocab_size', new_num_tokens) + # TODO(westfish@126.com): add tie_weight. # TODO(westfish) Add tie_weight to tie the weights between the input embeddings and the output embeddings if needed. return new_embeddings + def _update_init_config(self, init_config: dict, key: str, value: Any): + """update init_config by pair + + Args: + init_config (dict): the init_config instance + key (str): the key field + value (Any): the new value of instance + """ + if key in init_config: + init_config[key] = value + return + + for arg in init_config.get('init_args', []): + if not isinstance(arg, PretrainedModel): + continue + self._update_init_config(arg.init_config, key, value) + def _get_resized_embeddings( self, old_embeddings: nn.Embedding,