forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model Averaging] Support hierarchical model averaging (pytorch#73285)
Summary: Implement hierarchical model averaging proposed in pytorch#71325. Unit tests are added. Since I don't have access to 4-GPU machines in open-source environment, expect that the branch with the prefix of `ci-all` can run the test that requires 4 GPUs. In the future, the internals of `PeriodicModelAveraging` can be simplified as an implementation of a specialized hierarchical model averaging, where `period_group_size_dict` only has a pair of period and world size. Pull Request resolved: pytorch#73285 Reviewed By: mrshenli Differential Revision: D34457792 Pulled By: rohan-varma fbshipit-source-id: 39a6c5bf8a2852b6394a56abbad17b8a909b9fba (cherry picked from commit 5f543d4)
- Loading branch information
1 parent
bcd0843
commit 0bb3b06
Showing
4 changed files
with
262 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
159 changes: 159 additions & 0 deletions
159
torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2022 Cruise LLC | ||
import warnings | ||
from collections import OrderedDict | ||
import logging | ||
|
||
import torch.distributed as dist | ||
import torch.distributed.algorithms.model_averaging.utils as utils | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HierarchicalModelAverager: | ||
r""" | ||
A group of model averagers used for hierarchical model averaging (hierarchical SGD). | ||
Process groups of different sizes are organized in a hierarhicy, and they average parameters | ||
by using different periods concurrently after the warm-up stage. | ||
This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` | ||
that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports | ||
a two-level hierarchy: the intra-machine level and the global level, where the intra-machine | ||
level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. | ||
Similarly, the process groups within this class do not have such an intra-machine process | ||
subgroup, which should be embedded by the post-local SGD communication hook instead. | ||
Args: | ||
period_group_size_dict: An ordered dict mapping keys of model averaging period to | ||
process group size, used for initializing process groups of | ||
different sizes in a hierarchy to average parameters concurrently. | ||
Particularly, at each iteration, there will be at most a single | ||
process group that runs averaging -- the period of such group should | ||
have the largest period which the current step can be divided by. | ||
For example, if the dict has three keys: 2, 4, and 8, | ||
then this means totally three process groups will be created to | ||
average parameters every 2, 4, and 8 iterations, respectively. | ||
At the 4th iteration, only the second process group will run | ||
averaging, because the first process group should be a | ||
subset of the second process group, and no need to execute the first | ||
process group redundantly. | ||
On the other hand, the third process group can only be triggered | ||
every 8 iterations, so it will not be triggered at the 4th iteration. | ||
warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. | ||
process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. | ||
If ``None``, the default process group, which is created | ||
by :func:`torch.distributed.init_process_group`, will be used. | ||
(default: ``None``) | ||
Example:: | ||
>>> from collections import OrderedDict | ||
>>> import torch | ||
>>> import torch.distributed as dist | ||
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( | ||
>>> PostLocalSGDState, | ||
>>> post_localSGD_hook, | ||
>>> ) | ||
>>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD | ||
>>> import torch.nn as nn | ||
>>> | ||
>>> dist.init_process_group("nccl", rank=rank, world_size=16) | ||
>>> torch.cuda.set_device(rank) | ||
>>> module = nn.Linear(1, 1, bias=False).to(rank) | ||
>>> model = nn.parallel.DistributedDataParallel( | ||
>>> module, device_ids=[rank], output_device=rank | ||
>>> ) | ||
>>> # Register a post-localSGD communication hook. | ||
>>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. | ||
>>> subgroup, _ = dist.new_subgroups() | ||
>>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100) | ||
>>> model.register_comm_hook(state, post_localSGD_hook) | ||
>>> | ||
>>> # Average parameters among each group of 8 processes every 4 iterations, and among all | ||
>>> # the 16 processes every 16 iterations. | ||
>>> averager = hierarchicalSGD.HierarchicalModelAverager( | ||
>>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) | ||
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. | ||
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. | ||
>>> # After 100 steps, run model averaging at two levels. | ||
>>> for step in range(0, 200): | ||
>>> optimizer.zero_grad() | ||
>>> loss = loss_fn(output, labels) | ||
>>> loss.backward() | ||
>>> optimizer.step() | ||
>>> # Average parameters after ``optimizer.step()``. | ||
>>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. | ||
>>> averager.average_parameters(model.parameters()) | ||
.. warning :: | ||
The last group size in the dict must be the size of the provided ``process_group``, | ||
which indicates model averaging at the highest level of the hierarchy. | ||
If ``process_group`` is not provided, then the last group size should be equal to the world size. | ||
.. warning :: | ||
`HierarchicalModelAverager` is experimental and subject to change. | ||
""" | ||
|
||
def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): | ||
if not period_group_size_dict: | ||
raise ValueError("Arg ``period_group_size_dict`` must not be empty.") | ||
self._periods = list(period_group_size_dict.keys()) | ||
if self._periods[0] <= 0: | ||
raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.") | ||
elif self._periods[-1] == 1: | ||
warnings.warn( | ||
"When the maximum period in arg ``period_group_size_dict`` is 1, " | ||
"no need to use model averaging because the communication cost " | ||
"of all-reducing parameters will be no less than the cost of all-reducing gradients " | ||
"by DistributedDataParallel in the backward pass. Therefore, only " | ||
"DistributedDataParallel should be used for this case." | ||
) | ||
ovall_group : dist.ProcessGroup = ( | ||
process_group if process_group is not None else dist.group.WORLD | ||
) | ||
overall_group_size = dist.get_world_size(group=ovall_group) | ||
if list(period_group_size_dict.values())[-1] != overall_group_size: | ||
raise ValueError( | ||
"The last value in arg ``period_process_group_dict`` " | ||
"must be equal to the size of arg ``process_group``.") | ||
|
||
self.period_process_group_dict = OrderedDict() | ||
logger.info("Model averaging hierarchy:") | ||
for period, group_size in period_group_size_dict.items(): | ||
logger.info( | ||
f"\tEach group that has {group_size} processes average parameters every {period} iterations, " | ||
"if no higher-level averaging.") | ||
if group_size != overall_group_size: | ||
self.period_process_group_dict[period], _ = dist.new_subgroups( | ||
group_size=group_size, group=ovall_group) | ||
else: | ||
self.period_process_group_dict[period] = ovall_group | ||
|
||
if warmup_steps < 0: | ||
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") | ||
self.warmup_steps = warmup_steps | ||
self.step = 0 | ||
|
||
def _find_process_group(self): | ||
""" | ||
Returns a tuple consisting of whether ``step`` can be divided by | ||
a period in the keys of ``period_process_group_dict`` and the associated process group if any. | ||
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, | ||
then the returned process group is the one corresponding to the largest period, | ||
since this process group will be used for averaging parameters at this ``step``. | ||
""" | ||
for period in reversed(self._periods): | ||
if self.step % period == 0: | ||
return (True, self.period_process_group_dict[period]) | ||
return (False, None) | ||
|
||
def average_parameters(self, params): | ||
r""" | ||
Averages parameters if ``step`` is no less than ``warmup_steps`` | ||
and it can be divided by a period in the keys of ``period_process_group_dict``, | ||
where ``step`` is increased by 1 at each iteration in the training loop. | ||
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, | ||
only the largest period is used, and the corresponding process group is used for averaging parameters. | ||
""" | ||
if self.step >= self.warmup_steps: | ||
found, group = self._find_process_group() | ||
if found: | ||
utils.average_parameters(iter(params), group) | ||
self.step += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters