Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed] Allow HF optimizer and scheduler to be passed to deepspeed #10464

Merged
merged 32 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0541df1
pass hf optimizer and scheduler to deepspeed if not specified in ds c…
cli99 Mar 1, 2021
30ebb6f
pass hf optimizer and scheduler to deepspeed if not specified in ds c…
cli99 Mar 1, 2021
8416a78
Merge branch 'deepspeed' of https://github.com/cli99/transformers int…
cli99 Mar 1, 2021
aec38cb
update
cli99 Mar 2, 2021
1ed68e1
make init_deepspeed support config dict
stas00 Mar 2, 2021
98a1562
fix docstring formatting
stas00 Mar 2, 2021
333d8dc
clean up trainer's comments
stas00 Mar 2, 2021
9daef95
add new tests
stas00 Mar 2, 2021
c0060e9
fix type
stas00 Mar 2, 2021
14cdc4b
composit argparse doesn't work
stas00 Mar 2, 2021
83e4897
style
stas00 Mar 2, 2021
9c73ce3
add a new test, rename others
stas00 Mar 2, 2021
1aeb2f2
document new functionality
stas00 Mar 2, 2021
4cc0679
Merge remote-tracking branch 'origin/master' into deepspeed
stas00 Mar 8, 2021
e78f40e
complete tests, add docs
stas00 Mar 8, 2021
605358d
style
stas00 Mar 8, 2021
a17c77a
correct level
stas00 Mar 8, 2021
c5f06b6
Apply suggestions from code review
stas00 Mar 9, 2021
f6d0067
add new methods to the doc
stas00 Mar 9, 2021
bb448d6
Merge remote-tracking branch 'origin/master' into deepspeed
stas00 Mar 12, 2021
20f395c
must tell DS we are using a non-native optimizer
stas00 Mar 12, 2021
8e20811
add protection against cpu_offload + HF optimizer combo
stas00 Mar 13, 2021
a2d877d
fix the cli overrides
stas00 Mar 13, 2021
e4abec8
sync docs + tests
stas00 Mar 13, 2021
dccb770
restore AdamW
stas00 Mar 13, 2021
eb4051f
better docs
stas00 Mar 13, 2021
3b09360
need new version
stas00 Mar 13, 2021
a354f42
no longer needed
stas00 Mar 13, 2021
da2fe96
remove outdate information
stas00 Mar 13, 2021
dfb0d57
refactor duplicated code
stas00 Mar 13, 2021
e758a3e
Merge remote-tracking branch 'origin/master' into deepspeed
stas00 Mar 16, 2021
fb84a93
þMerge branch 'master' into deepspeed
cli99 Mar 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions docs/source/main_classes/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,31 @@ to achieve the same configuration as provided by the longer json file in the fir
When you execute the program, DeepSpeed will log the configuration it received from the :class:`~transformers.Trainer`
to the console, so you can see exactly what the final configuration was passed to it.


Passing Configuration
=======================================================================================================================

As discussed in this document normally the DeepSpeed configuration is passed as a path to a json file, but if you're
not using the command line interface to configure the training, and instead instantiate the Trainer via
:class:`~transformers.TrainingArguments` then for the ``deepspeed`` argument you can pass a nested ``dict``. This
allows you to create the configuration on the fly and doesn't require you to write it to the file system before passing
it to :class:`~transformers.TrainingArguments`.

To summarize you can do:

.. code-block:: python

TrainingArguments(..., deespeed="/path/to/ds_config.json")

or:

.. code-block:: python

ds_config_dict=dict(scheduler=scheduler_params, optimizer=optimizer_params)
TrainingArguments(..., deespeed=ds_config_dict)



Shared Configuration
=======================================================================================================================

Expand Down Expand Up @@ -750,9 +775,26 @@ no equivalent command line arguments.



Optimizer
Optimizer and Scheduler
=======================================================================================================================

You can mix and match DeepSpeed and HuggingFace schedulers and optimizers, with the exception of HuggingFace scheduler
and DeepSpeed optimizer:

+--------------+--------------+--------------+
| Combos | HF Scheduler | DS Scheduler |
+--------------+--------------+--------------+
| HF Optimizer | Yes | Yes |
+--------------+--------------+--------------+
| DS Optimizer | No | Yes |
+--------------+--------------+--------------+




Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""


DeepSpeed's main optimizers are Adam, OneBitAdam, and Lamb. These have been thoroughly tested with ZeRO and are thus
recommended to be used. It, however, can import other optimizers from ``torch``. The full documentation is `here
Expand Down Expand Up @@ -787,7 +829,7 @@ make sure to adjust the values. e.g. if use Adam you will want ``weight_decay``


Scheduler
=======================================================================================================================
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

DeepSpeed supports LRRangeTest, OneCycle, WarmupLR and WarmupDecayLR LR schedulers. The full documentation is `here
<https://www.deepspeed.ai/docs/config-json/#scheduler-parameters>`__.
Expand Down
47 changes: 45 additions & 2 deletions examples/tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import sys
import unittest
from copy import deepcopy

from transformers.integrations import is_deepspeed_available
from transformers.testing_utils import (
Expand Down Expand Up @@ -67,17 +69,58 @@ def setUp(self):
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
self.ds_config_file = f"{self.test_file_dir_str}/ds_config.json"
with io.open(self.ds_config_file, "r", encoding="utf-8") as f:
self.ds_config_dict = json.load(f)

def test_fake_notebook_no_launcher(self):

# this setup emulates a notebook where a launcher needs to be emulated by hand

with CaptureStd() as cs:
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file)
trainer.train()
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"

# Test various combos
# 1. DS scheduler + DS optimizer: this is already tested by most other tests
# 2. HF scheduler + HF optimizer:
# 3. DS scheduler + HF optimizer:
# 4. HF scheduler + DS optimizer:

def test_hf_scheduler_hf_optimizer(self):
a = 0
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
del ds_config_dict["scheduler"] # force default HF Trainer scheduler
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
trainer.train()
new_a = trainer.model.a.item()
self.assertNotEqual(new_a, a)

def test_ds_scheduler_hf_optimizer(self):
a = 0
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
trainer.train()
new_a = trainer.model.a.item()
self.assertNotEqual(new_a, a)

def test_hf_scheduler_ds_optimizer(self):
# this combo is not possible at the moment
a = 0
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["scheduler"] # force default HF Trainer scheduler
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertTrue("HF Scheduler + DeepSpeed Optimizer combination is not possible" in str(context.exception))

def test_early_get_last_lr(self):
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
Expand Down
67 changes: 26 additions & 41 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pathlib import Path
from types import SimpleNamespace

from .trainer_utils import SchedulerType
from .utils import logging


Expand Down Expand Up @@ -285,8 +284,13 @@ def init_deepspeed(trainer, num_training_steps):
ds_config_file = args.deepspeed
model = trainer.model

