Skip to content

Commit

Permalink
[torch.distributed][DDP] Disable DDP bucketing for the first iteration (
Browse files Browse the repository at this point in the history
pytorch#72843)

Summary:
Pull Request resolved: pytorch#72843

# [Debug Story] Training Hanging and DDP Bucketing

**What are the characteristics of the hanging training instance?**

The model uses TorchRec `PooledEmbeddingArch` and corresponding sharding solution.

The model config difference to trigger this hanging issue is turning on position weighted embedding tables.

A feature processor module, `GroupedPositionWeightedModule`, is constructed on all ranks, but `GroupedPositionWeightedModule.foward(...)` is only [called on subset ranks of the whole world](https://fburl.com/code/yqrmtvli).

**What was the initial manifested error?**

The training was stuck in the first iteration.

**What are useful debugging tools this time?**

After turning off [static_graph in DDP](https://fburl.com/code/4io81p5i), we saw there were sparse feature lengths becoming negative values after all-to-all collectives. Hanging becomes fatal failure.

After turning on [torch.distributed DETAIL debugging mode](https://fburl.com/code/cp8e28mm), we saw 2 trainers sent out mismatched collectives, one doing all-to-all, the other doing all-reduce. So we know the negative values comes from all-to-all being matched with all-reduce. the error had happened ahead, which is the wrong timing of either doing all-reduce or all-to-all.

With more added loggings inside of DDP, it turned out the DDP decided to do all-reduce at different timings across different ranks.

**What is DDP bucketing?**

Once a gradient is ready on a rank, DDP uses all-reduce to synchronize the average of this gradient across all ranks.

Say we have 4 tensor ops. A, B, C, D.

In the most naive version, we could do one synchronization when all gradients in the full backward graph are ready.

The time sequence would be,

* D.grad
* C.grad
* B.grad
* A.grad
* All reduce on [D.grad, C.grad, B.grad, A.grad].

But that would be a huge waste of communication channel bandwidth.

With DDP bucketing, we could put ahead some gradient synchronization batch by batch. The above time sequence now becomes,

* D.grad
* C.grad
* All reduce on [D.grad, C.grad].
* B.grad
* A.grad
* All reduce on [B.grad, A.grad].

With gradient computation overlaps with communication, bucketing technique brings better DDP execution performance.

**What exactly went wrong in this case?**

1. The bucketing doesn’t honor backward graph execution order.
2. There are other collectives comm ops in backward graph.
3. There are unused parameters (i.e unused sub-module) in subset ranks of the whole world.

Using the above example again, we have 4 tensor ops. A, B, C, D.

Say we have 2 trainers,

B is the feature processor module.

B only runs on trainer 0 (both forward and backward), but not on trainer1.

C is the All-to-all (Pooled embeddings distribution).

C sends out all-to-all collective in both its forward and backward pass.

Keep assuming all other ops run on both trainers.

trainer_0 op sequence is,

A, B (feature preproc), C (all-to-all), D | D.grad, C.grad (reverse all-to-all), B.grad (feature proc grads), A.grad

trainer_1 op sequence is,

A, C (all-to-all), D | D.grad, C.grad (reverse all-to-all), A.grad

Even though the correct bucketing should be (same bucketing for both ranks),

* bucket_0, [D.grad, C.grad]
* bucket_1, [B.grad, A.grad]

but because of 1), they end up like,

* bucket_0, [B.grad, D.grad]
* bucket_1, [C.grad, A.grad]

Plus 2) and 3), the time sequence could like,

(check mark represents the gradient is ready)

(bucket is ready to do synchronization if all its enclosing gradients are ready)

* trainer_0
   * t0,
      * D.grad
      * bucket_0, [B.grad, D.grad ✓]
   * t1,
      * **C.grad all-to-all**
      * C.grad ✓
      * bucket_1, [C.grad ✓, A.grad]
   * t2
      * B.grad
      * bucket_0, [B.grad ✓, D.grad ✓] ✓
   * t3
      * All-reduce for bucket_0
   * t4
      * A.grad
      * bucket_1, [C.grad ✓, A.grad ✓] ✓
* trainer_1
   * t0,
      * D.grad
      * bucket_0, [B.grad ✓, D.grad ✓] ✓. (Because B is not used on trainer_1, DDP marks its gradient as ready immediately.)
   * t1,
      * **All-reduce for bucket_0**
   * t2
      * C.grad all-to-all
      * bucket_1, [C.grad ✓, A.grad]
   * t3
      * A.grad
      * bucket_1, [C.grad ✓, A.grad ✓] ✓

This is why trainer_0 all-to-all is matched up with trainer_1 all-reduce.

**What is the solution for fixing DDP?**

Disable DDP bucketing for the first iteration. D34051938

This is because after the first iteration, buckets will be built again based on real backward graph execution order.

So the slow gradient synchronization only affects the first iteration.

Test Plan:
buck build mode/dev-nosan caffe2/test/distributed:distributed_gloo_spawn
BACKEND=gloo WORLD_SIZE=3 buck-out/gen/caffe2/test/distributed/distributed_gloo_spawn\#binary.par -r test_ddp_logging_data_cpu

P484179296

buck build mode/dev-nosan caffe2/test/distributed:distributed_nccl_spawn
BACKEND=nccl WORLD_SIZE=2 buck-out/gen/caffe2/test/distributed/distributed_nccl_spawn\#binary.par -r test_ddp_logging_data_cpu -r test_ddp_get_bucket_sizes
P484177200

Reviewed By: zhaojuanmao

Differential Revision: D34051938

fbshipit-source-id: 0c7f35875687095c3199f19990e73a8349b6e5b9
(cherry picked from commit bb8f113)
  • Loading branch information
xush6528 authored and pytorchmergebot committed Mar 4, 2022
1 parent 727debb commit bcd0843
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
36 changes: 30 additions & 6 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import collections.abc
import copy
from dataclasses import dataclass
Expand Down Expand Up @@ -687,14 +688,32 @@ def _ddp_init_helper(
(5) passing a handle of DDP to SyncBatchNorm Layer
"""
self.num_iterations = 0
# The bucket size limit is specified in the constructor.
# Additionally, we allow for a single small bucket for parameters
# that are defined first, such that their gradients don't spill into
# a much larger bucket, adding unnecessary latency after gradient
# computation finishes. Experiments showed 1MB is a reasonable value.
# Notice, the parameters order is not in the order in which they are used,
# especially in models with control flow.
#
# Alongside parameters are not presented in the real execution order,
# if a certain model happens to also
# 1) have other collectives comm ops in its backward graph.
# 2) have unused parameter in subset ranks of the whole world.
# bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
# matching up with other collectives comm ops on other ranks unexpectedly.
#
# In order to handle this corner case, when the parameters are not in the real execution order,
# we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
# of the whole graph are computed.
#
# Notice, here we only disable bucketing for the first iteration.
# After the first iteration, it's OK to rebuild buckets,
# because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.

# Can remove this branching once #73732 is landed.
if static_graph is True or self.find_unused_parameters is False:
bucket_size_limits = [sys.maxsize]
else:
bucket_size_limits = [dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap]
bucket_indices, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
parameters,
[dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap],
bucket_size_limits,
expect_sparse_gradient,
)

Expand All @@ -707,6 +726,11 @@ def _ddp_init_helper(
list(reversed(per_bucket_size_limits)),
self.process_group,
expect_sparse_gradient,
# The bucket size limit is specified in the constructor.
# Additionally, we allow for a single small bucket for parameters
# that are defined first, such that their gradients don't spill into
# a much larger bucket, adding unnecessary latency after gradient
# computation finishes. Experiments showed 1MB is a reasonable value.
self.bucket_bytes_cap,
self.find_unused_parameters,
self.gradient_as_bucket_view,
Expand Down
7 changes: 5 additions & 2 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5179,7 +5179,7 @@ def parse_env(var):
rebuilt_bucket_lims = ddp_logging_data.get("rebuilt_bucket_size_limits")
self.assertEqual(
int(init_bucket_lims),
dist._DEFAULT_FIRST_BUCKET_BYTES,
-1,
)
self.assertEqual(
int(rebuilt_bucket_lims),
Expand Down Expand Up @@ -8161,7 +8161,10 @@ def forward(self, x):
]
# first_bucket_bytes is actually the last because we reverse
# parameter bucket order under DDP_SET_LAST_BUCKET_CAP flag.
self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
if i <= 1:
self.assertEqual(bucket_size_limits[-1], -1)
else:
self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
for j, bucket_size in enumerate(bucket_size_limits):
if j != len(bucket_size_limits) - 1:
self.assertEqual(bucket_size, default_bucket_cap_mb)
Expand Down

0 comments on commit bcd0843

Please sign in to comment.