From b5b1441478cd89f6c61f81db56cf8435cd80d060 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 13:07:34 +0000 Subject: [PATCH 01/12] v1 --- src/transformers/modeling_utils.py | 19 ++++++++++++++++--- .../models/align/modeling_align.py | 7 ++++--- .../models/altclip/modeling_altclip.py | 12 +++++++----- .../modeling_audio_spectrogram_transformer.py | 7 ++++--- .../models/autoformer/modeling_autoformer.py | 9 +++++---- src/transformers/models/bark/modeling_bark.py | 7 ++++--- src/transformers/models/bart/modeling_bart.py | 9 +++++---- src/transformers/models/beit/modeling_beit.py | 7 ++++--- src/transformers/models/bert/modeling_bert.py | 7 ++++--- .../modeling_bert_generation.py | 7 ++++--- .../models/big_bird/modeling_big_bird.py | 7 ++++--- .../modeling_bigbird_pegasus.py | 9 +++++---- .../models/biogpt/modeling_biogpt.py | 7 ++++--- src/transformers/models/bit/modeling_bit.py | 5 +++-- .../models/blenderbot/modeling_blenderbot.py | 9 +++++---- .../modeling_blenderbot_small.py | 9 +++++---- src/transformers/models/blip/modeling_blip.py | 7 ++++--- .../models/blip/modeling_blip_text.py | 2 +- .../models/blip_2/modeling_blip_2.py | 9 +++++---- .../models/bloom/modeling_bloom.py | 7 ++++--- .../bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- .../models/camembert/modeling_camembert.py | 7 ++++--- .../models/canine/modeling_canine.py | 7 ++++--- .../chinese_clip/modeling_chinese_clip.py | 9 +++++---- src/transformers/models/clap/modeling_clap.py | 9 +++++---- src/transformers/models/clip/modeling_clip.py | 7 ++++--- .../models/clipseg/modeling_clipseg.py | 7 ++++--- .../models/codegen/modeling_codegen.py | 7 ++++--- .../modeling_conditional_detr.py | 7 ++++--- .../models/convbert/modeling_convbert.py | 7 ++++--- .../models/convnext/modeling_convnext.py | 5 +++-- .../models/convnextv2/modeling_convnextv2.py | 5 +++-- .../models/cpmant/modeling_cpmant.py | 5 +++-- .../data2vec/modeling_data2vec_audio.py | 9 +++++---- .../models/data2vec/modeling_data2vec_text.py | 7 ++++--- .../data2vec/modeling_data2vec_vision.py | 7 ++++--- .../models/deberta/modeling_deberta.py | 7 ++++--- .../models/deberta_v2/modeling_deberta_v2.py | 7 ++++--- .../modeling_decision_transformer.py | 7 ++++--- .../modeling_deformable_detr.py | 7 ++++--- src/transformers/models/deit/modeling_deit.py | 7 ++++--- .../models/deprecated/mctct/modeling_mctct.py | 7 ++++--- .../open_llama/modeling_open_llama.py | 7 ++++--- .../modeling_trajectory_transformer.py | 7 ++++--- .../models/deprecated/van/modeling_van.py | 5 +++-- src/transformers/models/deta/modeling_deta.py | 7 ++++--- src/transformers/models/detr/modeling_detr.py | 7 ++++--- .../models/dinat/modeling_dinat.py | 2 +- .../models/dinov2/modeling_dinov2.py | 7 ++++--- .../models/distilbert/modeling_distilbert.py | 7 ++++--- .../models/donut/modeling_donut_swin.py | 7 ++++--- src/transformers/models/dpr/modeling_dpr.py | 5 +++-- src/transformers/models/dpt/modeling_dpt.py | 7 ++++--- .../efficientnet/modeling_efficientnet.py | 5 +++-- .../models/electra/modeling_electra.py | 7 ++++--- .../models/encodec/modeling_encodec.py | 5 +++-- .../modeling_encoder_decoder.py | 6 +++--- .../models/ernie/modeling_ernie.py | 7 ++++--- .../models/ernie_m/modeling_ernie_m.py | 5 +++-- src/transformers/models/esm/modeling_esm.py | 7 ++++--- .../models/falcon/modeling_falcon.py | 7 ++++--- .../models/flava/modeling_flava.py | 7 ++++--- src/transformers/models/fnet/modeling_fnet.py | 7 ++++--- .../models/focalnet/modeling_focalnet.py | 7 ++++--- src/transformers/models/fuyu/modeling_fuyu.py | 5 +++-- src/transformers/models/git/modeling_git.py | 9 +++++---- src/transformers/models/gpt2/modeling_gpt2.py | 7 ++++--- .../gpt_bigcode/modeling_gpt_bigcode.py | 7 ++++--- .../models/gpt_neo/modeling_gpt_neo.py | 7 ++++--- .../models/gpt_neox/modeling_gpt_neox.py | 7 ++++--- .../modeling_gpt_neox_japanese.py | 5 +++-- src/transformers/models/gptj/modeling_gptj.py | 7 ++++--- .../modeling_gptsan_japanese.py | 5 +++-- .../models/graphormer/modeling_graphormer.py | 5 +++-- .../models/groupvit/modeling_groupvit.py | 7 ++++--- .../models/hubert/modeling_hubert.py | 11 ++++++----- .../models/idefics/modeling_idefics.py | 7 ++++--- src/transformers/models/idefics/vision.py | 2 +- .../models/imagegpt/modeling_imagegpt.py | 7 ++++--- .../models/informer/modeling_informer.py | 11 ++++++----- .../instructblip/modeling_instructblip.py | 9 +++++---- .../models/layoutlm/modeling_layoutlm.py | 7 ++++--- .../models/layoutlmv2/modeling_layoutlmv2.py | 7 ++++--- .../models/layoutlmv3/modeling_layoutlmv3.py | 2 +- src/transformers/models/led/modeling_led.py | 9 +++++---- .../models/levit/modeling_levit.py | 5 +++-- src/transformers/models/lilt/modeling_lilt.py | 7 ++++--- .../models/llama/modeling_llama.py | 7 ++++--- .../models/longformer/modeling_longformer.py | 7 ++++--- .../models/longt5/modeling_longt5.py | 5 +++-- src/transformers/models/luke/modeling_luke.py | 7 ++++--- .../models/m2m_100/modeling_m2m_100.py | 9 +++++---- .../models/marian/modeling_marian.py | 9 +++++---- .../models/markuplm/modeling_markuplm.py | 2 +- .../mask2former/modeling_mask2former.py | 2 +- .../models/maskformer/modeling_maskformer.py | 10 ++++++---- .../maskformer/modeling_maskformer_swin.py | 7 ++++--- .../models/mbart/modeling_mbart.py | 9 +++++---- .../megatron_bert/modeling_megatron_bert.py | 7 ++++--- .../models/mgp_str/modeling_mgp_str.py | 5 +++-- .../models/mistral/modeling_mistral.py | 7 ++++--- .../models/mobilevit/modeling_mobilevit.py | 7 ++++--- .../mobilevitv2/modeling_mobilevitv2.py | 7 ++++--- src/transformers/models/mpt/modeling_mpt.py | 7 ++++--- src/transformers/models/mra/modeling_mra.py | 7 ++++--- src/transformers/models/mt5/modeling_mt5.py | 5 +++-- .../models/musicgen/modeling_musicgen.py | 11 ++++++----- src/transformers/models/mvp/modeling_mvp.py | 9 +++++---- src/transformers/models/nat/modeling_nat.py | 2 +- .../models/nezha/modeling_nezha.py | 7 ++++--- .../models/nllb_moe/modeling_nllb_moe.py | 7 ++++--- .../nystromformer/modeling_nystromformer.py | 7 ++++--- .../models/oneformer/modeling_oneformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 7 ++++--- .../models/owlv2/modeling_owlv2.py | 8 +++++--- .../models/owlvit/modeling_owlvit.py | 7 ++++--- .../models/pegasus/modeling_pegasus.py | 9 +++++---- .../models/pegasus_x/modeling_pegasus_x.py | 9 +++++---- .../models/persimmon/modeling_persimmon.py | 7 ++++--- .../models/pix2struct/modeling_pix2struct.py | 12 +++++++----- .../models/plbart/modeling_plbart.py | 9 +++++---- .../models/poolformer/modeling_poolformer.py | 5 +++-- .../models/pop2piano/modeling_pop2piano.py | 5 +++-- .../models/prophetnet/modeling_prophetnet.py | 9 +++++---- src/transformers/models/pvt/modeling_pvt.py | 5 +++-- .../models/qdqbert/modeling_qdqbert.py | 7 ++++--- .../models/realm/modeling_realm.py | 2 +- .../models/regnet/modeling_regnet.py | 5 +++-- .../models/rembert/modeling_rembert.py | 7 ++++--- .../models/resnet/modeling_resnet.py | 5 +++-- .../models/roberta/modeling_roberta.py | 7 ++++--- .../modeling_roberta_prelayernorm.py | 7 ++++--- .../models/roc_bert/modeling_roc_bert.py | 7 ++++--- .../models/roformer/modeling_roformer.py | 7 ++++--- src/transformers/models/rwkv/modeling_rwkv.py | 7 ++++--- src/transformers/models/sam/modeling_sam.py | 2 +- src/transformers/models/sew/modeling_sew.py | 9 +++++---- .../models/sew_d/modeling_sew_d.py | 9 +++++---- .../modeling_speech_encoder_decoder.py | 6 +++--- .../speech_to_text/modeling_speech_to_text.py | 9 +++++---- .../modeling_speech_to_text_2.py | 7 ++++--- .../models/speecht5/modeling_speecht5.py | 11 ++++++----- .../models/splinter/modeling_splinter.py | 7 ++++--- .../swiftformer/modeling_swiftformer.py | 5 +++-- src/transformers/models/swin/modeling_swin.py | 7 ++++--- .../models/swin/modeling_tf_swin.py | 5 ----- .../models/swin2sr/modeling_swin2sr.py | 7 ++++--- .../models/swinv2/modeling_swinv2.py | 7 ++++--- .../modeling_switch_transformers.py | 5 +++-- src/transformers/models/t5/modeling_t5.py | 5 +++-- .../modeling_table_transformer.py | 7 ++++--- .../models/tapas/modeling_tapas.py | 7 ++++--- .../modeling_time_series_transformer.py | 9 +++++---- .../timesformer/modeling_timesformer.py | 7 ++++--- .../models/trocr/modeling_trocr.py | 7 ++++--- src/transformers/models/tvlt/modeling_tvlt.py | 9 +++++---- src/transformers/models/umt5/modeling_umt5.py | 5 +++-- .../models/unispeech/modeling_unispeech.py | 11 ++++++----- .../unispeech_sat/modeling_unispeech_sat.py | 11 ++++++----- .../models/upernet/modeling_upernet.py | 5 +++-- .../models/videomae/modeling_videomae.py | 9 +++++---- src/transformers/models/vilt/modeling_vilt.py | 7 ++++--- .../modeling_vision_encoder_decoder.py | 6 +++--- .../visual_bert/modeling_visual_bert.py | 7 ++++--- src/transformers/models/vit/modeling_vit.py | 7 ++++--- .../models/vit_hybrid/modeling_vit_hybrid.py | 7 ++++--- .../models/vit_mae/modeling_vit_mae.py | 9 +++++---- .../models/vit_msn/modeling_vit_msn.py | 7 ++++--- .../models/vitdet/modeling_vitdet.py | 7 ++++--- .../models/vitmatte/modeling_vitmatte.py | 5 +++-- src/transformers/models/vits/modeling_vits.py | 7 ++++--- .../models/vivit/modeling_vivit.py | 7 ++++--- .../models/wav2vec2/modeling_wav2vec2.py | 11 ++++++----- .../modeling_wav2vec2_conformer.py | 9 +++++---- .../models/wavlm/modeling_wavlm.py | 11 ++++++----- .../models/whisper/modeling_whisper.py | 9 +++++---- .../models/x_clip/modeling_x_clip.py | 9 +++++---- src/transformers/models/xglm/modeling_xglm.py | 7 ++++--- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 9 +++++---- .../xlm_roberta/modeling_xlm_roberta.py | 7 ++++--- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xmod/modeling_xmod.py | 7 ++++--- .../models/yolos/modeling_yolos.py | 7 ++++--- src/transformers/models/yoso/modeling_yoso.py | 7 ++++--- ...ng_{{cookiecutter.lowercase_modelname}}.py | 16 +++++++++------- tests/test_modeling_common.py | 7 +++++++ 187 files changed, 750 insertions(+), 562 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0317695f2096..4da8bfcc9043 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import functools import gc import importlib.metadata import inspect @@ -1819,16 +1820,28 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + gradient_checkpointing_func = functools.partial( + torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs + ) + + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True @@ -1845,7 +1858,7 @@ def gradient_checkpointing_disable(self): activations". """ if self.supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 6cbf01a3432c..bad7db0150c8 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1102,7 +1102,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1197,9 +1197,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AlignTextModel, AlignVisionModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index c4e32de55d9c..56b9657aecb7 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -653,7 +653,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -967,7 +967,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1089,11 +1089,13 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, AltCLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, AltRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 28969f50b672..2a895dc073ba 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -343,7 +343,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -395,9 +395,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.weight.data.fill_(1.0) # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST - def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ASTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 96298c77a344..278811d23d9a 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -946,9 +946,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUTOFORMER_START_DOCSTRING = r""" @@ -1214,7 +1215,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1433,7 +1434,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 649719e0eefa..8ffb22fd3e5d 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -313,9 +313,10 @@ def device(self) -> torch.device: return get_parameter_device(self) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BARK_MODEL_START_DOCSTRING = """ @@ -645,7 +646,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9e7763ca23d8..2af67b87b739 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -521,9 +521,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BartDecoder, BartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -861,7 +862,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1118,7 +1119,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d698cff88b14..d30eff63f541 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -572,9 +572,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BeitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1b0fad3f9d65..993160d6998a 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -600,7 +600,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -762,9 +762,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index abe2d828b28b..c811f2d19d31 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -408,7 +408,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -607,9 +607,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e266b1a67b7d..6677e658b8dd 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1624,7 +1624,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1784,9 +1784,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BigBirdEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIG_BIRD_START_DOCSTRING = r""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4e279f9dc059..98d8ae83179f 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1609,9 +1609,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1950,7 +1951,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2297,7 +2298,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index ca084db5c7d0..6597d2ea04e6 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -376,9 +376,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BioGptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIOGPT_START_DOCSTRING = r""" @@ -598,7 +599,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 12a5ecd42b74..d02861d6343d 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -669,9 +669,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1db81905210b..caaf59d289a2 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -483,9 +483,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -784,7 +785,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1040,7 +1041,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 129de3dd1456..d72ee4ceb558 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -480,9 +480,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -782,7 +783,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1037,7 +1038,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 9fca7c28a1a0..59f2590d04ee 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -461,9 +461,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BlipEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_START_DOCSTRING = r""" @@ -629,7 +630,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 49b958afc2eb..317eea1e1b6e 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -429,7 +429,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bd56b17e55c2..bcb6f4f7b6c6 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -297,9 +297,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Blip2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_2_START_DOCSTRING = r""" @@ -480,7 +481,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -951,7 +952,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d90bb6ad8fdf..688415ac1121 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -496,9 +496,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, BloomModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_standard_cache( @@ -769,7 +770,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ce569157b811..d64532170bbf 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -811,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index a8ea8d49195b..603dc2a52b8f 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -658,7 +658,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, bbox_pos_emb, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 8d7d279579e3..44764f900abb 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -531,7 +531,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -625,9 +625,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CamembertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CAMEMBERT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 657104ad6965..9625e97ea28b 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -802,7 +802,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -919,9 +919,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CanineEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CANINE_START_DOCSTRING = r""" diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 7bab0aea6eb9..1f4a42732d7d 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -742,9 +742,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CHINESE_CLIP_START_DOCSTRING = r""" @@ -916,7 +917,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1025,7 +1026,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, ) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1d17a5188387..ccee38322c0b 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -946,7 +946,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -1602,7 +1602,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1701,9 +1701,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ClapTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 3a894b9727c9..e179244a1c32 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIP_START_DOCSTRING = r""" @@ -646,7 +647,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 96f13217aaf8..385737bafe33 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -479,9 +479,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPSegEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIPSEG_START_DOCSTRING = r""" @@ -655,7 +656,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 172a45544bac..464eeebc9ba0 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -337,9 +337,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CodeGenModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CODEGEN_START_DOCSTRING = r""" @@ -548,7 +549,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 15f24084f469..2a4812eaf048 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1169,9 +1169,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConditionalDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONDITIONAL_DETR_START_DOCSTRING = r""" @@ -1523,7 +1524,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index a6fccf5b72b4..927c026df777 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -264,9 +264,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SeparableConv1D(nn.Module): @@ -639,7 +640,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e6cf336517a5..e11112b53222 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -296,9 +296,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 3a268c713d50..f1ff89bb1243 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -317,9 +317,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXTV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 6d2dc596fa65..8a6c744ed69e 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -556,9 +556,10 @@ def _init_weights(self, module): elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CpmAntEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CPMANT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 4435e9b8d017..6d8bb5c2058c 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -300,7 +300,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -600,7 +600,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -761,9 +761,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_AUDIO_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a521ccb39aaf..66588647f61b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -613,9 +613,10 @@ def _init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VECTEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index f8fe59587af0..e7fd98091f97 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -529,7 +529,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -585,9 +585,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 6f6c2af63a67..06a33a7dd85c 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -464,7 +464,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), next_kv, attention_mask, @@ -839,9 +839,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index eda4f406cb31..2172f5d22eef 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -508,7 +508,7 @@ def custom_forward(*inputs): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), next_kv, attention_mask, @@ -938,9 +938,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8e5053a4160d..3865fe523f71 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -469,9 +469,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DecisionTransformerGPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -639,7 +640,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index f541ca130544..7e04d2a1c760 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1086,9 +1086,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DeformableDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEFORMABLE_DETR_START_DOCSTRING = r""" @@ -1388,7 +1389,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 38c28dbbedc6..ff95a458ad77 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -364,7 +364,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -415,9 +415,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, DeiTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index eca5ba014e51..e38b89a0a444 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -504,9 +504,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MCTCTEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MCTCT_START_DOCSTRING = r""" @@ -623,7 +624,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 6853f5333f13..f021714be250 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -456,9 +456,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OpenLlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPEN_LLAMA_INPUTS_DOCSTRING = r""" @@ -673,7 +674,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 75415dbe77bf..13a26b6c05d5 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrajectoryTransformerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -557,7 +558,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, layer_past, diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 4ef18f54158f..52c9e1242422 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -387,9 +387,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VanModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 9cd29e940887..2c5890e0a357 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -977,9 +977,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetaDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETA_START_DOCSTRING = r""" @@ -1280,7 +1281,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3dda00a20082..4200e6556d50 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -925,9 +925,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETR_START_DOCSTRING = r""" @@ -1258,7 +1259,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 89c6ed2e2a88..eb4d3f2ff296 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -660,7 +660,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DinatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6e4446faddd5..656a3022c96f 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -454,7 +454,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -516,9 +516,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Dinov2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DINOV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index f26b5846972d..de3c125abbac 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -365,7 +365,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_state, attn_mask, @@ -430,9 +430,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Transformer): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DISTILBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 0d833406e259..1a1e215f9a6d 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -756,7 +756,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -826,9 +826,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DonutSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 944ce142b0ad..c258343f6cfd 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -164,9 +164,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DPREncoder(DPRPreTrainedModel): diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c36656a..b13ca04626cf 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -535,7 +535,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -818,9 +818,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DPTViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DPT_START_DOCSTRING = r""" diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 478aeecee02b..d1b2c9940343 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -500,9 +500,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EfficientNetBlock): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index da3ee8e51d36..a7d943450a86 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -578,7 +578,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -692,9 +692,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ElectraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 697fb3c94fbb..28c20da3d5eb 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -473,9 +473,10 @@ def _init_weights(self, module): elif "bias" in name: nn.init.constant_(param, 0.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (EncodecEncoder, EncodecDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ENCODEC_START_DOCSTRING = r""" diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 3548e48c595a..d64860d6263e 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -265,10 +265,10 @@ def tie_weights(self): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d55155f80093..b178ca354495 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -513,7 +513,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -680,9 +680,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/ernie_m/modeling_ernie_m.py b/src/transformers/models/ernie_m/modeling_ernie_m.py index 9c53ddd73c85..b26ee0fcafd1 100755 --- a/src/transformers/models/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/ernie_m/modeling_ernie_m.py @@ -429,9 +429,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 7a07495ba7e5..3115a1357ea6 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -612,7 +612,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -710,9 +710,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EsmEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e9dca6df9894..29873a39457f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -945,9 +945,10 @@ def _init_weights(self, module: nn.Module): module.weight.data.fill_(1.0) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, FalconModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_cache_to_standard_format( @@ -1155,7 +1156,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8de647c8299a..9b5faaeb15f6 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -670,7 +670,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -879,9 +879,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, FlavaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 45042147761d..299b607b6b8a 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -299,7 +299,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(create_custom_forward(layer_module), hidden_states) else: layer_outputs = layer_module(hidden_states) @@ -431,9 +431,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 8d18a8c63fda..0e33dc4f66f4 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -593,7 +593,7 @@ def custom_forward(*inputs): return custom_forward - stage_outputs = torch.utils.checkpoint.checkpoint( + stage_outputs = self.gradient_checkpointing_func( create_custom_forward(stage_module), hidden_states, input_dimensions, @@ -659,9 +659,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FocalNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FOCALNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index b14b1b0b871d..141976ef21c3 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -70,9 +70,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FuyuForCausalLM): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FUYU_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 00707e42dd08..bcbee566fa24 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -459,7 +459,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -533,9 +533,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GitEncoder, GitVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GIT_START_DOCSTRING = r""" @@ -885,7 +886,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 838e7ca29925..fd726627bb1b 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -480,9 +480,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -885,7 +886,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index be90f61e45bf..3bcb4a865812 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -405,9 +405,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_BIGCODE_START_DOCSTRING = r""" @@ -658,7 +659,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3ad49554c0ac..494187a33aa4 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,9 +384,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_NEO_START_DOCSTRING = r""" @@ -612,7 +613,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9391805a77b8..19560dc6c975 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -78,9 +78,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXAttention(nn.Module): @@ -649,7 +650,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 98753edeb544..c1c5527a4655 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -66,9 +66,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXJapaneseModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXJapaneseAttention(nn.Module): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 6b5607f235b1..a51d4bdd094c 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -361,9 +361,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTJModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPTJ_START_DOCSTRING = r""" @@ -675,7 +676,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 24917fcfdb07..84d956c9f57e 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -759,9 +759,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GPTSanJapaneseAttention,)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right def _shift_right(self, input_ids): diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index 8247745a3bc3..68ed6d265e70 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -772,9 +772,10 @@ def _init_weights( module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GraphormerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GraphormerModel(GraphormerPreTrainedModel): diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 59ff60ed765a..d4199891f6c9 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -805,9 +805,10 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GROUPVIT_START_DOCSTRING = r""" @@ -1038,7 +1039,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1a7bde45efc1..9acb52c2aedb 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -353,7 +353,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -738,7 +738,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -828,7 +828,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -895,9 +895,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 316f36561308..d3f9c5da4d2d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -978,9 +978,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, IdeficsModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1339,7 +1340,7 @@ def vblock( ) use_cache = False - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( vblock, decoder_layer, hidden_states, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index d4966a240d84..eb2b836169d6 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -408,7 +408,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 54edcd30fc87..f3ebc9324260 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -525,9 +525,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ImageGPTModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None IMAGEGPT_START_DOCSTRING = r""" @@ -824,7 +825,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index e7b35174ca7e..5b93a16d3e02 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -924,9 +924,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (InformerDecoder, InformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INFORMER_START_DOCSTRING = r""" @@ -1222,14 +1223,14 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1446,7 +1447,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 082900a6652f..7b02ee85020c 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -304,9 +304,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, InstructBlipEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INSTRUCTBLIP_START_DOCSTRING = r""" @@ -469,7 +470,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -946,7 +947,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 884a2799728b..82531ab7a455 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -494,7 +494,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -638,9 +638,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LAYOUTLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index ef970edfdc91..30ff103bea7d 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -446,7 +446,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -514,9 +514,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 30ab0a5e8620..42162dcfb2e5 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -668,7 +668,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f0c22ed9502c..1029a7950a2e 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1155,9 +1155,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LEDDecoder, LEDEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1883,7 +1884,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2150,7 +2151,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 0accc28391bd..5acaaeba9004 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -507,9 +507,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LevitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 46fe2d3e9cd7..65c381fc50a9 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -521,7 +521,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layout_inputs, @@ -607,9 +607,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LiltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b67719ac3271..5664a581ffb7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -705,9 +705,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -921,7 +922,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids ) else: diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 33bf9a6f9268..3b77ad46aed3 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1311,7 +1311,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1439,9 +1439,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LongformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c80d2349832c..4c6ff76cc95d 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1341,9 +1341,10 @@ def _init_weights(self, module): ) # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LongT5Attention, LongT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 6913ede09d1c..fde39d0999af 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -795,7 +795,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), word_hidden_states, entity_hidden_states, @@ -920,9 +920,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LukeEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 6db8bbb5213b..264aff5b4aac 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -552,9 +552,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (M2M100Decoder, M2M100Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None M2M_100_START_DOCSTRING = r""" @@ -827,7 +828,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1074,7 +1075,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 69de5b2e7d0e..a0ab7192718b 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -500,9 +500,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MarianDecoder, MarianEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -795,7 +796,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1045,7 +1046,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 530c66a0c80b..9686b0a1d305 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -655,7 +655,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e839b16f6257..9ec586a17bb3 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1871,7 +1871,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 87b91ed64b62..8502a6a368ea 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -855,7 +855,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, @@ -1619,11 +1619,13 @@ def _init_weights(self, module: nn.Module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerPixelLevelModule): - module.encoder.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 357ac9d4aaca..fe9dbc91f801 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -695,7 +695,7 @@ def custom_forward(*inputs): return custom_forward - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( + layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask ) else: @@ -752,9 +752,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b53ad8848dd3..644c5d292b0e 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -516,9 +516,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MBartDecoder, MBartDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -835,7 +836,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1094,7 +1095,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5d0ad6e3410c..16d463dcb470 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -558,7 +558,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -728,9 +728,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MegatronBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 5d1f5bea7bfd..1257b4df39c0 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -333,9 +333,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: MgpstrEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, MgpstrEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MGP_STR_START_DOCSTRING = r""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d650d60b8a55..1544ebeaaf81 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -689,9 +689,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MistralModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MISTRAL_INPUTS_DOCSTRING = r""" @@ -926,7 +927,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index c3accb21e05e..0653321df9c3 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -633,7 +633,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) @@ -672,9 +672,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5a0e08d7344d..5aca04266e46 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -589,7 +589,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) @@ -629,9 +629,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVITV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index d760bec9854a..279b0bc903a5 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -294,9 +294,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, MptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_mpt_cache( @@ -531,7 +532,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index d400fea6d23d..672e2666533d 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -773,7 +773,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -871,9 +871,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MRA_START_DOCSTRING = r""" diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 186db94dad7f..2e2b68060dc9 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -845,9 +845,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MT5Attention, MT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index bcc6bc82a2f5..6bee6f35dc7d 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -475,9 +475,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MusicgenDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MUSICGEN_START_DOCSTRING = r""" @@ -1562,10 +1563,10 @@ def tie_weights(self): self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.text_encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5c1ed05249ef..f44e067aac31 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -563,9 +563,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -956,7 +957,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1235,7 +1236,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index ecc745b558dd..4f7206a5e8ec 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -639,7 +639,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: NatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index fa31e94f4d2e..5a94e43291cb 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -584,7 +584,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -752,9 +752,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NezhaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 3701bbecef2e..6c42ffa95b2d 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -874,9 +874,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NLLB_MOE_START_DOCSTRING = r""" @@ -1160,7 +1161,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 51ee73ab72d3..3c5df5dedd2e 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -377,7 +377,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -477,9 +477,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NystromformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NYSTROMFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 5b6220f88169..165684542d85 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2616,7 +2616,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = self.gradient_checkpointing_func(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f3f24652434..c97d57fa236f 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -411,9 +411,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (OPTDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPT_INPUTS_DOCSTRING = r""" @@ -699,7 +700,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, causal_attention_mask, diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 451cc4a69126..5aee16cc8106 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -584,9 +584,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Owlv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLV2_START_DOCSTRING = r""" @@ -771,7 +772,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1378,6 +1379,7 @@ def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor: """Predicts the probability that each image feature token is an object. + Args: image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)): Features extracted from the image. diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 66cfb8092df5..b5317ea1c1b8 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -576,9 +576,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OwlViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLVIT_START_DOCSTRING = r""" @@ -760,7 +761,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 55856f7b06b6..705cf956f784 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -500,9 +500,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusDecoder, PegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_START_DOCSTRING = r""" @@ -810,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1095,7 +1096,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index e87e9c7164ab..5f5888429231 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -780,9 +780,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_X_START_DOCSTRING = r""" @@ -1078,7 +1079,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, global_hidden_states, @@ -1339,7 +1340,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a0bc57263823..c6092e158c93 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PersimmonModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PERSIMMON_INPUTS_DOCSTRING = r""" @@ -677,7 +678,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 58041820c1fb..31cedc13359f 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -350,7 +350,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -563,9 +563,10 @@ def __init__(self, config: Pix2StructConfig): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Pix2StructVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def get_input_embeddings(self): return self.embeddings.patch_projection @@ -1320,9 +1321,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pix2StructTextAttention, Pix2StructTextModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3a880839236d..a079b0bf0cf5 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -517,9 +517,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PLBartDecoder, PLBartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PLBART_START_DOCSTRING = r""" @@ -814,7 +815,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1072,7 +1073,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 6acc8ec98e69..209533e31990 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -282,9 +282,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PoolFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None POOLFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5a67b8044b09..acb43f824b7b 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -739,9 +739,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 241a9efea36a..04d2b946eafc 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -557,9 +557,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1336,7 +1337,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1577,7 +1578,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 2dd452ec1df1..356b7c14afa8 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -489,9 +489,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ) - def _set_gradient_checkpointing(self, module: PvtEncoder, value: bool = False): + def _set_gradient_checkpointing(self, module: PvtEncoder, gradient_checkpointing_func=None): if isinstance(module, PvtEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PVT_START_DOCSTRING = r""" diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index fead8fc0cf7f..cf307fb35009 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -588,7 +588,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -757,9 +757,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, QDQBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index aa738d782b7b..8f7d0a656002 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -593,7 +593,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 07ef29fd3332..21050f07fda4 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -293,9 +293,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RegNetModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REGNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 235bff89f8a3..6dd04ed4030c 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -550,7 +550,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -673,9 +673,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RemBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REMBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index f2d207c2189f..e6b1d85b2a46 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -283,9 +283,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ResNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None RESNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 6d4cc991d22c..d7ead17b4544 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -612,9 +612,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index da1cd6331bc3..4ae7a308f68e 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -519,7 +519,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -615,9 +615,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaPreLayerNormEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index a5b1b63050b1..d1b84ab31f63 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -651,7 +651,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -796,9 +796,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoCBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROC_BERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b9c36a305ff1..e860ff34eb52 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -585,7 +585,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -715,9 +715,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index db41bd3c9538..bbe21949f0e6 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -466,9 +466,10 @@ def _init_weights(self, module): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RwkvModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -684,7 +685,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states, state, attentions = torch.utils.checkpoint.checkpoint( + hidden_states, state, attentions = self.gradient_checkpointing_func( create_custom_forward(block), hidden_states, state ) else: diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index abf5544a5b4d..f5cd7cf0a45b 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1049,7 +1049,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 34f9c84235cc..44cbcec5267a 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -367,7 +367,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -680,7 +680,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -756,9 +756,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SEWEncoder, SEWFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 661a8c03b1a5..74374e1a4eb9 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -460,7 +460,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -1141,7 +1141,7 @@ def custom_forward(*inputs): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = self.gradient_checkpointing_func( create_custom_forward(layer_module), next_kv, attention_mask, @@ -1322,9 +1322,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SEWDTransformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SEWD_START_DOCSTRING = r""" diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index e80c26e2698d..ec255fab9bc7 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -249,10 +249,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 31c9b6cfe935..acdcc2f902cb 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -559,9 +559,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -824,7 +825,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1073,7 +1074,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index f9b5dec42092..9d863ba3e2f2 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -437,9 +437,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Speech2Text2Decoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPEECH_TO_TEXT_2_START_DOCSTRING = r""" @@ -677,7 +678,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 9b8ab3d3805a..ef374bbb32e7 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -527,7 +527,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -1281,9 +1281,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SpeechT5Encoder(SpeechT5PreTrainedModel): @@ -1393,7 +1394,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1723,7 +1724,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f72ffb10111b..f1ab50179dea 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -466,7 +466,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -544,9 +544,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SplinterEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPLINTER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index ff72f87506d3..4170ce153bbf 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -442,9 +442,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) - def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, SwiftFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIFTFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 45a7aa718cf0..228d962dea1d 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -832,7 +832,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -901,9 +901,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 02ec39edb0fe..5d5356144245 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel): config_class = SwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" - supports_gradient_checkpointing = True - - def _set_gradient_checkpointing(self, module, value=False) -> None: - if isinstance(module, TFSwinEncoder): - module.gradient_checkpointing = value SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index a8a17bdf584b..db8ff6a652eb 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -753,7 +753,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -802,9 +802,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swin2SREncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN2SR_START_DOCSTRING = r""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index a4224e16df3c..fda0e080d0d8 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -913,7 +913,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -983,9 +983,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swinv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWINV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0a402ea2d6af..ed0d59abb8bc 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -865,9 +865,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0e7237ea36b6..603c6a4730e8 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -873,9 +873,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (T5Attention, T5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 8f59bd4b6e17..fb42673ae5c5 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -831,9 +831,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TableTransformerDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TABLE_TRANSFORMER_START_DOCSTRING = r""" @@ -1150,7 +1151,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index cdaa4b3e2725..e6ce415899fa 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -653,7 +653,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -778,9 +778,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TapasEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TAPAS_START_DOCSTRING = r""" diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 2caca5bd1051..c550f89e9504 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -663,9 +663,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" @@ -953,7 +954,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1171,7 +1172,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 676bcf7a5e27..df7dd1c953f5 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -446,7 +446,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, ) @@ -494,9 +494,10 @@ def _init_weights(self, module): nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) module.patch_embeddings.apply(self._init_weights) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TimesformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIMESFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c0541814be46..6971b4dfb21a 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -454,9 +454,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrOCRDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TROCR_START_DOCSTRING = r""" @@ -709,7 +710,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 464c3e76a11f..8852083c4694 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -567,7 +567,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -616,9 +616,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TvltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TVLT_START_DOCSTRING = r""" @@ -884,7 +885,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index ffafd1581140..e6e9aaa26a38 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -556,9 +556,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UMT5Attention, UMT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c475ab7f80f8..8f667d3d564c 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -391,7 +391,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -774,7 +774,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -864,7 +864,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1039,9 +1039,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_START_DOCSTRING = r""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 3fcc9549bbdc..5584929ab11c 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -405,7 +405,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -788,7 +788,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -878,7 +878,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1053,9 +1053,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_SAT_START_DOCSTRING = r""" diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index b56b508d14ae..04b8c94e1351 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -315,9 +315,10 @@ def init_weights(self): if self.auxiliary_head is not None: self.auxiliary_head.init_weights() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UPERNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 07c32d149290..9657a747fbc3 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -441,7 +441,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -489,9 +489,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VideoMAEEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIDEOMAE_START_DOCSTRING = r""" @@ -733,7 +734,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a36d58bd235b..9c8fee6c79f3 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -538,7 +538,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -591,9 +591,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d3e464cbfffa..84275cc33a76 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -225,10 +225,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 81ad1068483a..c2eaa90b4864 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -425,7 +425,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -547,9 +547,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VisualBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8fdacdddf04c..050d02ee2990 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -404,7 +404,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -467,9 +467,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 008f6b3c9db5..fa4b3471e375 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -422,7 +422,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -486,9 +486,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTHybridEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f3686..b468075d08ff 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -543,7 +543,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -591,9 +591,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViTMAEEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MAE_START_DOCSTRING = r""" @@ -800,7 +801,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d622c..87779dd3ae94 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -394,7 +394,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -444,9 +444,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTMSNEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MSN_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e89fdbd7a336..fd9f26923444 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -572,7 +572,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -666,9 +666,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.norm3.weight.data.zero_() module.norm3.bias.data.zero_() - def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, VitDetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITDET_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index b23bdd21d56b..18b8b80b328c 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -86,9 +86,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b9a1f1ae15..7b7899ee287f 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1174,7 +1174,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, padding_mask, @@ -1296,9 +1296,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (VitsTextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITS_START_DOCSTRING = r""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index fd35668572a7..5e07b1544b2b 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -345,7 +345,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -414,9 +414,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=self.config.initializer_range) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VivitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a6e02a0476f1..c02e23660b56 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -458,7 +458,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -810,7 +810,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -899,7 +899,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1173,9 +1173,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_adapters(self): if self.config.adapter_attn_dim is None: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f162c5142970..edcdcf4a22ac 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -525,7 +525,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -918,7 +918,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1178,9 +1178,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAV2VEC2_CONFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5013837cbdce..182482dfd83a 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -361,7 +361,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self.gradient_checkpointing_func( create_custom_forward(conv_layer), hidden_states, ) @@ -720,7 +720,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -811,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer), hidden_states, attention_mask, @@ -1052,9 +1052,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAVLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8962324471ca..f4b8fb4852a9 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -685,9 +685,10 @@ def _init_weights(self, module): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WhisperDecoder, WhisperEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -949,7 +950,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, None, @@ -1182,7 +1183,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index da7eddff8df8..025533ab4178 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -534,9 +534,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None X_CLIP_START_DOCSTRING = r""" @@ -710,7 +711,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -957,7 +958,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 0c769dbbb5f3..16f0402abf98 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -503,9 +503,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XGLMModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( @@ -682,7 +683,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index cde05cfe8a8a..e599cc3cede7 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -570,9 +570,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1356,7 +1357,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1600,7 +1601,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index da454b1e3331..b195ee43723e 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -518,7 +518,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -614,9 +614,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XLMRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 26e0361abdb5..0e3ed4eeb986 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -506,7 +506,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 28fddc2fdbd6..2eb0ba83d726 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -580,7 +580,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, lang_ids, @@ -680,9 +680,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XmodEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def set_default_language(self, language: str): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae6e..0884529777c2 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -499,7 +499,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -551,9 +551,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, YolosEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOLOS_START_DOCSTRING = r""" diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 5edd7f883542..0159d7fb76d3 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -568,7 +568,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -668,9 +668,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, YosoEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOSO_START_DOCSTRING = r""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 02fcb7d2f511..ee583cec3548 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -550,7 +550,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -679,9 +679,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2024,9 +2025,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2319,7 +2321,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2558,7 +2560,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 34f5bae3746f..74ea66fd6fa0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -569,6 +569,13 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + model.gradient_checkpointing_disable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions") From 449b4a4c3cbd3d5edeab0dd11d34570fddcb7073 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 13:32:57 +0000 Subject: [PATCH 02/12] fix --- .../models/seamless_m4t/modeling_seamless_m4t.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 7df6fcd98907..6b0e7a7ff079 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -900,7 +900,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, relative_position_embeddings, @@ -1547,9 +1547,10 @@ def _init_weights(self, module): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride @@ -1864,7 +1865,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2139,7 +2140,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, From 6b4ab9f0c2a43e3a729982824bb00e481815a063 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 14:37:13 +0000 Subject: [PATCH 03/12] remove `create_custom_forward` --- .../models/align/modeling_align.py | 11 +++------- .../models/altclip/modeling_altclip.py | 21 +++++-------------- .../modeling_audio_spectrogram_transformer.py | 10 ++------- .../models/autoformer/modeling_autoformer.py | 12 +++-------- src/transformers/models/bark/modeling_bark.py | 12 +++-------- src/transformers/models/bart/modeling_bart.py | 12 +++-------- src/transformers/models/beit/modeling_beit.py | 10 ++------- src/transformers/models/bert/modeling_bert.py | 11 +++------- .../modeling_bert_generation.py | 11 +++------- .../models/big_bird/modeling_big_bird.py | 11 +++------- .../modeling_bigbird_pegasus.py | 12 +++-------- .../models/biogpt/modeling_biogpt.py | 12 +++-------- .../models/blenderbot/modeling_blenderbot.py | 12 +++-------- .../modeling_blenderbot_small.py | 12 +++-------- src/transformers/models/blip/modeling_blip.py | 10 ++------- .../models/blip/modeling_blip_text.py | 11 +++------- .../models/blip_2/modeling_blip_2.py | 12 +++-------- .../models/bloom/modeling_bloom.py | 12 +++-------- .../bridgetower/modeling_bridgetower.py | 11 +++------- src/transformers/models/bros/modeling_bros.py | 10 ++------- .../models/camembert/modeling_camembert.py | 11 +++------- .../models/canine/modeling_canine.py | 10 ++------- .../chinese_clip/modeling_chinese_clip.py | 13 ++++-------- src/transformers/models/clap/modeling_clap.py | 20 ++++-------------- src/transformers/models/clip/modeling_clip.py | 10 ++------- .../models/clipseg/modeling_clipseg.py | 10 ++------- .../models/codegen/modeling_codegen.py | 12 +++-------- .../modeling_conditional_detr.py | 9 +------- .../models/convbert/modeling_convbert.py | 3 ++- .../data2vec/modeling_data2vec_audio.py | 11 ++-------- .../models/data2vec/modeling_data2vec_text.py | 11 +++------- .../data2vec/modeling_data2vec_vision.py | 10 ++------- .../models/deberta/modeling_deberta.py | 10 ++------- .../models/deberta_v2/modeling_deberta_v2.py | 10 ++------- .../modeling_decision_transformer.py | 12 +++-------- .../modeling_deformable_detr.py | 8 +------ src/transformers/models/deit/modeling_deit.py | 10 ++------- .../models/deprecated/mctct/modeling_mctct.py | 10 ++------- .../open_llama/modeling_open_llama.py | 12 +++-------- .../modeling_trajectory_transformer.py | 9 +------- src/transformers/models/deta/modeling_deta.py | 8 +------ src/transformers/models/detr/modeling_detr.py | 9 +------- .../models/dinov2/modeling_dinov2.py | 10 ++------- .../models/distilbert/modeling_distilbert.py | 10 ++------- .../models/donut/modeling_donut_swin.py | 9 +------- src/transformers/models/dpt/modeling_dpt.py | 10 ++------- .../models/electra/modeling_electra.py | 11 +++------- .../models/ernie/modeling_ernie.py | 11 +++------- src/transformers/models/esm/modeling_esm.py | 11 +++------- .../models/falcon/modeling_falcon.py | 13 ++++-------- .../models/flava/modeling_flava.py | 10 ++------- src/transformers/models/fnet/modeling_fnet.py | 9 +------- .../models/focalnet/modeling_focalnet.py | 9 +------- src/transformers/models/git/modeling_git.py | 21 +++++-------------- src/transformers/models/gpt2/modeling_gpt2.py | 12 +++-------- .../gpt_bigcode/modeling_gpt_bigcode.py | 12 +++-------- .../models/gpt_neo/modeling_gpt_neo.py | 12 +++-------- .../models/gpt_neox/modeling_gpt_neox.py | 13 ++++-------- src/transformers/models/gptj/modeling_gptj.py | 12 +++-------- .../models/groupvit/modeling_groupvit.py | 10 ++------- .../models/hubert/modeling_hubert.py | 13 +++--------- src/transformers/models/idefics/vision.py | 10 ++------- .../models/imagegpt/modeling_imagegpt.py | 12 +++-------- .../models/informer/modeling_informer.py | 12 +++-------- .../instructblip/modeling_instructblip.py | 12 +++-------- .../models/layoutlm/modeling_layoutlm.py | 11 +++------- .../models/layoutlmv2/modeling_layoutlmv2.py | 3 ++- .../models/layoutlmv3/modeling_layoutlmv3.py | 13 +----------- src/transformers/models/led/modeling_led.py | 13 ++++-------- src/transformers/models/lilt/modeling_lilt.py | 10 ++------- .../models/llama/modeling_llama.py | 17 +++++++-------- .../models/longformer/modeling_longformer.py | 11 +++------- .../models/longt5/modeling_longt5.py | 11 +++------- src/transformers/models/luke/modeling_luke.py | 10 ++------- .../models/m2m_100/modeling_m2m_100.py | 12 +++-------- .../models/marian/modeling_marian.py | 12 +++-------- .../models/markuplm/modeling_markuplm.py | 11 +++------- .../mask2former/modeling_mask2former.py | 10 ++------- .../models/maskformer/modeling_maskformer.py | 10 ++------- .../maskformer/modeling_maskformer_swin.py | 12 ++++------- .../models/mbart/modeling_mbart.py | 12 +++-------- .../megatron_bert/modeling_megatron_bert.py | 11 +++------- .../models/mistral/modeling_mistral.py | 14 +++++-------- .../models/mobilevit/modeling_mobilevit.py | 9 +------- .../mobilevitv2/modeling_mobilevitv2.py | 9 +------- src/transformers/models/mpt/modeling_mpt.py | 12 +++-------- src/transformers/models/mra/modeling_mra.py | 9 +------- src/transformers/models/mt5/modeling_mt5.py | 11 +++------- .../models/musicgen/modeling_musicgen.py | 12 +++-------- src/transformers/models/mvp/modeling_mvp.py | 12 +++-------- .../models/nezha/modeling_nezha.py | 11 +++------- .../models/nllb_moe/modeling_nllb_moe.py | 12 +++-------- .../nystromformer/modeling_nystromformer.py | 10 ++------- src/transformers/models/opt/modeling_opt.py | 12 +++-------- .../models/owlv2/modeling_owlv2.py | 10 ++------- .../models/owlvit/modeling_owlvit.py | 10 ++------- .../models/pegasus/modeling_pegasus.py | 12 +++-------- .../models/pegasus_x/modeling_pegasus_x.py | 12 +++-------- .../models/persimmon/modeling_persimmon.py | 12 +++-------- .../models/pix2struct/modeling_pix2struct.py | 12 +++-------- .../models/plbart/modeling_plbart.py | 12 +++-------- .../models/pop2piano/modeling_pop2piano.py | 11 +++------- .../models/prophetnet/modeling_prophetnet.py | 12 +++-------- .../models/qdqbert/modeling_qdqbert.py | 11 +++------- .../models/realm/modeling_realm.py | 11 +++------- .../models/rembert/modeling_rembert.py | 11 +++------- .../models/roberta/modeling_roberta.py | 11 +++------- .../modeling_roberta_prelayernorm.py | 11 +++------- .../models/roc_bert/modeling_roc_bert.py | 11 +++------- .../models/roformer/modeling_roformer.py | 11 +++------- src/transformers/models/rwkv/modeling_rwkv.py | 10 +-------- src/transformers/models/sam/modeling_sam.py | 9 +------- .../seamless_m4t/modeling_seamless_m4t.py | 8 +------ src/transformers/models/sew/modeling_sew.py | 11 ++-------- .../models/sew_d/modeling_sew_d.py | 19 +++-------------- .../speech_to_text/modeling_speech_to_text.py | 12 +++-------- .../modeling_speech_to_text_2.py | 10 +-------- .../models/speecht5/modeling_speecht5.py | 13 +++--------- .../models/splinter/modeling_splinter.py | 11 +++------- src/transformers/models/swin/modeling_swin.py | 9 +------- .../models/swin2sr/modeling_swin2sr.py | 9 +------- .../models/swinv2/modeling_swinv2.py | 9 +------- .../modeling_switch_transformers.py | 11 +++------- src/transformers/models/t5/modeling_t5.py | 11 +++------- .../modeling_table_transformer.py | 9 +------- .../models/tapas/modeling_tapas.py | 11 +++------- .../modeling_time_series_transformer.py | 12 +++-------- .../timesformer/modeling_timesformer.py | 10 ++------- .../models/trocr/modeling_trocr.py | 12 +++-------- src/transformers/models/tvlt/modeling_tvlt.py | 5 +++-- src/transformers/models/umt5/modeling_umt5.py | 11 +++------- .../models/unispeech/modeling_unispeech.py | 13 +++--------- .../unispeech_sat/modeling_unispeech_sat.py | 13 +++--------- .../models/videomae/modeling_videomae.py | 12 +++-------- src/transformers/models/vilt/modeling_vilt.py | 3 ++- .../visual_bert/modeling_visual_bert.py | 10 ++------- src/transformers/models/vit/modeling_vit.py | 10 ++------- .../models/vit_hybrid/modeling_vit_hybrid.py | 10 ++------- .../models/vit_mae/modeling_vit_mae.py | 12 +++-------- .../models/vit_msn/modeling_vit_msn.py | 10 ++------- .../models/vitdet/modeling_vitdet.py | 10 ++------- src/transformers/models/vits/modeling_vits.py | 10 ++------- .../models/vivit/modeling_vivit.py | 10 ++------- .../models/wav2vec2/modeling_wav2vec2.py | 13 +++--------- .../modeling_wav2vec2_conformer.py | 11 ++-------- .../models/wavlm/modeling_wavlm.py | 13 +++--------- .../models/whisper/modeling_whisper.py | 12 +++-------- .../models/x_clip/modeling_x_clip.py | 12 +++-------- src/transformers/models/xglm/modeling_xglm.py | 11 +++------- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 12 +++-------- .../xlm_roberta/modeling_xlm_roberta.py | 11 +++------- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 11 +++------- src/transformers/models/xmod/modeling_xmod.py | 10 +++------ .../models/yolos/modeling_yolos.py | 9 ++------ src/transformers/models/yoso/modeling_yoso.py | 10 ++------- ...ng_{{cookiecutter.lowercase_modelname}}.py | 14 +++++-------- 156 files changed, 406 insertions(+), 1317 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index bad7db0150c8..e132fae5c660 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1095,20 +1095,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 56b9657aecb7..71e650adba1b 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -646,20 +646,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -960,18 +955,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 2a895dc073ba..1c79f3cfd78b 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -336,17 +336,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 278811d23d9a..520a9ddbc131 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1208,18 +1208,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1435,7 +1429,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 8ffb22fd3e5d..11c53ccbdb21 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -638,20 +638,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 2af67b87b739..70013fba27df 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -855,18 +855,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1120,7 +1114,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d30eff63f541..860de96323be 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -510,17 +510,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 993160d6998a..b251c9c9b559 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -593,20 +593,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index c811f2d19d31..97fb89e95413 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -401,20 +401,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 6677e658b8dd..890eb8c6875f 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1617,15 +1617,8 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, @@ -1635,6 +1628,8 @@ def custom_forward(*inputs): from_mask, to_mask, blocked_encoder_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 98d8ae83179f..ad0640b04451 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1944,15 +1944,8 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1961,6 +1954,7 @@ def custom_forward(*inputs): to_mask, blocked_encoder_mask, blocked_encoder_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2299,7 +2293,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 6597d2ea04e6..7dc72aa6368e 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -591,20 +591,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index caaf59d289a2..4a3248e5d443 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -778,18 +778,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1042,7 +1036,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index d72ee4ceb558..ef9d0e9643f5 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -776,18 +776,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1039,7 +1033,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 59f2590d04ee..229afec0a81f 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -623,17 +623,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 317eea1e1b6e..a9decd052d37 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -422,20 +422,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bcb6f4f7b6c6..8339016efcb9 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -474,17 +474,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -953,7 +947,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 688415ac1121..83998421e131 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -762,21 +762,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index d64532170bbf..ea4c3cc285ba 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -804,20 +804,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 603dc2a52b8f..60e753c95f8d 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -651,21 +651,15 @@ def forward( "`use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, bbox_pos_emb, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 44764f900abb..d5d9f0ae488f 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -524,20 +524,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 9625e97ea28b..198e3376731a 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -795,18 +795,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 1f4a42732d7d..c96521493fd5 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -910,20 +910,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1027,7 +1022,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, ) else: diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index ccee38322c0b..7c6c9618c453 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -939,15 +939,8 @@ def forward( input_dimensions = self.input_resolutions[i] if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1595,20 +1588,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index e179244a1c32..9e179753157b 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -640,18 +640,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 385737bafe33..0bded11f9bc1 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -649,18 +649,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 464eeebc9ba0..0a01e05044e4 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -541,21 +541,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 2a4812eaf048..d663e080df93 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1517,15 +1517,8 @@ def forward( # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 927c026df777..c040715c3630 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -641,12 +641,13 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 6d8bb5c2058c..71fd5a705990 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -293,15 +293,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -601,7 +594,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 66588647f61b..ba5c6b97a965 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index e7fd98091f97..6c5c39e4957c 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -522,17 +522,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 06a33a7dd85c..a7816bae558b 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -457,20 +457,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: hidden_states = layer_module( diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 2172f5d22eef..e536d376c591 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -501,20 +501,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - output_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 3865fe523f71..8146436cdc51 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -632,22 +632,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 7e04d2a1c760..fb8ed41ce712 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1383,14 +1383,8 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index ff95a458ad77..4cd8785ce535 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -357,17 +357,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index e38b89a0a444..779b409470d9 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -617,18 +617,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index f021714be250..5ab949b11ce3 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -666,20 +666,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, position_ids, None, + output_attentions, + None, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 13a26b6c05d5..8081a96430bc 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -551,15 +551,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 2c5890e0a357..ff24ed74856a 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1275,14 +1275,8 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 4200e6556d50..cc370a5a0c7c 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1252,15 +1252,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 656a3022c96f..1fd39703bce3 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -447,17 +447,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index de3c125abbac..db48ac56fee3 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -358,18 +358,12 @@ def forward( all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_state, attn_mask, head_mask[i], + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 1a1e215f9a6d..a789b7ef57ba 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -749,15 +749,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index b13ca04626cf..513892740ed7 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -528,17 +528,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index a7d943450a86..eee30624719e 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -571,20 +571,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b178ca354495..d88563e778c7 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -506,20 +506,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 3115a1357ea6..21e480c8212b 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -605,20 +605,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 29873a39457f..7f8e7db562d7 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1148,21 +1148,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, alibi, causal_mask, position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, padding_mask, ) else: diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 9b5faaeb15f6..614314632151 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -663,18 +663,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 299b607b6b8a..f9ec022845f0 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,14 +292,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = self.gradient_checkpointing_func(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.forward, hidden_states) else: layer_outputs = layer_module(hidden_states) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 0e33dc4f66f4..5ff1c99b94f3 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -586,15 +586,8 @@ def forward( for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - stage_outputs = self.gradient_checkpointing_func( - create_custom_forward(stage_module), + stage_module.forward, hidden_states, input_dimensions, ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index bcbee566fa24..0e44931eb99e 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -452,18 +452,13 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -879,18 +874,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index fd726627bb1b..ee84cb1bc88f 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -878,22 +878,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3bcb4a865812..7d4e77a4674f 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -651,22 +651,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 494187a33aa4..6ede0829cd03 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -605,20 +605,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 19560dc6c975..860552cde485 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -642,20 +642,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for layer_past - return module(*inputs, use_cache, None, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, position_ids, head_mask[i], + use_cache, + None, + output_attentions, ) else: outputs = layer( diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a51d4bdd094c..c0302b6c21a0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -668,21 +668,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index d4199891f6c9..e4aeaf70996f 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1032,18 +1032,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 9acb52c2aedb..94b2a205d8ca 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -346,15 +346,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -739,7 +732,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -829,7 +822,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index eb2b836169d6..cb604909e192 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -401,18 +401,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index f3ebc9324260..187f39248fbc 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -817,22 +817,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 5b93a16d3e02..d5cd57dea7cc 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1216,18 +1216,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) if conv_layer is not None: output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) @@ -1448,7 +1442,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 7b02ee85020c..74c6d875f222 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -463,17 +463,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -948,7 +942,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 82531ab7a455..dc094bd8ba0b 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -487,20 +487,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 30ff103bea7d..fcb1dd37de72 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -447,10 +447,11 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 42162dcfb2e5..9afc855417fa 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -657,19 +657,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) - # The above line will cause error: - # RuntimeError: Trying to backward through the graph a second time - # (or directly access saved tensors after they have already been freed). - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 1029a7950a2e..5850923ffdca 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1877,20 +1877,15 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2152,7 +2147,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 65c381fc50a9..2c7085aa8228 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -514,19 +514,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layout_inputs, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5664a581ffb7..340f02abea07 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -914,16 +914,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + decoder_layer.forward, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + padding_mask, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 3b77ad46aed3..6ca8f61cfa4c 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1304,20 +1304,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 4c6ff76cc95d..b4d3c3ba495f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1511,15 +1511,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1529,6 +1522,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index fde39d0999af..143932f924bf 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -788,19 +788,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, word_hidden_states, entity_hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 264aff5b4aac..080949adbeb6 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -821,18 +821,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1076,7 +1070,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a0ab7192718b..7bf8aac0aef6 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -789,18 +789,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1047,7 +1041,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 9686b0a1d305..fc15c86e7a94 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -648,20 +648,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 9ec586a17bb3..7d00b6b6d871 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1864,20 +1864,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, None, None, + output_attentions, ) else: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 8502a6a368ea..a941c0508a94 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -848,20 +848,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index fe9dbc91f801..dd6c45de8a56 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -688,15 +688,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, layer_head_mask + layer_module.forward, + hidden_states, + layer_head_mask, + output_attentions, ) else: layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 644c5d292b0e..aa5f17215e90 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -829,18 +829,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1096,7 +1090,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 16d463dcb470..a2e2a39ec966 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -551,20 +551,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1544ebeaaf81..e9c17cc25ccf 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -919,19 +919,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, + use_cache, + padding_mask, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 0653321df9c3..1e8a8afa07dd 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -626,15 +626,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, ) else: diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5aca04266e46..c857915a8cca 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -582,15 +582,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, ) else: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 279b0bc903a5..897a90ce0486 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -524,20 +524,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - outputs = self.gradient_checkpointing_func( - create_custom_forward(block), + block.forward, hidden_states, alibi, causal_mask, layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 672e2666533d..1da9da2af915 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -766,15 +766,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 2e2b68060dc9..2951ffc889dc 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1074,15 +1074,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1092,6 +1085,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 6bee6f35dc7d..a740ed47074b 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -827,16 +827,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -844,6 +836,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index f44e067aac31..71a1a166d848 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -950,19 +950,13 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), (self_attn_prompt[idx] if self.use_prompt else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1237,7 +1231,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 5a94e43291cb..a8ad52d26988 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -577,20 +577,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 72cf7e3a3005..883589de2410 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1154,18 +1154,12 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1435,7 +1429,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 3c5df5dedd2e..9a023cbc91ef 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -370,17 +370,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index c97d57fa236f..5782d796566a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -692,20 +692,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 5aee16cc8106..351a1a77d59a 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -765,18 +765,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index b5317ea1c1b8..63e1570a1106 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -754,18 +754,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 705cf956f784..dbe93bc18bec 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -804,18 +804,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1097,7 +1091,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 5f5888429231..a29a1250a976 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1072,18 +1072,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, global_hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1341,7 +1335,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c6092e158c93..28a12d5eb338 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -670,19 +670,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 31cedc13359f..acbe0996d5ae 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -343,18 +343,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -1505,7 +1499,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index a079b0bf0cf5..59803ed363e1 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -808,18 +808,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1074,7 +1068,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index acb43f824b7b..5cf7039e9f0c 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -903,15 +903,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -921,6 +914,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 04d2b946eafc..f0016c8c206d 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1330,18 +1330,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1579,7 +1573,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index cf307fb35009..69be03b93bde 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -581,20 +581,15 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 8f7d0a656002..a63e3a9e9bce 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -586,20 +586,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 6dd04ed4030c..6471653da7bf 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -543,20 +543,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index d7ead17b4544..aedfc5ef8077 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 4ae7a308f68e..1bcdb8724518 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -512,20 +512,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index d1b84ab31f63..3627944fab4b 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -644,20 +644,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index e860ff34eb52..6773a6f967ad 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -578,21 +578,16 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, sinusoidal_pos, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index bbe21949f0e6..d7c7df9a8390 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -677,16 +677,8 @@ def forward( all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - hidden_states, state, attentions = self.gradient_checkpointing_func( - create_custom_forward(block), hidden_states, state + block.forward, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index f5cd7cf0a45b..d384747af336 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1042,15 +1042,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 6b0e7a7ff079..b0b56b80d268 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1857,18 +1857,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 44cbcec5267a..6a3cde064ff8 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -360,15 +360,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -681,7 +674,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 74374e1a4eb9..f18622538e41 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -453,15 +453,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -1134,20 +1127,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - output_states = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index acdcc2f902cb..9d75dc4f3da0 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -818,18 +818,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1075,7 +1069,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 9d863ba3e2f2..486dda2f46b4 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -670,16 +670,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index ef374bbb32e7..b470cab687d2 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -520,15 +520,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -1395,7 +1388,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1725,7 +1718,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f1ab50179dea..d766f435f150 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -459,20 +459,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 228d962dea1d..25432478abea 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -825,15 +825,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index db8ff6a652eb..d7b248b11359 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -746,15 +746,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask + stage_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index fda0e080d0d8..c00ae39e0bec 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -906,15 +906,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index ed0d59abb8bc..32d030728de5 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1040,15 +1040,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1058,6 +1051,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 603c6a4730e8..c796a9cf24cf 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1101,15 +1101,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1119,6 +1112,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index fb42673ae5c5..e72975a200a0 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -1144,15 +1144,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index e6ce415899fa..ae22bbd8449d 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -646,20 +646,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_values, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_values, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index c550f89e9504..9b44713dc64a 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -947,18 +947,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1173,7 +1167,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index df7dd1c953f5..ccc65287cdc2 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -439,16 +439,10 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 6971b4dfb21a..9b7fab8e2f3d 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -702,16 +702,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -719,6 +711,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 8852083c4694..fcf61142ced6 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -568,10 +568,11 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -886,7 +887,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, None, ) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index e6e9aaa26a38..a5b58444fe4e 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -710,15 +710,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -726,6 +719,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 8f667d3d564c..4708ff4173dd 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -384,15 +384,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -775,7 +768,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -865,7 +858,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 5584929ab11c..2d57b0638198 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -398,15 +398,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -789,7 +782,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -879,7 +872,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 9657a747fbc3..27e09730cfde 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -434,17 +434,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -735,7 +729,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, None, ) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 9c8fee6c79f3..1d9db412d37f 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -539,10 +539,11 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index c2eaa90b4864..36a1292fc9fd 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -418,18 +418,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 050d02ee2990..b06ab62113a7 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -397,17 +397,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index fa4b3471e375..7b54e6c1535b 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -415,17 +415,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index b468075d08ff..0e27a335ddb6 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -536,17 +536,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -802,7 +796,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, None, ) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 87779dd3ae94..91e13c7b6adc 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -387,17 +387,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index fd9f26923444..8e20f17e0709 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -565,17 +565,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 7b7899ee287f..f3cda24b85cc 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1167,18 +1167,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, padding_mask, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 5e07b1544b2b..b4ed99bd9e98 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -338,17 +338,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index c02e23660b56..2e18a26633cf 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -451,15 +451,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -811,7 +804,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) @@ -900,7 +893,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, ) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index edcdcf4a22ac..6f2c28624df7 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -518,15 +518,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -919,7 +912,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, relative_position_embeddings, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 182482dfd83a..defe32a5103c 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -354,15 +354,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - hidden_states = self.gradient_checkpointing_func( - create_custom_forward(conv_layer), + conv_layer.forward, hidden_states, ) else: @@ -721,7 +714,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, position_bias, @@ -812,7 +805,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer), + layer.forward, hidden_states, attention_mask, position_bias, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f4b8fb4852a9..0b7341fa6d3d 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -943,18 +943,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1184,7 +1178,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 025533ab4178..de3a4376e4a6 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -704,18 +704,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -959,7 +953,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 16f0402abf98..f6f518e8f5ce 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -676,15 +676,8 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -692,6 +685,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index e599cc3cede7..e07d343b62cb 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1350,18 +1350,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1602,7 +1596,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index b195ee43723e..1bc22ca10045 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -511,20 +511,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 0e3ed4eeb986..3477d709ae0e 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -499,20 +499,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 2eb0ba83d726..c6fcc0bb7c21 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -574,20 +574,16 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, lang_ids, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 0884529777c2..4e1825a457bc 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -493,16 +493,11 @@ def forward( if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 0159d7fb76d3..b0cbd589b293 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -561,17 +561,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index ee583cec3548..2071c90a83bb 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -544,19 +544,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -2322,7 +2318,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2561,7 +2557,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = self.gradient_checkpointing_func( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, From 2b5a6695333dcd5a36fad3f5c547e88bdf03de15 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 14:38:45 +0000 Subject: [PATCH 04/12] fixup --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index ebd1361ed2e9..f7f74201d2d3 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1862,7 +1862,7 @@ def forward( hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), - output_attentions + output_attentions, ) else: layer_outputs = encoder_layer( From 6fbe101677c9d06450a18f04be33dac8a41cf205 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 23 Oct 2023 14:46:07 +0000 Subject: [PATCH 05/12] fixup --- .../models/deformable_detr/modeling_deformable_detr.py | 1 - src/transformers/models/deta/modeling_deta.py | 1 - src/transformers/models/xglm/modeling_xglm.py | 1 - src/transformers/models/xmod/modeling_xmod.py | 1 - src/transformers/models/yolos/modeling_yolos.py | 1 - 5 files changed, 5 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index fb8ed41ce712..507a18151490 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1382,7 +1382,6 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index ff24ed74856a..0853a3f82208 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1274,7 +1274,6 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index f6f518e8f5ce..075e0c315970 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -675,7 +675,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index c6fcc0bb7c21..4ca4adeec995 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -573,7 +573,6 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 4e1825a457bc..a378e96f9909 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -492,7 +492,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, From 634b5e7fe959f280176c0d5710bfe783f808974f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:17:42 +0000 Subject: [PATCH 06/12] add test and fix all failing GC tests --- src/transformers/models/align/modeling_align.py | 2 +- src/transformers/models/blip/modeling_blip.py | 4 ++-- src/transformers/models/blip_2/modeling_blip_2.py | 13 +++++-------- src/transformers/models/gpt2/modeling_gpt2.py | 1 - .../models/groupvit/modeling_groupvit.py | 1 - src/transformers/models/hubert/modeling_hubert.py | 2 +- .../models/instructblip/modeling_instructblip.py | 13 +++++-------- src/transformers/models/longt5/modeling_longt5.py | 4 +--- src/transformers/models/mbart/modeling_mbart.py | 2 +- .../models/pix2struct/modeling_pix2struct.py | 2 +- .../models/seamless_m4t/modeling_seamless_m4t.py | 2 +- src/transformers/models/sew_d/modeling_sew_d.py | 2 +- .../models/speecht5/modeling_speecht5.py | 6 ------ src/transformers/models/tvlt/modeling_tvlt.py | 2 +- .../models/videomae/modeling_videomae.py | 2 +- .../models/vit_mae/modeling_vit_mae.py | 2 +- .../models/vitmatte/modeling_vitmatte.py | 5 +++++ src/transformers/models/vits/modeling_vits.py | 2 +- tests/test_modeling_common.py | 14 ++++++++++++++ 19 files changed, 42 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index e132fae5c660..7b141b5f65a3 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1193,7 +1193,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (AlignTextModel, AlignVisionModel)): + if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 229afec0a81f..927c33f9927c 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -34,7 +34,7 @@ replace_return_docstrings, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel +from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel logger = logging.get_logger(__name__) @@ -462,7 +462,7 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BlipEncoder): + if isinstance(module, (BlipEncoder, BlipTextEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 8339016efcb9..735b81bc4229 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -298,10 +298,14 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Blip2Encoder): + if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) + BLIP_2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -939,13 +943,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ee84cb1bc88f..dc28ed3640f4 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1618,7 +1618,6 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index e4aeaf70996f..332b14d9961c 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -492,7 +492,6 @@ def __init__( self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size)) else: self.group_token = None - self.gradient_checkpointing = False self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)]) if num_group_token > 0: diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 94b2a205d8ca..b215063090d0 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -889,7 +889,7 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): + if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 74c6d875f222..3cc44efbe361 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -305,10 +305,14 @@ def _init_weights(self, module): module.bias.data.zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, InstructBlipEncoder): + if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) + INSTRUCTBLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -934,13 +938,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b4d3c3ba495f..9abbfa2f2001 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -775,7 +775,6 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = False # Relativen attention bias & Layer norm for global attention if self.has_relative_attention_bias: @@ -1340,9 +1339,8 @@ def _init_weights(self, module): mean=0.0, std=factor * ((d_model) ** -0.5) ) - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (LongT5Attention, LongT5Stack)): + if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index aa5f17215e90..29bcd445a8f2 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -517,7 +517,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MBartDecoder, MBartDecoder)): + if isinstance(module, (MBartDecoder, MBartEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index acbe0996d5ae..0efaab7cec59 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -558,7 +558,7 @@ def __init__(self, config: Pix2StructConfig): self.post_init() def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, Pix2StructVisionEncoder): + if isinstance(module, (Pix2StructVisionEncoder, Pix2StructVisionAttention)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index f7f74201d2d3..b0bc0a135091 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1548,7 +1548,7 @@ def _init_weights(self, module): nn.init.uniform_(module.bias, a=-k, b=k) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): + if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TConformerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index f18622538e41..2dc2231e6073 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1310,7 +1310,7 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti return attention_mask def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, SEWDTransformerEncoder): + if isinstance(module, (SEWDEncoder, SEWDFeatureEncoder, SEWDTransformerEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index b470cab687d2..c1fef6df94d1 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1433,7 +1433,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1470,7 +1469,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1513,7 +1511,6 @@ class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1782,7 +1779,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1830,7 +1826,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1883,7 +1878,6 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index fcf61142ced6..65da6a46339a 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -618,7 +618,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TvltEncoder): + if isinstance(module, (TvltEncoder, TvltDecoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 27e09730cfde..203aacc3f365 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -484,7 +484,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, VideoMAEEncoder): + if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 0e27a335ddb6..14e047c5acc8 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -586,7 +586,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ViTMAEEncoder): + if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 18b8b80b328c..f5025a37e71c 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -91,6 +91,11 @@ def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None + for backbone_module in module.modules(): + if hasattr(backbone_module, "gradient_checkpointing"): + backbone_module.gradient_checkpointing_func = gradient_checkpointing_func + backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None + class VitMatteBasicConv3x3(nn.Module): """ diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index f3cda24b85cc..49b8e1a6a40a 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1291,7 +1291,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (VitsTextEncoder)): + if isinstance(module, VitsEncoder): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 74ea66fd6fa0..7e1c471badf4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -349,10 +349,24 @@ def test_gradient_checkpointing_enable_disable(self): model.gradient_checkpointing_enable() self.assertTrue(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + # check disable works model.gradient_checkpointing_disable() self.assertFalse(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: From 476d261db8fca68b04d652b70d785686666b0e77 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:28:07 +0000 Subject: [PATCH 07/12] remove all remaining `create_custom_forward` methods --- .../models/autoformer/modeling_autoformer.py | 10 ++------- src/transformers/models/bart/modeling_bart.py | 10 ++------- .../modeling_bigbird_pegasus.py | 10 ++------- .../models/blenderbot/modeling_blenderbot.py | 10 ++------- .../modeling_blenderbot_small.py | 10 ++------- .../chinese_clip/modeling_chinese_clip.py | 8 +------ .../models/convbert/modeling_convbert.py | 7 ------- .../data2vec/modeling_data2vec_audio.py | 8 +------ .../models/hubert/modeling_hubert.py | 16 ++------------ .../models/informer/modeling_informer.py | 10 ++------- .../models/layoutlmv2/modeling_layoutlmv2.py | 7 ------- src/transformers/models/led/modeling_led.py | 10 ++------- .../models/m2m_100/modeling_m2m_100.py | 10 ++------- .../models/marian/modeling_marian.py | 10 ++------- .../models/mbart/modeling_mbart.py | 10 ++------- src/transformers/models/mvp/modeling_mvp.py | 10 ++------- .../models/nllb_moe/modeling_nllb_moe.py | 9 ++------ .../models/pegasus/modeling_pegasus.py | 10 ++------- .../models/pegasus_x/modeling_pegasus_x.py | 10 ++------- .../models/pix2struct/modeling_pix2struct.py | 9 ++------ .../models/plbart/modeling_plbart.py | 10 ++------- .../models/prophetnet/modeling_prophetnet.py | 10 ++------- .../seamless_m4t/modeling_seamless_m4t.py | 21 ++++--------------- src/transformers/models/sew/modeling_sew.py | 8 +------ .../speech_to_text/modeling_speech_to_text.py | 10 ++------- .../models/speecht5/modeling_speecht5.py | 18 +++------------- .../modeling_time_series_transformer.py | 10 ++------- src/transformers/models/tvlt/modeling_tvlt.py | 15 +------------ .../models/unispeech/modeling_unispeech.py | 16 ++------------ .../unispeech_sat/modeling_unispeech_sat.py | 16 ++------------ .../models/videomae/modeling_videomae.py | 8 +------ src/transformers/models/vilt/modeling_vilt.py | 7 ------- .../models/vit_mae/modeling_vit_mae.py | 8 +------ .../models/wav2vec2/modeling_wav2vec2.py | 16 ++------------ .../modeling_wav2vec2_conformer.py | 8 +------ .../models/wavlm/modeling_wavlm.py | 16 ++------------ .../models/whisper/modeling_whisper.py | 10 ++------- .../models/x_clip/modeling_x_clip.py | 8 +------ .../xlm_prophetnet/modeling_xlm_prophetnet.py | 10 ++------- ...ng_{{cookiecutter.lowercase_modelname}}.py | 17 +++------------ 40 files changed, 70 insertions(+), 366 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 520a9ddbc131..29073c3d57dd 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1420,14 +1420,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1437,6 +1429,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 70013fba27df..390af1a825a7 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1105,14 +1105,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1122,6 +1114,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ad0640b04451..03ef911970ad 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2284,14 +2284,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -2301,6 +2293,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 4a3248e5d443..35879ac1500a 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1027,14 +1027,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1044,6 +1036,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ef9d0e9643f5..59ba6b9dd874 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1024,14 +1024,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1041,6 +1033,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index c96521493fd5..a010d82fd9de 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1014,16 +1014,10 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index c040715c3630..e24083021425 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -633,13 +633,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 71fd5a705990..5a2491571efa 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -586,17 +586,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index b215063090d0..e5b1b1742e74 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -724,17 +724,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -814,17 +808,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index d5cd57dea7cc..423de7d81976 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1433,14 +1433,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1450,6 +1442,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index fcb1dd37de72..03900bff907c 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -439,13 +439,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 5850923ffdca..3d4e3c26188c 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2138,14 +2138,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -2155,6 +2147,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 080949adbeb6..b9b672ca2829 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1061,14 +1061,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1078,6 +1070,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 7bf8aac0aef6..81a4d7b6f6b5 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1032,14 +1032,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1049,6 +1041,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 29bcd445a8f2..341260efe45c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1081,14 +1081,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1098,6 +1090,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 71a1a166d848..d8622fca9582 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1222,14 +1222,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1241,6 +1233,8 @@ def custom_forward(*inputs): self_attn_prompt[idx] if self.use_prompt else None, cross_attn_prompt[idx] if self.use_prompt else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 883589de2410..51bbd56d2b58 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1421,13 +1421,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( decoder_layer.forward, hidden_states, @@ -1437,6 +1430,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index dbe93bc18bec..5fc671f25f46 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1082,14 +1082,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1099,6 +1091,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index a29a1250a976..f35bef20969d 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1326,14 +1326,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1341,6 +1333,8 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 0efaab7cec59..9b4444c56ee2 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1491,13 +1491,6 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( layer_module.forward, hidden_states, @@ -1509,6 +1502,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 59803ed363e1..cdd73be66d7a 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1059,14 +1059,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1076,6 +1068,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index f0016c8c206d..eb1b319fb19a 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1564,14 +1564,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1585,6 +1577,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index b0bc0a135091..a930d60ec9da 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -892,14 +892,7 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, @@ -2125,15 +2118,7 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, attention_mask, @@ -2142,6 +2127,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 6a3cde064ff8..883fab34fce2 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -666,17 +666,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 9d75dc4f3da0..030358ff033a 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1060,14 +1060,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1077,6 +1069,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index c1fef6df94d1..40d30f366a20 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1380,19 +1380,13 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), position_bias, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1706,14 +1700,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1723,6 +1709,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 9b44713dc64a..349bc5d48adf 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1158,14 +1158,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1175,6 +1167,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 65da6a46339a..086cf66fd40d 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -560,13 +560,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, @@ -879,17 +872,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 4708ff4173dd..bcfc4069c8a3 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -760,17 +760,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -850,17 +844,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2d57b0638198..778dbfad18a9 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -774,17 +774,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -864,17 +858,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 203aacc3f365..84ff258c58b8 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -721,17 +721,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 1d9db412d37f..a93dc99903e1 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -531,13 +531,6 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 14e047c5acc8..5fa10ca9d137 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -788,17 +788,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer_module.forward, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 2e18a26633cf..ec38d6a11570 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -796,17 +796,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -885,17 +879,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 6f2c28624df7..5d7235925568 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -904,18 +904,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, relative_position_embeddings, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index defe32a5103c..ef76b4333089 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -706,18 +706,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -797,18 +791,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( layer.forward, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 0b7341fa6d3d..c868abe44c0e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1169,14 +1169,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1186,6 +1178,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index de3a4376e4a6..46ad1fb719e7 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -945,18 +945,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index e07d343b62cb..bc86fd6ff5c8 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1587,14 +1587,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -1608,6 +1600,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 2071c90a83bb..e2e36ac36823 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2310,18 +2310,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2549,13 +2543,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = self.gradient_checkpointing_func( decoder_layer.forward, hidden_states, @@ -2565,6 +2552,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: From 465849c02d1b72a6974f5a3d47facbfeb04f45b4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:48:05 +0000 Subject: [PATCH 08/12] fix idefics bug --- src/transformers/models/idefics/modeling_idefics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index d3f9c5da4d2d..5c2d6f996319 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer +from .vision import IdeficsVisionTransformer, IdeficsVisionEncoder logger = logging.get_logger(__name__) @@ -979,7 +979,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, IdeficsModel): + if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)): module.gradient_checkpointing_func = gradient_checkpointing_func module.gradient_checkpointing = gradient_checkpointing_func is not None @@ -1099,7 +1099,6 @@ def __init__(self, config: IdeficsConfig): self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() From 7e5eeda035a5041904a59a23afeca827c3c541e5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 24 Oct 2023 09:56:11 +0000 Subject: [PATCH 09/12] fixup --- src/transformers/models/idefics/modeling_idefics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 5c2d6f996319..28841903a1a3 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer, IdeficsVisionEncoder +from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer logger = logging.get_logger(__name__) From 967ed0db036766dc1e9a964eca66a8c07df1b902 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 25 Oct 2023 09:04:43 +0000 Subject: [PATCH 10/12] replace with `__call__` --- src/transformers/models/align/modeling_align.py | 2 +- src/transformers/models/altclip/modeling_altclip.py | 4 ++-- .../modeling_audio_spectrogram_transformer.py | 2 +- src/transformers/models/autoformer/modeling_autoformer.py | 4 ++-- src/transformers/models/bark/modeling_bark.py | 2 +- src/transformers/models/bart/modeling_bart.py | 4 ++-- src/transformers/models/beit/modeling_beit.py | 2 +- src/transformers/models/bert/modeling_bert.py | 2 +- .../models/bert_generation/modeling_bert_generation.py | 2 +- src/transformers/models/big_bird/modeling_big_bird.py | 2 +- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 4 ++-- src/transformers/models/biogpt/modeling_biogpt.py | 2 +- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 ++-- .../models/blenderbot_small/modeling_blenderbot_small.py | 4 ++-- src/transformers/models/blip/modeling_blip.py | 2 +- src/transformers/models/blip/modeling_blip_text.py | 2 +- src/transformers/models/blip_2/modeling_blip_2.py | 4 ++-- src/transformers/models/bloom/modeling_bloom.py | 2 +- src/transformers/models/bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- src/transformers/models/camembert/modeling_camembert.py | 2 +- src/transformers/models/canine/modeling_canine.py | 2 +- .../models/chinese_clip/modeling_chinese_clip.py | 4 ++-- src/transformers/models/clap/modeling_clap.py | 4 ++-- src/transformers/models/clip/modeling_clip.py | 2 +- src/transformers/models/clipseg/modeling_clipseg.py | 2 +- src/transformers/models/codegen/modeling_codegen.py | 2 +- .../models/conditional_detr/modeling_conditional_detr.py | 2 +- src/transformers/models/convbert/modeling_convbert.py | 2 +- src/transformers/models/data2vec/modeling_data2vec_audio.py | 4 ++-- src/transformers/models/data2vec/modeling_data2vec_text.py | 2 +- .../models/data2vec/modeling_data2vec_vision.py | 2 +- src/transformers/models/deberta/modeling_deberta.py | 2 +- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 2 +- .../decision_transformer/modeling_decision_transformer.py | 2 +- .../models/deformable_detr/modeling_deformable_detr.py | 2 +- src/transformers/models/deit/modeling_deit.py | 2 +- src/transformers/models/deprecated/mctct/modeling_mctct.py | 2 +- .../models/deprecated/open_llama/modeling_open_llama.py | 2 +- .../modeling_trajectory_transformer.py | 2 +- src/transformers/models/deta/modeling_deta.py | 2 +- src/transformers/models/detr/modeling_detr.py | 2 +- src/transformers/models/dinov2/modeling_dinov2.py | 2 +- src/transformers/models/distilbert/modeling_distilbert.py | 2 +- src/transformers/models/donut/modeling_donut_swin.py | 2 +- src/transformers/models/dpt/modeling_dpt.py | 2 +- src/transformers/models/electra/modeling_electra.py | 2 +- src/transformers/models/ernie/modeling_ernie.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/flava/modeling_flava.py | 2 +- src/transformers/models/fnet/modeling_fnet.py | 2 +- src/transformers/models/focalnet/modeling_focalnet.py | 2 +- src/transformers/models/git/modeling_git.py | 4 ++-- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/groupvit/modeling_groupvit.py | 2 +- src/transformers/models/hubert/modeling_hubert.py | 6 +++--- src/transformers/models/idefics/vision.py | 2 +- src/transformers/models/imagegpt/modeling_imagegpt.py | 2 +- src/transformers/models/informer/modeling_informer.py | 4 ++-- .../models/instructblip/modeling_instructblip.py | 4 ++-- src/transformers/models/layoutlm/modeling_layoutlm.py | 2 +- src/transformers/models/layoutlmv2/modeling_layoutlmv2.py | 2 +- src/transformers/models/layoutlmv3/modeling_layoutlmv3.py | 2 +- src/transformers/models/led/modeling_led.py | 4 ++-- src/transformers/models/lilt/modeling_lilt.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/longformer/modeling_longformer.py | 2 +- src/transformers/models/luke/modeling_luke.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 ++-- src/transformers/models/marian/modeling_marian.py | 4 ++-- src/transformers/models/markuplm/modeling_markuplm.py | 2 +- src/transformers/models/mask2former/modeling_mask2former.py | 2 +- src/transformers/models/maskformer/modeling_maskformer.py | 2 +- .../models/maskformer/modeling_maskformer_swin.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 4 ++-- .../models/megatron_bert/modeling_megatron_bert.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mobilevit/modeling_mobilevit.py | 2 +- src/transformers/models/mobilevitv2/modeling_mobilevitv2.py | 2 +- src/transformers/models/mpt/modeling_mpt.py | 2 +- src/transformers/models/mra/modeling_mra.py | 2 +- src/transformers/models/mvp/modeling_mvp.py | 4 ++-- src/transformers/models/nezha/modeling_nezha.py | 2 +- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 2 +- .../models/nystromformer/modeling_nystromformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- src/transformers/models/owlv2/modeling_owlv2.py | 2 +- src/transformers/models/owlvit/modeling_owlvit.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 4 ++-- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 ++-- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 2 +- src/transformers/models/plbart/modeling_plbart.py | 4 ++-- src/transformers/models/prophetnet/modeling_prophetnet.py | 4 ++-- src/transformers/models/qdqbert/modeling_qdqbert.py | 2 +- src/transformers/models/realm/modeling_realm.py | 2 +- src/transformers/models/rembert/modeling_rembert.py | 2 +- src/transformers/models/roberta/modeling_roberta.py | 2 +- .../roberta_prelayernorm/modeling_roberta_prelayernorm.py | 2 +- src/transformers/models/roc_bert/modeling_roc_bert.py | 2 +- src/transformers/models/roformer/modeling_roformer.py | 2 +- src/transformers/models/rwkv/modeling_rwkv.py | 2 +- src/transformers/models/sam/modeling_sam.py | 2 +- .../models/seamless_m4t/modeling_seamless_m4t.py | 4 ++-- src/transformers/models/sew/modeling_sew.py | 4 ++-- src/transformers/models/sew_d/modeling_sew_d.py | 4 ++-- .../models/speech_to_text/modeling_speech_to_text.py | 4 ++-- .../models/speech_to_text_2/modeling_speech_to_text_2.py | 2 +- src/transformers/models/speecht5/modeling_speecht5.py | 6 +++--- src/transformers/models/splinter/modeling_splinter.py | 2 +- src/transformers/models/swin/modeling_swin.py | 2 +- src/transformers/models/swin2sr/modeling_swin2sr.py | 2 +- src/transformers/models/swinv2/modeling_swinv2.py | 2 +- .../models/table_transformer/modeling_table_transformer.py | 2 +- src/transformers/models/tapas/modeling_tapas.py | 2 +- .../modeling_time_series_transformer.py | 4 ++-- src/transformers/models/timesformer/modeling_timesformer.py | 2 +- src/transformers/models/trocr/modeling_trocr.py | 2 +- src/transformers/models/tvlt/modeling_tvlt.py | 4 ++-- src/transformers/models/unispeech/modeling_unispeech.py | 6 +++--- .../models/unispeech_sat/modeling_unispeech_sat.py | 6 +++--- src/transformers/models/videomae/modeling_videomae.py | 4 ++-- src/transformers/models/vilt/modeling_vilt.py | 2 +- src/transformers/models/visual_bert/modeling_visual_bert.py | 2 +- src/transformers/models/vit/modeling_vit.py | 2 +- src/transformers/models/vit_hybrid/modeling_vit_hybrid.py | 2 +- src/transformers/models/vit_mae/modeling_vit_mae.py | 4 ++-- src/transformers/models/vit_msn/modeling_vit_msn.py | 2 +- src/transformers/models/vitdet/modeling_vitdet.py | 2 +- src/transformers/models/vits/modeling_vits.py | 2 +- src/transformers/models/vivit/modeling_vivit.py | 2 +- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 6 +++--- .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 4 ++-- src/transformers/models/wavlm/modeling_wavlm.py | 6 +++--- src/transformers/models/whisper/modeling_whisper.py | 4 ++-- src/transformers/models/x_clip/modeling_x_clip.py | 4 ++-- src/transformers/models/xglm/modeling_xglm.py | 2 +- .../models/xlm_prophetnet/modeling_xlm_prophetnet.py | 4 ++-- src/transformers/models/xlm_roberta/modeling_xlm_roberta.py | 2 +- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xmod/modeling_xmod.py | 2 +- src/transformers/models/yolos/modeling_yolos.py | 2 +- src/transformers/models/yoso/modeling_yoso.py | 2 +- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 6 +++--- 149 files changed, 197 insertions(+), 197 deletions(-) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 7b141b5f65a3..58dc2a892009 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1096,7 +1096,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 71e650adba1b..e6229165aace 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -647,7 +647,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -956,7 +956,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 1c79f3cfd78b..a1f85e2a09eb 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -337,7 +337,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 29073c3d57dd..40e300231085 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1209,7 +1209,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1421,7 +1421,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 11c53ccbdb21..2708b00d05c4 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -639,7 +639,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 390af1a825a7..73eca72e5d12 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -856,7 +856,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1106,7 +1106,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 860de96323be..3ba3d4911b0f 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -511,7 +511,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index b251c9c9b559..91380e13a055 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -594,7 +594,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 97fb89e95413..123cb2212e19 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -402,7 +402,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 890eb8c6875f..0ba2119e6844 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1618,7 +1618,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 03ef911970ad..98ff51032bad 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1945,7 +1945,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2285,7 +2285,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 7dc72aa6368e..2bbdbed348a1 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -592,7 +592,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 35879ac1500a..51a947af0a83 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -779,7 +779,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1028,7 +1028,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 59ba6b9dd874..88a9b52de909 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -777,7 +777,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1025,7 +1025,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 927c33f9927c..efd986299c29 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -624,7 +624,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index a9decd052d37..e0aa4e17f146 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -423,7 +423,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 735b81bc4229..2f7f00b3dd59 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -479,7 +479,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, output_attentions, @@ -944,7 +944,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 83998421e131..583367c9ab55 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -763,7 +763,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, alibi, causal_mask, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ea4c3cc285ba..0f272a21e21d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -805,7 +805,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 60e753c95f8d..c10f83505676 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -652,7 +652,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, bbox_pos_emb, attention_mask, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index d5d9f0ae488f..2e0a6c12fe64 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -525,7 +525,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 198e3376731a..adc875910320 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -796,7 +796,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index a010d82fd9de..ef1c265723b6 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -911,7 +911,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -1015,7 +1015,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, output_attentions, ) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 7c6c9618c453..025b59ae4b97 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -940,7 +940,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1589,7 +1589,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 9e179753157b..56f24c157f83 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -641,7 +641,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 0bded11f9bc1..7a0e52926983 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -650,7 +650,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 0a01e05044e4..340719e1fb78 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -542,7 +542,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index c887b170c9cd..01dbf8ecd59c 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1520,7 +1520,7 @@ def forward( query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index e24083021425..da577a589614 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -634,7 +634,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 5a2491571efa..a99b6f3a6dc3 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -294,7 +294,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -587,7 +587,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index ba5c6b97a965..507c2fc464d8 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -511,7 +511,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 6c5c39e4957c..2742d5ffc37b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -523,7 +523,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index a7816bae558b..65ec497cecd8 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -458,7 +458,7 @@ def forward( if self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, next_kv, attention_mask, query_states, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index e536d376c591..2245ac549ada 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -502,7 +502,7 @@ def forward( if self.gradient_checkpointing and self.training: output_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, next_kv, attention_mask, query_states, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8146436cdc51..19c2731a50a7 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -633,7 +633,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index b33ba3a5fa2a..220fcf0d0660 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1385,7 +1385,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 4cd8785ce535..6e97e932b533 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -358,7 +358,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 779b409470d9..9e7a73c5880b 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -618,7 +618,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 80f27e4d666c..fb1cc7f0fb84 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -667,7 +667,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 8081a96430bc..c9f31c714446 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -552,7 +552,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 9e0954736963..a6f979eaeea6 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1277,7 +1277,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 7781298b0137..1c09e3e3d7b2 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1255,7 +1255,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 1fd39703bce3..1440b6d615fb 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -448,7 +448,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index db48ac56fee3..3768dd6e91ca 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -359,7 +359,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_state, attn_mask, head_mask[i], diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index a789b7ef57ba..76d525717f8c 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -750,7 +750,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 513892740ed7..2621fa338015 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -529,7 +529,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index eee30624719e..fde5632c09c3 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -572,7 +572,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d88563e778c7..330cb5033160 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -507,7 +507,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 21e480c8212b..86bd20a46480 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -606,7 +606,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 5c011ddf3d5c..642e60a72f91 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1280,7 +1280,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, alibi, attention_mask, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 614314632151..1fbf49f9e127 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -664,7 +664,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index f9ec022845f0..b84761536bac 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,7 +292,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func(layer_module.forward, hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) else: layer_outputs = layer_module(hidden_states) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 5ff1c99b94f3..87ec98169626 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -587,7 +587,7 @@ def forward( for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: stage_outputs = self.gradient_checkpointing_func( - stage_module.forward, + stage_module.__call__, hidden_states, input_dimensions, ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 0e44931eb99e..293b9c789d56 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -453,7 +453,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -875,7 +875,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index dc28ed3640f4..24826a76bc04 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -879,7 +879,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7d4e77a4674f..37c51b40c9a7 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -652,7 +652,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6ede0829cd03..ed1e62bf175f 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -606,7 +606,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 860552cde485..cf0aa0645ae0 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -643,7 +643,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index c0302b6c21a0..65f805b71716 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -669,7 +669,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 332b14d9961c..a9de67143846 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1032,7 +1032,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index e5b1b1742e74..732e6be2f8dd 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -347,7 +347,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -725,7 +725,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -809,7 +809,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index cb604909e192..24dc3e9396aa 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -402,7 +402,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 187f39248fbc..a365731ed53d 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -818,7 +818,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, None, attention_mask, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 423de7d81976..53518760cc00 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1217,7 +1217,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1434,7 +1434,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 3cc44efbe361..d4cb7a1fa00b 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -468,7 +468,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, output_attentions, @@ -939,7 +939,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index dc094bd8ba0b..ce6d4302bccc 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -488,7 +488,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 03900bff907c..8f6260fdda49 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -440,7 +440,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 9afc855417fa..e387707e52da 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -658,7 +658,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 3d4e3c26188c..61bbd4156b46 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1878,7 +1878,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, @@ -2139,7 +2139,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 2c7085aa8228..4fd7a85affd7 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -515,7 +515,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layout_inputs, attention_mask, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8f55982565ce..279884dc164f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1015,7 +1015,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 6ca8f61cfa4c..b4f20b452558 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1305,7 +1305,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 143932f924bf..3b5f4d0bf71d 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -789,7 +789,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, word_hidden_states, entity_hidden_states, attention_mask, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index b9b672ca2829..4ebe11f3f3b3 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -822,7 +822,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1062,7 +1062,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 81a4d7b6f6b5..e2e09b564b0e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -790,7 +790,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1033,7 +1033,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index fc15c86e7a94..80498efb3cad 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -649,7 +649,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 7d00b6b6d871..86eccc478753 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1865,7 +1865,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index a941c0508a94..7df8b60792a0 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -849,7 +849,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index dd6c45de8a56..89c6a0c0e0b4 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -689,7 +689,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 341260efe45c..7c4c9bdf9598 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -830,7 +830,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1082,7 +1082,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index a2e2a39ec966..c23666f10b72 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -552,7 +552,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a95215f7641a..36b5a4b66bb5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1021,7 +1021,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 1e8a8afa07dd..c664c02a883b 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -627,7 +627,7 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index c857915a8cca..b88925f41b83 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -583,7 +583,7 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 897a90ce0486..ede306e71b86 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -525,7 +525,7 @@ def forward( if self.gradient_checkpointing and self.training: outputs = self.gradient_checkpointing_func( - block.forward, + block.__call__, hidden_states, alibi, causal_mask, diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 1da9da2af915..f6cb65889a37 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -767,7 +767,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, ) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index d8622fca9582..122b49287872 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -951,7 +951,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1223,7 +1223,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index a8ad52d26988..cd43688e3f74 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -578,7 +578,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 51bbd56d2b58..cbed1e1b1530 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1155,7 +1155,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 9a023cbc91ef..9b2052eb6ca4 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -371,7 +371,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5782d796566a..9925e7b4a46b 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -693,7 +693,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 351a1a77d59a..a1491d15ea55 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -766,7 +766,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 63e1570a1106..68037d13950e 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -755,7 +755,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 5fc671f25f46..058ecd1775a9 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -805,7 +805,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1083,7 +1083,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index f35bef20969d..6eaddf642a8b 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1073,7 +1073,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, global_hidden_states, attention_mask, @@ -1327,7 +1327,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index fda50ca47690..8043fc8699a6 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -670,7 +670,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, position_ids, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 9b4444c56ee2..cfc2b137c579 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -344,7 +344,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index cdd73be66d7a..1e047fd37267 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -809,7 +809,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1060,7 +1060,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index eb1b319fb19a..e4c28659cb48 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1331,7 +1331,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1565,7 +1565,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 69be03b93bde..0a2546a9b64e 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -582,7 +582,7 @@ def forward( ) use_cache = False layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index a63e3a9e9bce..86b37b21560b 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -587,7 +587,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 6471653da7bf..e5e662a9b556 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -544,7 +544,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index aedfc5ef8077..32a19c088317 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -511,7 +511,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 1bcdb8724518..78ca20684540 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -513,7 +513,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 3627944fab4b..3a58efa9140c 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -645,7 +645,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 6773a6f967ad..3893e27b028f 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -579,7 +579,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, sinusoidal_pos, diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index d7c7df9a8390..275233321372 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -678,7 +678,7 @@ def forward( for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: hidden_states, state, attentions = self.gradient_checkpointing_func( - block.forward, hidden_states, state, use_cache, output_attentions + block.__call__, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index d384747af336..1bd6fcdc2a8f 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1043,7 +1043,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index a930d60ec9da..ea79c7341883 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -893,7 +893,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, @@ -2119,7 +2119,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 883fab34fce2..36416c168c36 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -361,7 +361,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -667,7 +667,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 2dc2231e6073..39c9641b9489 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -454,7 +454,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -1128,7 +1128,7 @@ def forward( if self.gradient_checkpointing and self.training: output_states = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, next_kv, attention_mask, query_states, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 030358ff033a..73a02fe66df7 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -819,7 +819,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1061,7 +1061,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 486dda2f46b4..acee2b15a44f 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -671,7 +671,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 40d30f366a20..b8fea796647b 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -521,7 +521,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -1381,7 +1381,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1701,7 +1701,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index d766f435f150..1bdf8f3f5f91 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -460,7 +460,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 25432478abea..c2f15dbbf273 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -826,7 +826,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index d7b248b11359..47ce01d16916 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -747,7 +747,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - stage_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index c00ae39e0bec..6daad938a623 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -907,7 +907,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index fc9f001ff060..e1da557b0017 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -1151,7 +1151,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index ae22bbd8449d..de05d77ec943 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -647,7 +647,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 349bc5d48adf..1fa6a963f58f 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -948,7 +948,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1159,7 +1159,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index ccc65287cdc2..044705c35e54 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -440,7 +440,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, output_attentions, ) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 9b7fab8e2f3d..ada8638a03b6 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -703,7 +703,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 086cf66fd40d..a37265f37c7a 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -561,7 +561,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -873,7 +873,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, None, output_attentions, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index bcfc4069c8a3..db14d5bca51f 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -385,7 +385,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -761,7 +761,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -845,7 +845,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 778dbfad18a9..8a9a63804b56 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -399,7 +399,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -775,7 +775,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -859,7 +859,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 84ff258c58b8..277280954fd6 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -435,7 +435,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, @@ -722,7 +722,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, None, output_attentions, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a93dc99903e1..482bd08359bd 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -532,7 +532,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 36a1292fc9fd..425a125a0b89 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -419,7 +419,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index b06ab62113a7..67dbddf8766a 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -398,7 +398,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 7b54e6c1535b..959522843f7a 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -416,7 +416,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 5fa10ca9d137..e156fdc3292c 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -537,7 +537,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, @@ -789,7 +789,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, None, output_attentions, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 91e13c7b6adc..b727c331cfb4 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -388,7 +388,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 8e20f17e0709..9bb3991fabf1 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -566,7 +566,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b8e1a6a40a..b621bde35e61 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1168,7 +1168,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, padding_mask, attention_mask, diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index b4ed99bd9e98..50cb82fb4e18 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -339,7 +339,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ec38d6a11570..9f48e529627e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -452,7 +452,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -797,7 +797,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, @@ -880,7 +880,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, output_attentions, diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 5d7235925568..5fba773ee0cb 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -519,7 +519,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -905,7 +905,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index ef76b4333089..55b19e4c4143 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -355,7 +355,7 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: hidden_states = self.gradient_checkpointing_func( - conv_layer.forward, + conv_layer.__call__, hidden_states, ) else: @@ -707,7 +707,7 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, position_bias, @@ -792,7 +792,7 @@ def forward( # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer.forward, + layer.__call__, hidden_states, attention_mask, position_bias, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index c868abe44c0e..d6d0302727cb 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -944,7 +944,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), @@ -1170,7 +1170,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 46ad1fb719e7..6c9cc02db9c8 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -705,7 +705,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, @@ -946,7 +946,7 @@ def forward( encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 075e0c315970..1880a7832193 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -676,7 +676,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index bc86fd6ff5c8..9a9f02b74a65 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1351,7 +1351,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1588,7 +1588,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 1bc22ca10045..da99b2806fb6 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -512,7 +512,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 3477d709ae0e..49f7c0751721 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -500,7 +500,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 4ca4adeec995..5f7b42f266fb 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -574,7 +574,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, lang_ids, attention_mask, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index a378e96f9909..f6cbaecd014e 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -493,7 +493,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, layer_head_mask, output_attentions, diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index b0cbd589b293..8db66d221061 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -562,7 +562,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, output_attentions, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index e2e36ac36823..0b5af845c9aa 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -545,7 +545,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - layer_module.forward, + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -2311,7 +2311,7 @@ def forward( else: if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - encoder_layer.forward, + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -2544,7 +2544,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self.gradient_checkpointing_func( - decoder_layer.forward, + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, From cec602a669f973e7d8cd790986ac4071a566ddf5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 25 Oct 2023 09:43:31 +0000 Subject: [PATCH 11/12] add comment --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aabb8e34cbb0..73255b021f5f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1856,6 +1856,9 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks + of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + Args: gradient_checkpointing_kwargs (dict, *optional*): Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. From ded70d40d071f574d060f913431100c6f7be5fe4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 25 Oct 2023 09:52:45 +0000 Subject: [PATCH 12/12] quality --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 73255b021f5f..47e9cb2f23e0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1856,8 +1856,8 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks - of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 Args: gradient_checkpointing_kwargs (dict, *optional*):