Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ddp sync for logging in result step #2822

Merged
merged 7 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numbers
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
from torch import Tensor
import torch
from copy import copy
from pytorch_lightning.metrics.converters import _sync_ddp_if_available


class Result(Dict):
Expand Down Expand Up @@ -89,11 +91,18 @@ def log(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()

# sync across ddp
if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)):
value = _sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op)

if 'meta' not in self:
self.__setitem__('meta', {})

Expand Down Expand Up @@ -338,6 +347,9 @@ def log(
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a key, value
Expand Down Expand Up @@ -369,7 +381,8 @@ def log(
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)

def log_dict(
self,
Expand All @@ -380,6 +393,9 @@ def log_dict(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a dictonary of values at once
Expand All @@ -399,7 +415,8 @@ def log_dict(
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)


class EvalResult(Result):
Expand Down Expand Up @@ -446,6 +463,9 @@ def log(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a key, value
Expand Down Expand Up @@ -476,7 +496,8 @@ def log(
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph :
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)

def log_dict(
self,
Expand All @@ -487,6 +508,9 @@ def log_dict(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a dictonary of values at once
Expand All @@ -506,7 +530,8 @@ def log_dict(
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)

def get_callback_metrics(self) -> dict:
result = {
Expand Down
37 changes: 37 additions & 0 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
import tests.base.develop_utils as tutils
import sys


def _setup_ddp(rank, worldsize):
import os

os.environ["MASTER_ADDR"] = "localhost"

# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)


def _ddp_test_fn(rank, worldsize, result_cls: Result):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])

res = result_cls()
res.log("test_tensor", tensor, sync_ddp=True, sync_ddp_op=torch.distributed.ReduceOp.SUM)

assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"


@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult])
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_result_reduce_ddp(result_cls):
"""Make sure result logging works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, result_cls), nprocs=worldsize)