Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer / GC] Add gradient_checkpointing_kwargs in trainer and training arguments #27068

Merged
merged 6 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,12 @@ def _inner_training_loop(

# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
else:
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs

self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

model = self._wrap_model(self.model_wrapped)

Expand Down
8 changes: 8 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ class TrainingArguments:
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_args (`dict`, *optional*, defaults to `None`):
Key word arguments to be passed to the `gradient_checkpointing_enable` method.
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
Expand Down Expand Up @@ -1119,6 +1121,12 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
gradient_checkpointing_kwargs: dict = field(
default=None,
metadata={
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
},
)
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
Expand Down
59 changes: 58 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,38 @@ def forward(self, input_x, labels=None, **kwargs):
loss = nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

class RegressionPreTrainedModelWithGradientCheckpointing(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
supports_gradient_checkpointing = True

def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) for _ in range(4)])
self.head = nn.Linear(config.hidden_size, 1)
self.gradient_checkpointing = False
self.double_output = config.double_output

def forward(self, input_x, labels=None, **kwargs):
y = input_x.unsqueeze(0)

for layer in self.layers:
if self.training and self.gradient_checkpointing:
outputs = self._gradient_checkpointing_func(layer.__call__, y)
else:
outputs = layer(y)

y = outputs * 3

logits = self.head(y)

if labels is None:
return (logits, logits) if self.double_output else (logits,)

loss = nn.functional.mse_loss(logits, labels)

return (loss, y, y) if self.double_output else (loss, y)

class RegressionRandomPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
Expand Down Expand Up @@ -327,6 +359,7 @@ def get_regression_trainer(
a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, keep_report_to=False, **kwargs
):
label_names = kwargs.get("label_names", None)
gradient_checkpointing = kwargs.get("gradient_checkpointing", False)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)

Expand All @@ -336,7 +369,13 @@ def get_regression_trainer(
else:
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
# We infer the correct model class if one uses gradient_checkpointing or not
target_cls = (
RegressionPreTrainedModel
if not gradient_checkpointing
else RegressionPreTrainedModelWithGradientCheckpointing
)
model = target_cls(config)
else:
model = RegressionModel(a=a, b=b, double_output=double_output)

Expand Down Expand Up @@ -548,6 +587,24 @@ def test_gradient_accumulation(self):
trainer.train()
self.check_trained_model(trainer.model)

def test_gradient_checkpointing(self):
trainer = get_regression_trainer(
per_device_train_batch_size=1,
learning_rate=0.1,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
previous_params = {k: v.detach().clone() for k, v in trainer.model.named_parameters()}

trainer.train()

# Check if model weights have been updated
for k, v in trainer.model.named_parameters():
self.assertFalse(
torch.allclose(previous_params[k], v, rtol=1e-4, atol=1e-4),
f"Model weights for {k} have not been updated",
)

def test_training_loss(self):
n_gpus = max(1, get_gpu_count())

Expand Down