Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mean in DDP Sync #2568

Merged
merged 13 commits into from
Aug 4, 2020
9 changes: 9 additions & 0 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,23 +234,32 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""

if torch.distributed.is_available() and torch.distributed.is_initialized():
divide_by_world_size = False

if group is None:
group = torch.distributed.group.WORLD

if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'):
reduce_op = torch.distributed.ReduceOp.SUM
divide_by_world_size = True

# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)

return result


Expand Down
41 changes: 30 additions & 11 deletions tests/metrics/test_converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
Expand Down Expand Up @@ -114,26 +115,44 @@ def _setup_ddp(rank, worldsize):
dist.init_process_group("gloo", rank=rank, world_size=worldsize)


def _ddp_test_fn(rank, worldsize):
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.], device='cuda:0')

reduced_tensor = _sync_ddp_if_available(tensor)
if add_offset:
tensor = torch.tensor([float(rank)])
else:
tensor = torch.tensor([1.], )
if reduction_mean:
reduced_tensor = _sync_ddp_if_available(tensor, reduce_op='avg')

manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size()
print(reduced_tensor)
print(manual_reduction)
assert reduced_tensor.item() == manual_reduction
else:
reduced_tensor = _sync_ddp_if_available(tensor)

assert reduced_tensor.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Tensors'
assert reduced_tensor.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Tensors'


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize)
justusschock marked this conversation as resolved.
Show resolved Hide resolved

# dist.destroy_process_group()

@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp_mean():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, True, True), nprocs=worldsize)


def test_sync_reduce_simple():
Expand Down Expand Up @@ -172,7 +191,7 @@ def _ddp_test_tensor_metric(rank, worldsize):
_test_tensor_metric(True)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_tensor_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
Expand Down Expand Up @@ -212,7 +231,7 @@ def _ddp_test_numpy_metric(rank, worldsize):
_test_numpy_metric(True)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_numpy_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
Expand Down