Skip to content

Commit

Permalink
[core] Refactor of gradient_checkpointing (huggingface#27020)
Browse files Browse the repository at this point in the history
* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
  • Loading branch information
younesbelkada authored and EduardoPach committed Nov 19, 2023
1 parent 0a5af9f commit b7c7429
Show file tree
Hide file tree
Showing 188 changed files with 1,276 additions and 2,296 deletions.
22 changes: 19 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import gc
import importlib.metadata
import inspect
Expand Down Expand Up @@ -1848,16 +1849,31 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self):
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

gradient_checkpointing_func = functools.partial(
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)

self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func))

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
Expand All @@ -1874,7 +1890,7 @@ def gradient_checkpointing_disable(self):
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None))

if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()
Expand Down
20 changes: 8 additions & 12 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,20 +1095,15 @@ def forward(
past_key_value = past_key_values[i] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
Expand Down Expand Up @@ -1197,9 +1192,10 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AlignTextModel, AlignVisionModel)):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


@add_start_docstrings(
Expand Down
33 changes: 12 additions & 21 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,20 +646,15 @@ def forward(
past_key_value = past_key_values[i] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
Expand Down Expand Up @@ -960,18 +955,12 @@ def forward(
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
Expand Down Expand Up @@ -1089,11 +1078,13 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, AltCLIPEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, AltRobertaEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,11 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
Expand Down Expand Up @@ -395,9 +389,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
module.weight.data.fill_(1.0)

# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ASTEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
Expand Down
31 changes: 10 additions & 21 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,10 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AutoformerDecoder, AutoformerEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


AUTOFORMER_START_DOCSTRING = r"""
Expand Down Expand Up @@ -1207,18 +1208,12 @@ def forward(
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
Expand Down Expand Up @@ -1425,23 +1420,17 @@ def forward(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
Expand Down
19 changes: 7 additions & 12 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ def device(self) -> torch.device:

return get_parameter_device(self)

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


BARK_MODEL_START_DOCSTRING = """
Expand Down Expand Up @@ -637,20 +638,14 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
Expand Down
31 changes: 10 additions & 21 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,10 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None

@property
def dummy_inputs(self):
Expand Down Expand Up @@ -854,18 +855,12 @@ def forward(
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
Expand Down Expand Up @@ -1110,23 +1105,17 @@ def forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
Expand Down
17 changes: 6 additions & 11 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,17 +510,11 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
relative_position_bias = (
Expand Down Expand Up @@ -572,9 +566,10 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BeitEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


BEIT_START_DOCSTRING = r"""
Expand Down
Loading

0 comments on commit b7c7429

Please sign in to comment.