with io.open(ds_config_file, "r", encoding="utf-8") as f:
config = json.load(f)
if isinstance(args.deepspeed, dict):
config = args.deepspeed
elif isinstance(args.deepspeed, str):
with io.open(ds_config_file, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")

# The following code translates relevant trainer's cl args into the DS config

Expand Down Expand Up @@ -318,6 +322,14 @@ def init_deepspeed(trainer, num_training_steps):
else: # override only if the ds config doesn't already have this section
config["gradient_clipping"] = args.max_grad_norm

# Optimizer + Scheduler
# Currently support combos:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes
# 4. HF scheduler + DS optimizer: No

optimizer = None
if "optimizer" in config:
logger.info(
f"Keeping the `optimizer` config from {ds_config_file} intact, ignoring any optimizer-specific cl args"
Expand All @@ -326,22 +338,8 @@ def init_deepspeed(trainer, num_training_steps):
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
# But trainer uses AdamW by default.
# To use other optimizers so using a different scheduler requires voiding warranty with: `zero_allow_untested_optimizer`

optimizer_configs = {
"AdamW": {
"lr": args.learning_rate,
"betas": [args.adam_beta1, args.adam_beta2],
"eps": args.adam_epsilon,
"weight_decay": args.weight_decay,
}
}
optimizer = "AdamW"

config["zero_allow_untested_optimizer"] = True
config["optimizer"] = {
"type": optimizer,
"params": optimizer_configs[optimizer],
}
trainer.create_optimizer()
optimizer = trainer.optimizer

# DS schedulers (deepspeed/runtime/lr_schedules.py):
#
Expand All @@ -351,34 +349,19 @@ def init_deepspeed(trainer, num_training_steps):
# OneCycle | na | na | 1CLR
# WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
lr_scheduler = None
if "scheduler" in config:
logger.info(
f"Keeping the `scheduler` config from {ds_config_file} intact, ignoring any scheduler-specific cl args"
)
else: # override only if the ds config doesn't already have this section
if args.lr_scheduler_type == SchedulerType.LINEAR:
scheduler = "WarmupDecayLR"
params = {
"last_batch_iteration": -1,
"total_num_steps": num_training_steps,
"warmup_min_lr": 0,
"warmup_max_lr": args.learning_rate,
"warmup_num_steps": args.warmup_steps,
}
elif args.lr_scheduler_type == SchedulerType.CONSTANT_WITH_WARMUP:
scheduler = "WarmupLR"
params = {
"warmup_min_lr": 0,
"warmup_max_lr": args.learning_rate,
"warmup_num_steps": args.warmup_steps,
}
if "optimizer" in config:
# to make this option work, we need to init DS optimizer first, then init HS scheduler,
# then pass the HS scheduler to DS init
raise ValueError("At the moment HF Scheduler + DeepSpeed Optimizer combination is not possible")
else:
raise ValueError(f"{args.lr_scheduler_type} scheduler type is not supported by DeepSpeed")

config["scheduler"] = {
"type": scheduler,
"params": params,
}
trainer.create_scheduler(num_training_steps=num_training_steps)
lr_scheduler = trainer.lr_scheduler

# fp16
if trainer.fp16_backend is not None:
Expand Down Expand Up @@ -415,6 +398,8 @@ def init_deepspeed(trainer, num_training_steps):
model=model,
model_parameters=model_parameters,
config_params=config,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)

return model, optimizer, lr_scheduler
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,14 @@ def assert_screenout(out, what):
class CaptureStd:
"""
Context manager to capture:
stdout, clean it up and make it available via obj.out stderr, and make it available via obj.err

init arguments: - out - capture stdout: True/False, default True - err - capture stdout: True/False, default
True
- stdout, clean it up and make it available via obj.out
- stderr, and make it available via obj.err

init arguments:

- out - capture stdout: True/False, default True
- err - capture stdout: True/False, default True

Examples::

Expand Down
29 changes: 24 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ def __init__(
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
# 3. full fp16 eval - since the model needs to be half'ed first
# 4. Sharded DDP - same as MP
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
Expand All @@ -316,10 +322,6 @@ def __init__(
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer

# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
if self.place_model_on_device:
model = model.to(args.device)

Expand Down Expand Up @@ -609,6 +611,16 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.

We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
self.create_optimizer()
self.create_scheduler(num_training_steps)

def create_optimizer(self):
"""
Setup the optimizer.

We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
Expand Down Expand Up @@ -644,6 +656,13 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

def create_scheduler(self, num_training_steps: int):
"""
Setup the scheduler. The optmizer of the trainer must have been set up.

Args:
num_training_steps (int): The number of training steps to do.
"""
if self.lr_scheduler is None:
warmup_steps = (
self.args.warmup_steps
Expand Down Expand Up @@ -889,7 +908,7 @@ def train(
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module
self.model_wrapped = model # will get further wrapped in DDP
self.model_wrapped = model
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,10 @@ class TrainingArguments:

If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`.
deepspeed (:obj:`str`, `optional`):
deepspeed (:obj:`str`, :obj:`dict`, `optional`):
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
``ds_config.json``) or an already loaded json file as a :obj:`dict`"
label_smoothing_factor (:obj:`float`, `optional`, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
Expand Down Expand Up @@ -478,7 +479,9 @@ class TrainingArguments:
)
deepspeed: Optional[str] = field(
default=None,
metadata={"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json)"},
metadata={
"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
Expand Down