diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6f866f989a67..47e9cb2f23e0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import functools import gc import importlib.metadata import inspect @@ -1848,16 +1849,31 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + gradient_checkpointing_func = functools.partial( + torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs + ) + + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True @@ -1874,7 +1890,7 @@ def gradient_checkpointing_disable(self): activations". """ if self.supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 6cbf01a3432c..58dc2a892009 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1095,20 +1095,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1197,9 +1192,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (AlignTextModel, AlignVisionModel)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index c4e32de55d9c..e6229165aace 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -646,20 +646,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -960,18 +955,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1089,11 +1078,13 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, AltCLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, AltRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 28969f50b672..a1f85e2a09eb 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -336,17 +336,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -395,9 +389,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.weight.data.fill_(1.0) # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST - def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ASTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 96298c77a344..40e300231085 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -946,9 +946,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUTOFORMER_START_DOCSTRING = r""" @@ -1207,18 +1208,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1425,16 +1420,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1442,6 +1429,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 649719e0eefa..2708b00d05c4 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -313,9 +313,10 @@ def device(self) -> torch.device: return get_parameter_device(self) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BARK_MODEL_START_DOCSTRING = """ @@ -637,20 +638,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9e7763ca23d8..73eca72e5d12 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -521,9 +521,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BartDecoder, BartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -854,18 +855,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1110,16 +1105,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1127,6 +1114,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d698cff88b14..3ba3d4911b0f 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -510,17 +510,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( @@ -572,9 +566,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BeitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1b0fad3f9d65..91380e13a055 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -593,20 +593,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -762,9 +757,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index abe2d828b28b..123cb2212e19 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -401,20 +401,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -607,9 +602,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e266b1a67b7d..0ba2119e6844 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1617,15 +1617,8 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -1635,6 +1628,8 @@ def custom_forward(*inputs): from_mask, to_mask, blocked_encoder_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1784,9 +1779,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BigBirdEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIG_BIRD_START_DOCSTRING = r""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4e279f9dc059..98ff51032bad 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1609,9 +1609,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1943,15 +1944,8 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1960,6 +1954,7 @@ def custom_forward(*inputs): to_mask, blocked_encoder_mask, blocked_encoder_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2289,16 +2284,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2306,6 +2293,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index ca084db5c7d0..2bbdbed348a1 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -376,9 +376,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BioGptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIOGPT_START_DOCSTRING = r""" @@ -590,20 +591,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 12a5ecd42b74..d02861d6343d 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -669,9 +669,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1db81905210b..51a947af0a83 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -483,9 +483,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -777,18 +778,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1032,16 +1027,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1049,6 +1036,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 129de3dd1456..88a9b52de909 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -480,9 +480,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -775,18 +776,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1029,16 +1024,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1046,6 +1033,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 9fca7c28a1a0..efd986299c29 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -34,7 +34,7 @@ replace_return_docstrings, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel +from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel logger = logging.get_logger(__name__) @@ -461,9 +461,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BlipEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (BlipEncoder, BlipTextEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_START_DOCSTRING = r""" @@ -622,17 +623,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 49b958afc2eb..e0aa4e17f146 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -422,20 +422,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bd56b17e55c2..2f7f00b3dd59 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -297,9 +297,14 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, Blip2Encoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) BLIP_2_START_DOCSTRING = r""" @@ -473,17 +478,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -944,15 +943,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d90bb6ad8fdf..583367c9ab55 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -496,9 +496,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, BloomModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_standard_cache( @@ -761,21 +762,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ce569157b811..0f272a21e21d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -804,20 +804,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index a8ea8d49195b..c10f83505676 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -651,21 +651,15 @@ def forward( "`use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, bbox_pos_emb, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 8d7d279579e3..2e0a6c12fe64 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -524,20 +524,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -625,9 +620,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CamembertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CAMEMBERT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 657104ad6965..adc875910320 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -795,18 +795,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -919,9 +913,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CanineEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CANINE_START_DOCSTRING = r""" diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 7bab0aea6eb9..ef1c265723b6 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -742,9 +742,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CHINESE_CLIP_START_DOCSTRING = r""" @@ -909,20 +910,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1018,16 +1014,10 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1d17a5188387..025b59ae4b97 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -939,15 +939,8 @@ def forward( input_dimensions = self.input_resolutions[i] if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1595,20 +1588,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1701,9 +1689,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ClapTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 3a894b9727c9..56f24c157f83 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIP_START_DOCSTRING = r""" @@ -639,18 +640,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 96f13217aaf8..7a0e52926983 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -479,9 +479,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPSegEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIPSEG_START_DOCSTRING = r""" @@ -648,18 +649,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 05699ef15c72..9a5509a9ed86 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -339,9 +339,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CodeGenModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CODEGEN_START_DOCSTRING = r""" @@ -542,21 +543,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 69937afefce4..01dbf8ecd59c 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1171,9 +1171,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConditionalDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONDITIONAL_DETR_START_DOCSTRING = r""" @@ -1518,15 +1519,8 @@ def forward( # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index a6fccf5b72b4..da577a589614 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -264,9 +264,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SeparableConv1D(nn.Module): @@ -632,20 +633,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e6cf336517a5..e11112b53222 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -296,9 +296,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 3a268c713d50..f1ff89bb1243 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -317,9 +317,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXTV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 6d2dc596fa65..8a6c744ed69e 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -556,9 +556,10 @@ def _init_weights(self, module): elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CpmAntEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CPMANT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 4435e9b8d017..a99b6f3a6dc3 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -293,15 +293,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -593,17 +586,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -761,9 +748,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_AUDIO_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a521ccb39aaf..507c2fc464d8 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -613,9 +608,10 @@ def _init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VECTEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index f8fe59587af0..2742d5ffc37b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -522,17 +522,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( @@ -585,9 +579,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 6f6c2af63a67..65ec497cecd8 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -457,20 +457,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: hidden_states = layer_module( @@ -839,9 +833,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index eda4f406cb31..2245ac549ada 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -501,20 +501,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - output_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + output_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( @@ -938,9 +932,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8e5053a4160d..19c2731a50a7 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -469,9 +469,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DecisionTransformerGPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -631,22 +632,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index ea4555d5ae21..220fcf0d0660 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1088,9 +1088,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DeformableDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEFORMABLE_DETR_START_DOCSTRING = r""" @@ -1383,15 +1384,8 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 38c28dbbedc6..6e97e932b533 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -357,17 +357,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -415,9 +409,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, DeiTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index eca5ba014e51..9e7a73c5880b 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -504,9 +504,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MCTCTEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MCTCT_START_DOCSTRING = r""" @@ -616,18 +617,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 3ace323e8224..fb1cc7f0fb84 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -456,9 +456,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OpenLlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPEN_LLAMA_INPUTS_DOCSTRING = r""" @@ -665,20 +666,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, None, + output_attentions, + None, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 75415dbe77bf..c9f31c714446 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrajectoryTransformerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -550,15 +551,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 4ef18f54158f..52c9e1242422 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -387,9 +387,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VanModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 1aab38c28913..a6f979eaeea6 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -979,9 +979,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetaDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETA_START_DOCSTRING = r""" @@ -1275,15 +1276,8 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index d2b6ea07d7b7..1c09e3e3d7b2 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -927,9 +927,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETR_START_DOCSTRING = r""" @@ -1253,15 +1254,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 89c6ed2e2a88..eb4d3f2ff296 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -660,7 +660,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DinatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6e4446faddd5..1440b6d615fb 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -447,17 +447,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -516,9 +510,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Dinov2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DINOV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index f26b5846972d..3768dd6e91ca 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -358,18 +358,12 @@ def forward( all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_state, attn_mask, head_mask[i], + output_attentions, ) else: layer_outputs = layer_module( @@ -430,9 +424,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Transformer): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DISTILBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 0d833406e259..76d525717f8c 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -749,15 +749,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -826,9 +819,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DonutSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 944ce142b0ad..c258343f6cfd 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -164,9 +164,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DPREncoder(DPRPreTrainedModel): diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c36656a..2621fa338015 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -528,17 +528,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -818,9 +812,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DPTViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DPT_START_DOCSTRING = r""" diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 478aeecee02b..d1b2c9940343 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -500,9 +500,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EfficientNetBlock): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index da3ee8e51d36..fde5632c09c3 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -571,20 +571,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -692,9 +687,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ElectraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 697fb3c94fbb..28c20da3d5eb 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -473,9 +473,10 @@ def _init_weights(self, module): elif "bias" in name: nn.init.constant_(param, 0.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (EncodecEncoder, EncodecDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ENCODEC_START_DOCSTRING = r""" diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 787db7272642..a13fd19a900c 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -265,10 +265,10 @@ def tie_weights(self): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d55155f80093..330cb5033160 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -506,20 +506,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -680,9 +675,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/ernie_m/modeling_ernie_m.py b/src/transformers/models/ernie_m/modeling_ernie_m.py index 9c53ddd73c85..b26ee0fcafd1 100755 --- a/src/transformers/models/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/ernie_m/modeling_ernie_m.py @@ -429,9 +429,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 7a07495ba7e5..86bd20a46480 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -605,20 +605,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -710,9 +705,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EsmEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 6eaeed419977..642e60a72f91 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1097,9 +1097,10 @@ def _init_weights(self, module: nn.Module): module.weight.data.fill_(1.0) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, FalconModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_cache_to_standard_format( @@ -1278,21 +1279,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, attention_mask, position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8de647c8299a..1fbf49f9e127 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -663,18 +663,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -879,9 +873,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, FlavaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 45042147761d..b84761536bac 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,14 +292,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) else: layer_outputs = layer_module(hidden_states) @@ -431,9 +424,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 8d18a8c63fda..87ec98169626 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -586,15 +586,8 @@ def forward( for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - stage_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(stage_module), + stage_outputs = self.gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, ) @@ -659,9 +652,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FocalNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FOCALNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 03312420ca6e..37f9890ee3dd 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -70,9 +70,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FuyuForCausalLM): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FUYU_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 00707e42dd08..293b9c789d56 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -452,18 +452,13 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -533,9 +528,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GitEncoder, GitVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GIT_START_DOCSTRING = r""" @@ -878,18 +874,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 838e7ca29925..24826a76bc04 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -480,9 +480,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -877,22 +878,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( @@ -1623,7 +1618,6 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index be90f61e45bf..37c51b40c9a7 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -405,9 +405,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_BIGCODE_START_DOCSTRING = r""" @@ -650,22 +651,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3ad49554c0ac..ed1e62bf175f 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,9 +384,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_NEO_START_DOCSTRING = r""" @@ -604,20 +605,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9391805a77b8..cf0aa0645ae0 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -78,9 +78,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXAttention(nn.Module): @@ -641,20 +642,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for layer_past - return module(*inputs, use_cache, None, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_ids, head_mask[i], + use_cache, + None, + output_attentions, ) else: outputs = layer( diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 98753edeb544..c1c5527a4655 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -66,9 +66,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXJapaneseModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXJapaneseAttention(nn.Module): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index acdbb8c49211..2910f9535f66 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -363,9 +363,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTJModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPTJ_START_DOCSTRING = r""" @@ -669,21 +670,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 24917fcfdb07..84d956c9f57e 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -759,9 +759,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GPTSanJapaneseAttention,)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right def _shift_right(self, input_ids): diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index 8247745a3bc3..68ed6d265e70 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -772,9 +772,10 @@ def _init_weights( module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GraphormerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GraphormerModel(GraphormerPreTrainedModel): diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 59ff60ed765a..a9de67143846 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -492,7 +492,6 @@ def __init__( self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size)) else: self.group_token = None - self.gradient_checkpointing = False self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)]) if num_group_token > 0: @@ -805,9 +804,10 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GROUPVIT_START_DOCSTRING = r""" @@ -1031,18 +1031,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1a7bde45efc1..732e6be2f8dd 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -346,15 +346,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -731,17 +724,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -821,17 +808,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -895,9 +876,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 316f36561308..28841903a1a3 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer +from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer logger = logging.get_logger(__name__) @@ -978,9 +978,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, IdeficsModel): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1098,7 +1099,6 @@ def __init__(self, config: IdeficsConfig): self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1339,7 +1339,7 @@ def vblock( ) use_cache = False - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( vblock, decoder_layer, hidden_states, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index d4966a240d84..24dc3e9396aa 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -401,18 +401,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 54edcd30fc87..a365731ed53d 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -525,9 +525,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ImageGPTModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None IMAGEGPT_START_DOCSTRING = r""" @@ -816,22 +817,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index e7b35174ca7e..53518760cc00 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -924,9 +924,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (InformerDecoder, InformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INFORMER_START_DOCSTRING = r""" @@ -1215,21 +1216,15 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1438,16 +1433,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1455,6 +1442,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 082900a6652f..d4cb7a1fa00b 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -304,9 +304,14 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, InstructBlipEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) INSTRUCTBLIP_START_DOCSTRING = r""" @@ -462,17 +467,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -939,15 +938,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 884a2799728b..ce6d4302bccc 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -487,20 +487,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -638,9 +633,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LAYOUTLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index ef970edfdc91..8f6260fdda49 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -439,18 +439,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) @@ -514,9 +508,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 30ab0a5e8620..e387707e52da 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -657,19 +657,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) - # The above line will cause error: - # RuntimeError: Trying to backward through the graph a second time - # (or directly access saved tensors after they have already been freed). - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f0c22ed9502c..61bbd4156b46 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1155,9 +1155,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LEDDecoder, LEDEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1876,20 +1877,15 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2142,16 +2138,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -2159,6 +2147,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 0accc28391bd..5acaaeba9004 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -507,9 +507,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LevitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 46fe2d3e9cd7..4fd7a85affd7 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -514,19 +514,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layout_inputs, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( @@ -607,9 +601,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LiltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 541455d86afd..279884dc164f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -827,9 +827,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1013,16 +1014,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 33bf9a6f9268..b4f20b452558 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1304,20 +1304,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = layer_module( @@ -1439,9 +1434,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LongformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c80d2349832c..9abbfa2f2001 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -775,7 +775,6 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = False # Relativen attention bias & Layer norm for global attention if self.has_relative_attention_bias: @@ -1340,10 +1339,10 @@ def _init_weights(self, module): mean=0.0, std=factor * ((d_model) ** -0.5) ) - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LongT5Attention, LongT5Stack)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): @@ -1510,15 +1509,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1528,6 +1520,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 6913ede09d1c..3b5f4d0bf71d 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -788,19 +788,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, word_hidden_states, entity_hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( @@ -920,9 +914,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LukeEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 6db8bbb5213b..4ebe11f3f3b3 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -552,9 +552,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (M2M100Decoder, M2M100Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None M2M_100_START_DOCSTRING = r""" @@ -820,18 +821,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1066,16 +1061,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -1083,6 +1070,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 69de5b2e7d0e..e2e09b564b0e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -500,9 +500,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MarianDecoder, MarianEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -788,18 +789,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1037,16 +1032,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1054,6 +1041,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 530c66a0c80b..80498efb3cad 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -648,20 +648,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e839b16f6257..86eccc478753 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1864,20 +1864,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, None, None, + output_attentions, ) else: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 87b91ed64b62..7df8b60792a0 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -848,20 +848,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, ) else: layer_outputs = decoder_layer( @@ -1619,11 +1613,13 @@ def _init_weights(self, module: nn.Module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerPixelLevelModule): - module.encoder.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 357ac9d4aaca..89c6a0c0e0b4 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -688,15 +688,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, layer_head_mask + layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, ) else: layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( @@ -752,9 +748,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b53ad8848dd3..7c4c9bdf9598 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -516,9 +516,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (MBartDecoder, MBartDecoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (MBartDecoder, MBartEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -828,18 +829,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1086,16 +1081,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1103,6 +1090,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5d0ad6e3410c..c23666f10b72 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -551,20 +551,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -728,9 +723,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MegatronBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 5d1f5bea7bfd..1257b4df39c0 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -333,9 +333,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: MgpstrEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, MgpstrEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MGP_STR_START_DOCSTRING = r""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 53667b6a82c3..fbedb20dbc15 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -816,9 +816,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MistralModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MISTRAL_INPUTS_DOCSTRING = r""" @@ -1020,19 +1021,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index c3accb21e05e..c664c02a883b 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -626,15 +626,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: @@ -672,9 +665,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5a0e08d7344d..b88925f41b83 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -582,15 +582,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: @@ -629,9 +622,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVITV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index d760bec9854a..ede306e71b86 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -294,9 +294,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, MptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_mpt_cache( @@ -523,20 +524,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index d400fea6d23d..f6cb65889a37 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -766,15 +766,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, ) @@ -871,9 +864,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MRA_START_DOCSTRING = r""" diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 186db94dad7f..2951ffc889dc 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -845,9 +845,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MT5Attention, MT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1073,15 +1074,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1091,6 +1085,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index bcc6bc82a2f5..a740ed47074b 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -475,9 +475,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MusicgenDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MUSICGEN_START_DOCSTRING = r""" @@ -826,16 +827,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -843,6 +836,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -1562,10 +1557,10 @@ def tie_weights(self): self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.text_encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5c1ed05249ef..122b49287872 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -563,9 +563,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -949,19 +950,13 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), (self_attn_prompt[idx] if self.use_prompt else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1227,16 +1222,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1246,6 +1233,8 @@ def custom_forward(*inputs): self_attn_prompt[idx] if self.use_prompt else None, cross_attn_prompt[idx] if self.use_prompt else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index ecc745b558dd..4f7206a5e8ec 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -639,7 +639,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: NatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index fa31e94f4d2e..cd43688e3f74 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -577,20 +577,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -752,9 +747,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NezhaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index a88d53a340f4..cbed1e1b1530 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -874,9 +874,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NLLB_MOE_START_DOCSTRING = r""" @@ -1153,18 +1154,12 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1426,15 +1421,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -1442,6 +1430,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 51ee73ab72d3..9b2052eb6ca4 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -370,17 +370,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) @@ -477,9 +471,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NystromformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NYSTROMFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 5b6220f88169..165684542d85 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2616,7 +2616,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = self.gradient_checkpointing_func(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f3f24652434..9925e7b4a46b 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -411,9 +411,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (OPTDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPT_INPUTS_DOCSTRING = r""" @@ -691,20 +692,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 451cc4a69126..a1491d15ea55 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -584,9 +584,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Owlv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLV2_START_DOCSTRING = r""" @@ -764,18 +765,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1378,6 +1373,7 @@ def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor: """Predicts the probability that each image feature token is an object. + Args: image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)): Features extracted from the image. diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 66cfb8092df5..68037d13950e 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -576,9 +576,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OwlViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLVIT_START_DOCSTRING = r""" @@ -753,18 +754,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 55856f7b06b6..058ecd1775a9 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -500,9 +500,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusDecoder, PegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_START_DOCSTRING = r""" @@ -803,18 +804,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1087,16 +1082,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1104,6 +1091,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index e87e9c7164ab..6eaddf642a8b 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -780,9 +780,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_X_START_DOCSTRING = r""" @@ -1071,18 +1072,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, global_hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1331,21 +1326,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index d73cc4484484..8043fc8699a6 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PersimmonModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PERSIMMON_INPUTS_DOCSTRING = r""" @@ -668,19 +669,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 58041820c1fb..cfc2b137c579 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -343,18 +343,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -563,9 +557,10 @@ def __init__(self, config: Pix2StructConfig): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, value: bool = False) -> None: - if isinstance(module, Pix2StructVisionEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: + if isinstance(module, (Pix2StructVisionEncoder, Pix2StructVisionAttention)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def get_input_embeddings(self): return self.embeddings.patch_projection @@ -1320,9 +1315,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pix2StructTextAttention, Pix2StructTextModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def __init__(self, config): super().__init__(config) @@ -1495,15 +1491,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1513,6 +1502,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3a880839236d..1e047fd37267 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -517,9 +517,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PLBartDecoder, PLBartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PLBART_START_DOCSTRING = r""" @@ -807,18 +808,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1064,16 +1059,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1081,6 +1068,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 6acc8ec98e69..209533e31990 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -282,9 +282,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PoolFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None POOLFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5a67b8044b09..5cf7039e9f0c 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -739,9 +739,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -902,15 +903,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -920,6 +914,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 241a9efea36a..e4c28659cb48 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -557,9 +557,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1329,18 +1330,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1569,16 +1564,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -1590,6 +1577,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 2dd452ec1df1..356b7c14afa8 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -489,9 +489,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ) - def _set_gradient_checkpointing(self, module: PvtEncoder, value: bool = False): + def _set_gradient_checkpointing(self, module: PvtEncoder, gradient_checkpointing_func=None): if isinstance(module, PvtEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PVT_START_DOCSTRING = r""" diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index fead8fc0cf7f..0a2546a9b64e 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -581,20 +581,15 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -757,9 +752,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, QDQBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index aa738d782b7b..86b37b21560b 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -586,20 +586,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 07ef29fd3332..21050f07fda4 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -293,9 +293,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RegNetModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REGNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 235bff89f8a3..e5e662a9b556 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -543,20 +543,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -673,9 +668,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RemBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REMBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index f2d207c2189f..e6b1d85b2a46 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -283,9 +283,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ResNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None RESNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 6d4cc991d22c..32a19c088317 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -612,9 +607,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index da1cd6331bc3..78ca20684540 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -512,20 +512,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -615,9 +610,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaPreLayerNormEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index a5b1b63050b1..3a58efa9140c 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -644,20 +644,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -796,9 +791,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoCBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROC_BERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b9c36a305ff1..3893e27b028f 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -578,21 +578,16 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, sinusoidal_pos, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -715,9 +710,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index db41bd3c9538..275233321372 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -466,9 +466,10 @@ def _init_weights(self, module): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RwkvModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -676,16 +677,8 @@ def forward( all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - hidden_states, state, attentions = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), hidden_states, state + hidden_states, state, attentions = self.gradient_checkpointing_func( + block.__call__, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index abf5544a5b4d..1bd6fcdc2a8f 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1042,15 +1042,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 91ec6a8f9b87..ea79c7341883 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -892,15 +892,8 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, @@ -1547,9 +1540,10 @@ def _init_weights(self, module): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TConformerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride @@ -1856,18 +1850,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2130,16 +2118,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2147,6 +2127,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 34f9c84235cc..36416c168c36 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -360,15 +360,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -673,17 +666,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -756,9 +743,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SEWEncoder, SEWFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 661a8c03b1a5..39c9641b9489 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -453,15 +453,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -1134,20 +1127,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - output_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + output_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( @@ -1322,9 +1309,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, SEWDTransformerEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (SEWDEncoder, SEWDFeatureEncoder, SEWDTransformerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SEWD_START_DOCSTRING = r""" diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index e80c26e2698d..ec255fab9bc7 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -249,10 +249,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 31c9b6cfe935..73a02fe66df7 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -559,9 +559,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -817,18 +818,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1065,16 +1060,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1082,6 +1069,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index f9b5dec42092..acee2b15a44f 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -437,9 +437,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Speech2Text2Decoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPEECH_TO_TEXT_2_START_DOCSTRING = r""" @@ -669,16 +670,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 9b8ab3d3805a..b8fea796647b 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -520,15 +520,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -1281,9 +1274,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SpeechT5Encoder(SpeechT5PreTrainedModel): @@ -1386,19 +1380,13 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), position_bias, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1439,7 +1427,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1476,7 +1463,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1519,7 +1505,6 @@ class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1715,16 +1700,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1732,6 +1709,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -1788,7 +1767,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1836,7 +1814,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1889,7 +1866,6 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f72ffb10111b..1bdf8f3f5f91 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -459,20 +459,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -544,9 +539,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SplinterEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPLINTER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index ff72f87506d3..4170ce153bbf 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -442,9 +442,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) - def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, SwiftFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIFTFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 45a7aa718cf0..c2f15dbbf273 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -825,15 +825,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -901,9 +894,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 02ec39edb0fe..5d5356144245 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel): config_class = SwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" - supports_gradient_checkpointing = True - - def _set_gradient_checkpointing(self, module, value=False) -> None: - if isinstance(module, TFSwinEncoder): - module.gradient_checkpointing = value SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index a8a17bdf584b..47ce01d16916 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -746,15 +746,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) @@ -802,9 +795,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swin2SREncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN2SR_START_DOCSTRING = r""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index a4224e16df3c..6daad938a623 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -906,15 +906,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -983,9 +976,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swinv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWINV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0a402ea2d6af..32d030728de5 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -865,9 +865,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1039,15 +1040,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1057,6 +1051,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0e7237ea36b6..c796a9cf24cf 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -873,9 +873,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (T5Attention, T5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1100,15 +1101,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1118,6 +1112,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index b6012700ee7d..e1da557b0017 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -837,9 +837,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TableTransformerDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TABLE_TRANSFORMER_START_DOCSTRING = r""" @@ -1149,15 +1150,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index cdaa4b3e2725..de05d77ec943 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -646,20 +646,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_values, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_values, + output_attentions, ) else: layer_outputs = layer_module( @@ -778,9 +773,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TapasEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TAPAS_START_DOCSTRING = r""" diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 2caca5bd1051..1fa6a963f58f 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -663,9 +663,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" @@ -946,18 +947,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1163,16 +1158,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1180,6 +1167,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 676bcf7a5e27..044705c35e54 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -439,16 +439,10 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions) @@ -494,9 +488,10 @@ def _init_weights(self, module): nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) module.patch_embeddings.apply(self._init_weights) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TimesformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIMESFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c0541814be46..ada8638a03b6 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -454,9 +454,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrOCRDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TROCR_START_DOCSTRING = r""" @@ -701,16 +702,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -718,6 +711,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 464c3e76a11f..a37265f37c7a 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -560,18 +560,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -616,9 +610,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, TvltEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (TvltEncoder, TvltDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TVLT_START_DOCSTRING = r""" @@ -877,17 +872,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index ffafd1581140..a5b58444fe4e 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -556,9 +556,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UMT5Attention, UMT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -709,15 +710,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -725,6 +719,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c475ab7f80f8..db14d5bca51f 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -384,15 +384,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -767,17 +760,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -857,17 +844,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1039,9 +1020,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_START_DOCSTRING = r""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 3fcc9549bbdc..8a9a63804b56 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -398,15 +398,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -781,17 +774,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -871,17 +858,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1053,9 +1034,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_SAT_START_DOCSTRING = r""" diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index b56b508d14ae..04b8c94e1351 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -315,9 +315,10 @@ def init_weights(self): if self.auxiliary_head is not None: self.auxiliary_head.init_weights() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UPERNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 07c32d149290..277280954fd6 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -434,17 +434,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -489,9 +483,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, VideoMAEEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIDEOMAE_START_DOCSTRING = r""" @@ -726,17 +721,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a36d58bd235b..482bd08359bd 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -531,18 +531,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -591,9 +585,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d3e464cbfffa..84275cc33a76 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -225,10 +225,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 81ad1068483a..425a125a0b89 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -418,18 +418,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -547,9 +541,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VisualBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8fdacdddf04c..67dbddf8766a 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -397,17 +397,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -467,9 +461,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 008f6b3c9db5..959522843f7a 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -415,17 +415,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -486,9 +480,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTHybridEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f3686..e156fdc3292c 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -536,17 +536,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -591,9 +585,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ViTMAEEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MAE_START_DOCSTRING = r""" @@ -793,17 +788,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d622c..b727c331cfb4 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -387,17 +387,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -444,9 +438,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTMSNEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MSN_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e89fdbd7a336..9bb3991fabf1 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -565,17 +565,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -666,9 +660,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.norm3.weight.data.zero_() module.norm3.bias.data.zero_() - def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, VitDetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITDET_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index b23bdd21d56b..f5025a37e71c 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -86,9 +86,15 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + for backbone_module in module.modules(): + if hasattr(backbone_module, "gradient_checkpointing"): + backbone_module.gradient_checkpointing_func = gradient_checkpointing_func + backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b9a1f1ae15..b621bde35e61 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1167,18 +1167,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, padding_mask, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1296,9 +1290,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (VitsTextEncoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, VitsEncoder): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITS_START_DOCSTRING = r""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index fd35668572a7..50cb82fb4e18 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -338,17 +338,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -414,9 +408,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=self.config.initializer_range) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VivitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a6e02a0476f1..9f48e529627e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -451,15 +451,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -803,17 +796,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -892,17 +879,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1173,9 +1154,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_adapters(self): if self.config.adapter_attn_dim is None: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f162c5142970..5fba773ee0cb 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -518,15 +518,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -911,18 +904,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, + output_attentions, ) else: layer_outputs = layer( @@ -1178,9 +1165,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAV2VEC2_CONFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5013837cbdce..55b19e4c4143 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -354,15 +354,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -713,18 +706,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -804,18 +791,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -1052,9 +1033,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAVLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8962324471ca..d6d0302727cb 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -685,9 +685,10 @@ def _init_weights(self, module): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WhisperDecoder, WhisperEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -942,18 +943,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1174,16 +1169,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1191,6 +1178,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index da7eddff8df8..6c9cc02db9c8 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -534,9 +534,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None X_CLIP_START_DOCSTRING = r""" @@ -703,18 +704,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -950,18 +945,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 0c769dbbb5f3..1880a7832193 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -503,9 +503,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XGLMModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( @@ -674,16 +675,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -691,6 +684,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index cde05cfe8a8a..9a9f02b74a65 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -570,9 +570,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1349,18 +1350,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1592,16 +1587,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -1613,6 +1600,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index da454b1e3331..da99b2806fb6 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -511,20 +511,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -614,9 +609,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XLMRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 26e0361abdb5..49f7c0751721 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -499,20 +499,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 28fddc2fdbd6..5f7b42f266fb 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -573,21 +573,16 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, lang_ids, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -680,9 +675,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XmodEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def set_default_language(self, language: str): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae6e..f6cbaecd014e 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -492,17 +492,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -551,9 +545,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, YolosEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOLOS_START_DOCSTRING = r""" diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 5edd7f883542..8db66d221061 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -561,17 +561,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) @@ -668,9 +662,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, YosoEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOSO_START_DOCSTRING = r""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 02fcb7d2f511..0b5af845c9aa 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -544,19 +544,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -679,9 +675,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2024,9 +2021,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2312,18 +2310,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2551,15 +2543,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2567,6 +2552,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 34f5bae3746f..7e1c471badf4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -349,10 +349,24 @@ def test_gradient_checkpointing_enable_disable(self): model.gradient_checkpointing_enable() self.assertTrue(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + # check disable works model.gradient_checkpointing_disable() self.assertFalse(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: @@ -569,6 +583,13 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + model.gradient_checkpointing_disable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions")