Skip to content

Commit

Permalink
[Model Averaging] Support hierarchical model averaging (pytorch#73285)
Browse files Browse the repository at this point in the history
Summary:
Implement hierarchical model averaging proposed in pytorch#71325.

Unit tests are added. Since I don't have access to 4-GPU machines in open-source environment, expect that the branch with the prefix of `ci-all` can run the test that requires 4 GPUs.

In the future, the internals of `PeriodicModelAveraging` can be simplified as an implementation of a specialized hierarchical model averaging, where `period_group_size_dict` only has a pair of period and world size.

Pull Request resolved: pytorch#73285

Reviewed By: mrshenli

Differential Revision: D34457792

Pulled By: rohan-varma

fbshipit-source-id: 39a6c5bf8a2852b6394a56abbad17b8a909b9fba
(cherry picked from commit 5f543d4)
  • Loading branch information
wayi1 authored and pytorchmergebot committed Mar 4, 2022
1 parent bcd0843 commit 0bb3b06
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 2 deletions.
4 changes: 4 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain

All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.

All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/algorithms/model_averaging/averagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
warnings.warn(
"When period is 1, no need to use model averaging because the communication cost "
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
"by DistributedDataParall in the backward pass. Therefore, only "
"by DistributedDataParallel in the backward pass. Therefore, only "
"DistributedDataParallel should be used for this case."
)
self.period = period
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2022 Cruise LLC
import warnings
from collections import OrderedDict
import logging

import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.utils as utils

logger = logging.getLogger(__name__)


class HierarchicalModelAverager:
r"""
A group of model averagers used for hierarchical model averaging (hierarchical SGD).
Process groups of different sizes are organized in a hierarhicy, and they average parameters
by using different periods concurrently after the warm-up stage.
This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
Similarly, the process groups within this class do not have such an intra-machine process
subgroup, which should be embedded by the post-local SGD communication hook instead.
Args:
period_group_size_dict: An ordered dict mapping keys of model averaging period to
process group size, used for initializing process groups of
different sizes in a hierarchy to average parameters concurrently.
Particularly, at each iteration, there will be at most a single
process group that runs averaging -- the period of such group should
have the largest period which the current step can be divided by.
For example, if the dict has three keys: 2, 4, and 8,
then this means totally three process groups will be created to
average parameters every 2, 4, and 8 iterations, respectively.
At the 4th iteration, only the second process group will run
averaging, because the first process group should be a
subset of the second process group, and no need to execute the first
process group redundantly.
On the other hand, the third process group can only be triggered
every 8 iterations, so it will not be triggered at the 4th iteration.
warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
If ``None``, the default process group, which is created
by :func:`torch.distributed.init_process_group`, will be used.
(default: ``None``)
Example::
>>> from collections import OrderedDict
>>> import torch
>>> import torch.distributed as dist
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>> PostLocalSGDState,
>>> post_localSGD_hook,
>>> )
>>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
>>> import torch.nn as nn
>>>
>>> dist.init_process_group("nccl", rank=rank, world_size=16)
>>> torch.cuda.set_device(rank)
>>> module = nn.Linear(1, 1, bias=False).to(rank)
>>> model = nn.parallel.DistributedDataParallel(
>>> module, device_ids=[rank], output_device=rank
>>> )
>>> # Register a post-localSGD communication hook.
>>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
>>> subgroup, _ = dist.new_subgroups()
>>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Average parameters among each group of 8 processes every 4 iterations, and among all
>>> # the 16 processes every 16 iterations.
>>> averager = hierarchicalSGD.HierarchicalModelAverager(
>>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
>>> # After 100 steps, run model averaging at two levels.
>>> for step in range(0, 200):
>>> optimizer.zero_grad()
>>> loss = loss_fn(output, labels)
>>> loss.backward()
>>> optimizer.step()
>>> # Average parameters after ``optimizer.step()``.
>>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
>>> averager.average_parameters(model.parameters())
.. warning ::
The last group size in the dict must be the size of the provided ``process_group``,
which indicates model averaging at the highest level of the hierarchy.
If ``process_group`` is not provided, then the last group size should be equal to the world size.
.. warning ::
`HierarchicalModelAverager` is experimental and subject to change.
"""

def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
if not period_group_size_dict:
raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
self._periods = list(period_group_size_dict.keys())
if self._periods[0] <= 0:
raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.")
elif self._periods[-1] == 1:
warnings.warn(
"When the maximum period in arg ``period_group_size_dict`` is 1, "
"no need to use model averaging because the communication cost "
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
"by DistributedDataParallel in the backward pass. Therefore, only "
"DistributedDataParallel should be used for this case."
)
ovall_group : dist.ProcessGroup = (
process_group if process_group is not None else dist.group.WORLD
)
overall_group_size = dist.get_world_size(group=ovall_group)
if list(period_group_size_dict.values())[-1] != overall_group_size:
raise ValueError(
"The last value in arg ``period_process_group_dict`` "
"must be equal to the size of arg ``process_group``.")

self.period_process_group_dict = OrderedDict()
logger.info("Model averaging hierarchy:")
for period, group_size in period_group_size_dict.items():
logger.info(
f"\tEach group that has {group_size} processes average parameters every {period} iterations, "
"if no higher-level averaging.")
if group_size != overall_group_size:
self.period_process_group_dict[period], _ = dist.new_subgroups(
group_size=group_size, group=ovall_group)
else:
self.period_process_group_dict[period] = ovall_group

if warmup_steps < 0:
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
self.warmup_steps = warmup_steps
self.step = 0

def _find_process_group(self):
"""
Returns a tuple consisting of whether ``step`` can be divided by
a period in the keys of ``period_process_group_dict`` and the associated process group if any.
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
then the returned process group is the one corresponding to the largest period,
since this process group will be used for averaging parameters at this ``step``.
"""
for period in reversed(self._periods):
if self.step % period == 0:
return (True, self.period_process_group_dict[period])
return (False, None)

def average_parameters(self, params):
r"""
Averages parameters if ``step`` is no less than ``warmup_steps``
and it can be divided by a period in the keys of ``period_process_group_dict``,
where ``step`` is increased by 1 at each iteration in the training loop.
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
only the largest period is used, and the corresponding process group is used for averaging parameters.
"""
if self.step >= self.warmup_steps:
found, group = self._find_process_group()
if found:
utils.average_parameters(iter(params), group)
self.step += 1
99 changes: 98 additions & 1 deletion torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import tempfile
import time
from collections import namedtuple
from collections import namedtuple, OrderedDict
from contextlib import contextmanager, suppress
from datetime import timedelta
from functools import reduce
Expand All @@ -16,6 +16,7 @@
import torch.cuda
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -1033,6 +1034,102 @@ def test_periodic_model_averager(self):
# No model averaging, so the parameters are not updated.
self.assertEqual(param.data, tensor)

@sandcastle_skip_if(
BACKEND not in DistTestCases.backend_feature["subgroup"],
f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
)
@skip_if_lt_x_gpu(2)
def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
device_id = rank_to_GPU[rank][0]

model = nn.Linear(1, 5, bias=False).cuda(device_id)
param = next(model.parameters())
tensor = torch.ones_like(param.data) * rank
expected_avg_tensor = (
torch.ones_like(param.data) * sum(range(world_size)) / world_size
)
period = 4
for warmup_steps in [12, 13, 14, 15]:
averager = hierarchicalSGD.HierarchicalModelAverager(
# Run the global averaging at a period of 4,
# which is equivalent to the above periodic model averaging test case.
period_group_size_dict=OrderedDict([(period, world_size)]), warmup_steps=warmup_steps
)

averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
for step in range(0, 20):
# Reset the parameters at every step.
param.data = copy.deepcopy(tensor)
averager.average_parameters(model.parameters())
if step >= warmup_steps and (step - warmup_steps) % period == 0:
self.assertEqual(param.data, expected_avg_tensor)
else:
# No model averaging, so the parameters are not updated.
self.assertEqual(param.data, tensor)

@sandcastle_skip_if(
BACKEND not in DistTestCases.backend_feature["subgroup"],
f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
)
@require_world_size(4)
@skip_if_lt_x_gpu(4)
def test_3_level_hierarchical_model_averager(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
device_id = rank_to_GPU[rank][0]

model = nn.Linear(1, 5, bias=False).cuda(device_id)
param = next(model.parameters())
tensor = torch.ones_like(param.data) * rank
# Set up such a hierarchical model averaging as follows:
# after the first 10 warmup steps,
# run model averaging every 2 steps within each subgroup of size 2,
# run model averaging every 4 steps within each subgroup of size 3,
# and run the global model averaging every 8 steps.
# If there is a conflict in model averaging at a step, only run the highest-level model averaging.
warmup_steps = 10
subgroup_size1 = 2
subgroup_avg_period1 = 2
subgroup_size2 = 4
subgroup_avg_period2 = 4
global_avg_period = 8
period_group_size_dict = OrderedDict(
[(subgroup_avg_period1, subgroup_size1),
(subgroup_avg_period2, subgroup_size2),
(global_avg_period, world_size)])
averager = hierarchicalSGD.HierarchicalModelAverager(
period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
)
expected_avg_tensor_within_subgroup1 = (
torch.ones_like(param.data) * sum(range(subgroup_size1)) / subgroup_size1
)
expected_avg_tensor_within_subgroup2 = (
torch.ones_like(param.data) * sum(range(subgroup_size2)) / subgroup_size2
)
expected_global_avg_tensor = (
torch.ones_like(param.data) * sum(range(world_size)) / world_size
)
for step in range(0, 25):
# Reset the parameters at every step.
param.data = copy.deepcopy(tensor)
averager.average_parameters(model.parameters())
if step == 16 or step == 24:
# Run global model averaging when `step` can be divided by 8.
self.assertEqual(param.data, expected_global_avg_tensor)
elif step == 12 or step == 20:
# Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
elif step == 10 or step == 14 or step == 18 or step == 22:
# Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
else:
# No model averaging, so the parameters are not updated.
self.assertEqual(param.data, tensor)

# NCCL Batch SEND RECV
@skip_if_no_gpu
@sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
Expand Down

0 comments on commit 0bb3b06

Please sign in to comment.