Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed] Allow HF optimizer and scheduler to be passed to deepspeed #10464

Merged
merged 32 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0541df1
pass hf optimizer and scheduler to deepspeed if not specified in ds c…
cli99 Mar 1, 2021
30ebb6f
pass hf optimizer and scheduler to deepspeed if not specified in ds c…
cli99 Mar 1, 2021
8416a78
Merge branch 'deepspeed' of https://github.com/cli99/transformers int…
cli99 Mar 1, 2021
aec38cb
update
cli99 Mar 2, 2021
1ed68e1
make init_deepspeed support config dict
stas00 Mar 2, 2021
98a1562
fix docstring formatting
stas00 Mar 2, 2021
333d8dc
clean up trainer's comments
stas00 Mar 2, 2021
9daef95
add new tests
stas00 Mar 2, 2021
c0060e9
fix type
stas00 Mar 2, 2021
14cdc4b
composit argparse doesn't work
stas00 Mar 2, 2021
83e4897
style
stas00 Mar 2, 2021
9c73ce3
add a new test, rename others
stas00 Mar 2, 2021
1aeb2f2
document new functionality
stas00 Mar 2, 2021
4cc0679
Merge remote-tracking branch 'origin/master' into deepspeed
stas00 Mar 8, 2021
e78f40e
complete tests, add docs
stas00 Mar 8, 2021
605358d
style
stas00 Mar 8, 2021
a17c77a
correct level
stas00 Mar 8, 2021
c5f06b6
Apply suggestions from code review
stas00 Mar 9, 2021
f6d0067
add new methods to the doc
stas00 Mar 9, 2021
bb448d6
Merge remote-tracking branch 'origin/master' into deepspeed
stas00 Mar 12, 2021
20f395c
must tell DS we are using a non-native optimizer
stas00 Mar 12, 2021
8e20811
add protection against cpu_offload + HF optimizer combo
stas00 Mar 13, 2021
a2d877d
fix the cli overrides
stas00 Mar 13, 2021
e4abec8
sync docs + tests
stas00 Mar 13, 2021
dccb770
restore AdamW
stas00 Mar 13, 2021
eb4051f
better docs
stas00 Mar 13, 2021
3b09360
need new version
stas00 Mar 13, 2021
a354f42
no longer needed
stas00 Mar 13, 2021
da2fe96
remove outdate information
stas00 Mar 13, 2021
dfb0d57
refactor duplicated code
stas00 Mar 13, 2021
e758a3e
Merge remote-tracking branch 'origin/master' into deepspeed
stas00 Mar 16, 2021
fb84a93
þMerge branch 'master' into deepspeed
cli99 Mar 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 96 additions & 59 deletions docs/source/main_classes/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ the above features. To inject custom behavior you can subclass them and override
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **create_optimizer_and_scheduler** -- Sets up the optimizer and learning rate scheduler if they were not passed at
init.
init. Note, that you can also subclass or override the ``create_optimizer`` and ``create_scheduler`` methods
separately.
- **create_optimizer** -- Sets up the optimizer if it wasn't passed at init.
- **create_scheduler** -- Sets up the learning rate scheduler if it wasn't passed at init.
- **compute_loss** - Computes the loss on a batch of training inputs.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
Expand Down Expand Up @@ -542,8 +545,6 @@ cell with:
"cpu_offload": true
},

"zero_allow_untested_optimizer": true,

"optimizer": {
"type": "AdamW",
"params": {
Expand Down Expand Up @@ -612,17 +613,11 @@ example ``.json`` files with:

Some more examples are to be found in the `main repo <https://github.com/microsoft/DeepSpeed>`__ as well.

While you always have to supply the DeepSpeed configuration file, you can configure the DeepSpeed integration in
several ways:

1. Supply most of the configuration inside the file, and just use a few required command line arguments. This is the
recommended way as it puts most of the configuration params in one place.
2. Supply just the ZeRO configuration params inside the file, and configure the rest using the normal
:class:`~transformers.Trainer` command line arguments.
3. Any variation of the first two ways.
When using DeepSpeed you always need to supply a DeepSpeed configuration file, yet some configuration parameters have
to be configured via the command line. You will find the nuances in the rest of this guide.

To get an idea of what DeepSpeed configuration file looks like, here is one that activates ZeRO stage 2 features,
enables FP16, uses AdamW optimizer and WarmupLR scheduler:
enables FP16, uses ``AdamW`` optimizer and ``WarmupLR`` scheduler:

.. code-block:: json

Expand Down Expand Up @@ -666,36 +661,33 @@ enables FP16, uses AdamW optimizer and WarmupLR scheduler:
}
}

If you already have a command line that you have been using with :class:`transformers.Trainer` args, you can continue
using those and the :class:`~transformers.Trainer` will automatically convert them into the corresponding DeepSpeed
configuration at run time. For example, you could use the following configuration file:
When you execute the program, DeepSpeed will log the configuration it received from the :class:`~transformers.Trainer`
to the console, so you can see exactly what was the final configuration passed to it.

.. code-block:: json

{
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"cpu_offload": true
}
}
Passing Configuration
=======================================================================================================================

and the following command line arguments:
As discussed in this document normally the DeepSpeed configuration is passed as a path to a json file, but if you're
not using the command line interface to configure the training, and instead instantiate the
:class:`~transformers.Trainer` via :class:`~transformers.TrainingArguments` then for the ``deepspeed`` argument you can
pass a nested ``dict``. This allows you to create the configuration on the fly and doesn't require you to write it to
the file system before passing it to :class:`~transformers.TrainingArguments`.

.. code-block:: bash
To summarize you can do:

--learning_rate 3e-5 --warmup_steps 500 --adam_beta1 0.8 --adam_beta2 0.999 --adam_epsilon 1e-8 \
--weight_decay 3e-7 --lr_scheduler_type constant_with_warmup --fp16 --fp16_backend amp
.. code-block:: python

TrainingArguments(..., deespeed="/path/to/ds_config.json")

or:

.. code-block:: python

ds_config_dict=dict(scheduler=scheduler_params, optimizer=optimizer_params)
TrainingArguments(..., deespeed=ds_config_dict)

to achieve the same configuration as provided by the longer json file in the first example.

When you execute the program, DeepSpeed will log the configuration it received from the :class:`~transformers.Trainer`
to the console, so you can see exactly what the final configuration was passed to it.

Shared Configuration
=======================================================================================================================
Expand Down Expand Up @@ -761,9 +753,27 @@ no equivalent command line arguments.



Optimizer
Optimizer and Scheduler
=======================================================================================================================

As long as you don't enable ``cpu_offload`` you can mix and match DeepSpeed and HuggingFace schedulers and optimizers,
with the exception of using the combination of HuggingFace scheduler and DeepSpeed optimizer:

+--------------+--------------+--------------+
| Combos | HF Scheduler | DS Scheduler |
+--------------+--------------+--------------+
| HF Optimizer | Yes | Yes |
+--------------+--------------+--------------+
| DS Optimizer | No | Yes |
+--------------+--------------+--------------+

If ``cpu_offload`` is enabled you must use both DeepSpeed scheduler and DeepSpeed optimizer.



Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""


