diff --git a/LICENSE b/LICENSE index 9cb8cbef5a9f8..04f9ad1105655 100644 --- a/LICENSE +++ b/LICENSE @@ -28,6 +28,10 @@ All rights reserved. All contributions by Kakao Brain: Copyright 2019-2020 Kakao Brain +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + All contributions from Caffe: Copyright(c) 2013, 2014, 2015, the respective contributors All rights reserved. diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index cb67057bcdeb4..0fba0c9a5a57e 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -94,7 +94,7 @@ def __init__( warnings.warn( "When period is 1, no need to use model averaging because the communication cost " "of all-reducing parameters will be no less than the cost of all-reducing gradients " - "by DistributedDataParall in the backward pass. Therefore, only " + "by DistributedDataParallel in the backward pass. Therefore, only " "DistributedDataParallel should be used for this case." ) self.period = period diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py new file mode 100644 index 0000000000000..f70cf0fe2d857 --- /dev/null +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -0,0 +1,159 @@ +# Copyright 2022 Cruise LLC +import warnings +from collections import OrderedDict +import logging + +import torch.distributed as dist +import torch.distributed.algorithms.model_averaging.utils as utils + +logger = logging.getLogger(__name__) + + +class HierarchicalModelAverager: + r""" + A group of model averagers used for hierarchical model averaging (hierarchical SGD). + Process groups of different sizes are organized in a hierarhicy, and they average parameters + by using different periods concurrently after the warm-up stage. + This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` + that supports `post-local SGD `_, which essentially only supports + a two-level hierarchy: the intra-machine level and the global level, where the intra-machine + level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. + Similarly, the process groups within this class do not have such an intra-machine process + subgroup, which should be embedded by the post-local SGD communication hook instead. + + Args: + period_group_size_dict: An ordered dict mapping keys of model averaging period to + process group size, used for initializing process groups of + different sizes in a hierarchy to average parameters concurrently. + Particularly, at each iteration, there will be at most a single + process group that runs averaging -- the period of such group should + have the largest period which the current step can be divided by. + For example, if the dict has three keys: 2, 4, and 8, + then this means totally three process groups will be created to + average parameters every 2, 4, and 8 iterations, respectively. + At the 4th iteration, only the second process group will run + averaging, because the first process group should be a + subset of the second process group, and no need to execute the first + process group redundantly. + On the other hand, the third process group can only be triggered + every 8 iterations, so it will not be triggered at the 4th iteration. + warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. + process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. + If ``None``, the default process group, which is created + by :func:`torch.distributed.init_process_group`, will be used. + (default: ``None``) + + Example:: + >>> from collections import OrderedDict + >>> import torch + >>> import torch.distributed as dist + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD + >>> import torch.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> torch.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).to(rank) + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. + >>> subgroup, _ = dist.new_subgroups() + >>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Average parameters among each group of 8 processes every 4 iterations, and among all + >>> # the 16 processes every 16 iterations. + >>> averager = hierarchicalSGD.HierarchicalModelAverager( + >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging at two levels. + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Average parameters after ``optimizer.step()``. + >>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. + >>> averager.average_parameters(model.parameters()) + + .. warning :: + The last group size in the dict must be the size of the provided ``process_group``, + which indicates model averaging at the highest level of the hierarchy. + If ``process_group`` is not provided, then the last group size should be equal to the world size. + + .. warning :: + `HierarchicalModelAverager` is experimental and subject to change. + """ + + def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): + if not period_group_size_dict: + raise ValueError("Arg ``period_group_size_dict`` must not be empty.") + self._periods = list(period_group_size_dict.keys()) + if self._periods[0] <= 0: + raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.") + elif self._periods[-1] == 1: + warnings.warn( + "When the maximum period in arg ``period_group_size_dict`` is 1, " + "no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case." + ) + ovall_group : dist.ProcessGroup = ( + process_group if process_group is not None else dist.group.WORLD + ) + overall_group_size = dist.get_world_size(group=ovall_group) + if list(period_group_size_dict.values())[-1] != overall_group_size: + raise ValueError( + "The last value in arg ``period_process_group_dict`` " + "must be equal to the size of arg ``process_group``.") + + self.period_process_group_dict = OrderedDict() + logger.info("Model averaging hierarchy:") + for period, group_size in period_group_size_dict.items(): + logger.info( + f"\tEach group that has {group_size} processes average parameters every {period} iterations, " + "if no higher-level averaging.") + if group_size != overall_group_size: + self.period_process_group_dict[period], _ = dist.new_subgroups( + group_size=group_size, group=ovall_group) + else: + self.period_process_group_dict[period] = ovall_group + + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + self.step = 0 + + def _find_process_group(self): + """ + Returns a tuple consisting of whether ``step`` can be divided by + a period in the keys of ``period_process_group_dict`` and the associated process group if any. + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + then the returned process group is the one corresponding to the largest period, + since this process group will be used for averaging parameters at this ``step``. + """ + for period in reversed(self._periods): + if self.step % period == 0: + return (True, self.period_process_group_dict[period]) + return (False, None) + + def average_parameters(self, params): + r""" + Averages parameters if ``step`` is no less than ``warmup_steps`` + and it can be divided by a period in the keys of ``period_process_group_dict``, + where ``step`` is increased by 1 at each iteration in the training loop. + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + only the largest period is used, and the corresponding process group is used for averaging parameters. + """ + if self.step >= self.warmup_steps: + found, group = self._find_process_group() + if found: + utils.average_parameters(iter(params), group) + self.step += 1 diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5d7ef22b8bf45..b3f98560e42a0 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -6,7 +6,7 @@ import sys import tempfile import time -from collections import namedtuple +from collections import namedtuple, OrderedDict from contextlib import contextmanager, suppress from datetime import timedelta from functools import reduce @@ -16,6 +16,7 @@ import torch.cuda import torch.distributed as dist import torch.distributed.algorithms.model_averaging.averagers as averagers +import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils import torch.nn as nn import torch.nn.functional as F @@ -1033,6 +1034,102 @@ def test_periodic_model_averager(self): # No model averaging, so the parameters are not updated. self.assertEqual(param.data, tensor) + @sandcastle_skip_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices" + ) + @skip_if_lt_x_gpu(2) + def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Linear(1, 5, bias=False).cuda(device_id) + param = next(model.parameters()) + tensor = torch.ones_like(param.data) * rank + expected_avg_tensor = ( + torch.ones_like(param.data) * sum(range(world_size)) / world_size + ) + period = 4 + for warmup_steps in [12, 13, 14, 15]: + averager = hierarchicalSGD.HierarchicalModelAverager( + # Run the global averaging at a period of 4, + # which is equivalent to the above periodic model averaging test case. + period_group_size_dict=OrderedDict([(period, world_size)]), warmup_steps=warmup_steps + ) + + averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps) + for step in range(0, 20): + # Reset the parameters at every step. + param.data = copy.deepcopy(tensor) + averager.average_parameters(model.parameters()) + if step >= warmup_steps and (step - warmup_steps) % period == 0: + self.assertEqual(param.data, expected_avg_tensor) + else: + # No model averaging, so the parameters are not updated. + self.assertEqual(param.data, tensor) + + @sandcastle_skip_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices" + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_3_level_hierarchical_model_averager(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + + model = nn.Linear(1, 5, bias=False).cuda(device_id) + param = next(model.parameters()) + tensor = torch.ones_like(param.data) * rank + # Set up such a hierarchical model averaging as follows: + # after the first 10 warmup steps, + # run model averaging every 2 steps within each subgroup of size 2, + # run model averaging every 4 steps within each subgroup of size 3, + # and run the global model averaging every 8 steps. + # If there is a conflict in model averaging at a step, only run the highest-level model averaging. + warmup_steps = 10 + subgroup_size1 = 2 + subgroup_avg_period1 = 2 + subgroup_size2 = 4 + subgroup_avg_period2 = 4 + global_avg_period = 8 + period_group_size_dict = OrderedDict( + [(subgroup_avg_period1, subgroup_size1), + (subgroup_avg_period2, subgroup_size2), + (global_avg_period, world_size)]) + averager = hierarchicalSGD.HierarchicalModelAverager( + period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps + ) + expected_avg_tensor_within_subgroup1 = ( + torch.ones_like(param.data) * sum(range(subgroup_size1)) / subgroup_size1 + ) + expected_avg_tensor_within_subgroup2 = ( + torch.ones_like(param.data) * sum(range(subgroup_size2)) / subgroup_size2 + ) + expected_global_avg_tensor = ( + torch.ones_like(param.data) * sum(range(world_size)) / world_size + ) + for step in range(0, 25): + # Reset the parameters at every step. + param.data = copy.deepcopy(tensor) + averager.average_parameters(model.parameters()) + if step == 16 or step == 24: + # Run global model averaging when `step` can be divided by 8. + self.assertEqual(param.data, expected_global_avg_tensor) + elif step == 12 or step == 20: + # Run model averaging within subgroup when `step` can be divided by 4 but not by 8. + self.assertEqual(param.data, expected_avg_tensor_within_subgroup1) + elif step == 10 or step == 14 or step == 18 or step == 22: + # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8. + self.assertEqual(param.data, expected_avg_tensor_within_subgroup1) + else: + # No model averaging, so the parameters are not updated. + self.assertEqual(param.data, tensor) + # NCCL Batch SEND RECV @skip_if_no_gpu @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")