DeepSpeed's main optimizers are Adam, AdamW, OneBitAdam, and Lamb. These have been thoroughly tested with ZeRO and are
thus recommended to be used. It, however, can import other optimizers from ``torch``. The full documentation is `here
Expand All @@ -773,7 +783,7 @@ If you don't configure the ``optimizer`` entry in the configuration file, the :c
automatically set it to ``AdamW`` and will use the supplied values or the defaults for the following command line
arguments: ``--learning_rate``, ``--adam_beta1``, ``--adam_beta2``, ``--adam_epsilon`` and ``--weight_decay``.

Here is an example of the pre-configured ``optimizer`` entry for AdamW:
Here is an example of the pre-configured ``optimizer`` entry for ``AdamW``:

.. code-block:: json

Expand All @@ -789,6 +799,17 @@ Here is an example of the pre-configured ``optimizer`` entry for AdamW:
}
}

Note that the command line arguments will override the values in the configuration file. This is so that there is one
definitive source of the values and to avoid hard to find errors when for example, the learning rate is set to
different values in different places. Command line rules. The values that get overridden are:

- ``lr`` with the value of ``--learning_rate``
- ``betas`` with the value of ``--adam_beta1 --adam_beta2``
- ``eps`` with the value of ``--adam_epsilon``
- ``weight_decay`` with the value of ``--weight_decay``

Therefore please remember to tune the shared hyperparameters on the command line.

If you want to use another optimizer which is not listed above, you will have to add ``"zero_allow_untested_optimizer":
true`` to the top level configuration.

Expand All @@ -797,55 +818,71 @@ make sure to adjust the values. e.g. if use Adam you will want ``weight_decay``


Scheduler
=======================================================================================================================
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

DeepSpeed supports LRRangeTest, OneCycle, WarmupLR and WarmupDecayLR LR schedulers. The full documentation is `here
<https://www.deepspeed.ai/docs/config-json/#scheduler-parameters>`__.

If you don't configure the ``scheduler`` entry in the configuration file, the :class:`~transformers.Trainer` will use
the value of ``--lr_scheduler_type`` to configure it. Currently the :class:`~transformers.Trainer` supports only 2 LR
schedulers that are also supported by DeepSpeed:

Here is where the schedulers overlap between 🤗 Transformers and DeepSpeed:

* ``WarmupLR`` via ``--lr_scheduler_type constant_with_warmup``
* ``WarmupDecayLR`` via ``--lr_scheduler_type linear``. This is also the default value for ``--lr_scheduler_type``,
therefore, if you don't configure the scheduler this is scheduler that will get configured by default.

In either case, the values of ``--learning_rate`` and ``--warmup_steps`` will be used for the configuration.

In other words, if you don't use the configuration file to set the ``scheduler`` entry, provide either:

.. code-block:: bash
If you don't configure the ``scheduler`` entry in the configuration file, the :class:`~transformers.Trainer` will use
the values of ``--lr_scheduler_type``, ``--learning_rate`` and ``--warmup_steps`` to configure a 🤗 Transformers version
of it.

--lr_scheduler_type constant_with_warmup --learning_rate 3e-5 --warmup_steps 500
Here is an example of the pre-configured ``scheduler`` entry for ``WarmupLR``:

or
.. code-block:: json

.. code-block:: bash
{
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}

--lr_scheduler_type linear --learning_rate 3e-5 --warmup_steps 500
Note that the command line arguments will override the values in the configuration file. This is so that there is one
definitive source of the values and to avoid hard to find errors when for example, the learning rate is set to
different values in different places. Command line rules. The values that get overridden are:

with the desired values. If you don't pass these arguments, reasonable default values will be used instead.
- ``warmup_max_lr`` with the value of ``--learning_rate``
- ``warmup_num_steps`` with the value of ``--warmup_steps``
- ``total_num_steps`` with either the value of ``--max_steps`` or if it is not provided, derived automatically at run
time based on the environment and the size of the dataset and other command line arguments (needed for
``WarmupDecayLR``).

In the case of WarmupDecayLR ``total_num_steps`` gets set either via the ``--max_steps`` command line argument, or if
it is not provided, derived automatically at run time based on the environment and the size of the dataset and other
command line arguments.
Therefore please remember to tune the shared hyperparameters on the command line.

Here is an example of the pre-configured ``scheduler`` entry for WarmupLR (``constant_with_warmup`` in the
:class:`~transformers.Trainer` API):
For example, for ``WarmupDecayLR``, you can use the following entry:

.. code-block:: json

{
"scheduler": {
"type": "WarmupLR",
"type": "WarmupDecayLR",
"params": {
"total_num_steps": 10,
"last_batch_iteration": -1,
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}

and ``warmup_max_lr``, ``warmup_num_steps`` and ``total_num_steps`` will be corrected at loading time.



Automatic Mixed Precision
=======================================================================================================================

Expand Down Expand Up @@ -933,9 +970,9 @@ Notes
* While DeepSpeed has a pip installable PyPI package, it is highly recommended that it gets installed from `source
<https://github.com/microsoft/deepspeed#installation>`__ to best match your hardware and also if you need to enable
certain features, like 1-bit Adam, which aren't available in the pypi distribution.
* You don't have to use the :class:`~transformers.Trainer` to use DeepSpeed with HuggingFace ``transformers`` - you can
use any model with your own trainer, and you will have to adapt the latter according to `the DeepSpeed integration
instructions <https://www.deepspeed.ai/getting-started/#writing-deepspeed-models>`__.
* You don't have to use the :class:`~transformers.Trainer` to use DeepSpeed with 🤗 Transformers - you can use any model
with your own trainer, and you will have to adapt the latter according to `the DeepSpeed integration instructions
<https://www.deepspeed.ai/getting-started/#writing-deepspeed-models>`__.

Main DeepSpeed Resources
=======================================================================================================================
Expand Down
70 changes: 66 additions & 4 deletions examples/tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import sys
import unittest
from copy import deepcopy

from transformers.integrations import is_deepspeed_available
from transformers.testing_utils import (
Expand Down Expand Up @@ -67,16 +69,76 @@ def setUp(self):
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
self.ds_config_file = f"{self.test_file_dir_str}/ds_config.json"
with io.open(self.ds_config_file, "r", encoding="utf-8") as f:
self.ds_config_dict = json.load(f)

def test_fake_notebook_no_launcher(self):

# this setup emulates a notebook where a launcher needs to be emulated by hand

with CaptureStd() as cs:
with CaptureStd() as cs: # noqa
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file)
trainer.train()
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
# fixme:
# assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"

# Test various combos
# 1. DS scheduler + DS optimizer: this is already tested by most other tests
# 2. HF scheduler + HF optimizer:
# 3. DS scheduler + HF optimizer:
# 4. HF scheduler + DS optimizer:

def test_hf_scheduler_hf_optimizer(self):
a = 0
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
del ds_config_dict["scheduler"] # force default HF Trainer scheduler
ds_config_dict["zero_optimization"]["cpu_offload"] = False
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
trainer.train()
new_a = trainer.model.a.item()
self.assertNotEqual(new_a, a)

def test_ds_scheduler_hf_optimizer(self):
a = 0
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
ds_config_dict["zero_optimization"]["cpu_offload"] = False
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
trainer.train()
new_a = trainer.model.a.item()
self.assertNotEqual(new_a, a)

def test_hf_scheduler_ds_optimizer(self):
# this combo is not possible at the moment
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["scheduler"] # force default HF Trainer scheduler
ds_config_dict["zero_optimization"]["cpu_offload"] = False
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertTrue("HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception))

def test_hf_optimizer_with_offload(self):
# must not allow non-DS optimizer when using ZERO-offload
with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = deepcopy(self.ds_config_dict)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
ds_config_dict["zero_optimization"]["cpu_offload"] = True
# sanity check - should the default config change
assert (
"cpu_offload" in ds_config_dict["zero_optimization"]
and ds_config_dict["zero_optimization"]["cpu_offload"] is True
), "ensure the config is set up correctly"
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertTrue("ZeRO Offload can only work with DeepSpeed optimizers" in str(context.exception))

def test_early_get_last_lr(self):
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
Expand Down
Loading