Skip to content

Commit

Permalink
[Release-1.9.1] [torch] Various improvements to `torch.distributed.la…
Browse files Browse the repository at this point in the history
…unch` and `torch.distributed.run` (#60925) (#64797)

* Minor bug fix in the warning message (#61127)

Summary:
The current example code does not work. The correct one is like this: https://github.com/pytorch/pytorch/blob/cb7d813275a13a4233951e7cbcbb8351dbb0fd87/torch/distributed/run.py#L266

Pull Request resolved: #61127

Reviewed By: cbalioglu

Differential Revision: D29572003

Pulled By: mrshenli

fbshipit-source-id: 05b470230f3d70f8a6164edb5f92894a1112069f

* Small change for torch.distributed launcher (#59152)

Summary:
Pull Request resolved: #59152

Small change for https://fb.workplace.com/groups/319878845696681

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D28773682

Pulled By: H-Huang

fbshipit-source-id: acf82273e8622b7ffd3088d8d766bdf49273754c

* [torch] Various improvements to `torch.distributed.launch` and `torch.distributed.run` (#61294)

Summary:
Pull Request resolved: #61294

Pull Request resolved: #60925

* Make `torch.distributed.launch` restarts to 0
* Remove unnecessary `-use_env` warning, move `-use_env` warnings
* Move `-use_env` warnings to `torch.distributed.launch`
* Make default log level WARNING
* Add new doc section around transitioning to `torch.distributed.run`
* Make `torch.distributed.launch` not use error-propagation
* Set default events handler to `null` that does not print events to console
* Add reference from `torch.distributed.launch` to `torch.distributed.run`
* Set correct preexec function that sends SIGTERM to child processes when parent dies

Issues resolved:

#60716
#60754

Test Plan:
sandcastle

    python -m torch.distributed.launch --nproc_per_node 2 main.py -> uses 0 restarts
    python -m torch.distributed.run --nproc_per_node 2 main.py -> uses default for torchelastic, 0 restarts

    python -m torch.distributed.launch --nproc_per_node=4  --use_env --no_python  main.py -> produces error
    python -m torch.distributed.launch --nproc_per_node=4  --use_env main.py -> no warning
    python -m torch.distributed.launch --nproc_per_node=4  --no_python  main.py ->warning

Output of running torch.distributed.launch without --use_env:

    $path/torch/distributed/launch.py:173: FutureWarning: The module torch.distributed.launch is deprecated
    and will be removed in future. Use torch.distributed.run.
    Note that --use_env is set by default in torch.distributed.run.
    If your script expects `--local_rank` argument to be set, please
    change it to read from `os.environ('LOCAL_RANK')` instead.

New section:

{F628923078}

{F628974089}

Reviewed By: cbalioglu

Differential Revision: D29559553

fbshipit-source-id: 03ed9ba638bf154354e1530ffc964688431edf6b

Co-authored-by: Kento Nozawa <[email protected]>
Co-authored-by: Howard Huang <[email protected]>
Co-authored-by: Aliaksandr Ivanou <[email protected]>
  • Loading branch information
4 people authored Sep 10, 2021
1 parent 04dd41d commit 61e9e88
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 64 deletions.
2 changes: 2 additions & 0 deletions docs/source/elastic/errors.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _elastic_errors-api:

Error Propagation
==================

Expand Down
7 changes: 2 additions & 5 deletions docs/source/elastic/run.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
.. _launcher-api:

Elastic Launch
============================

torch.distributed.run
----------------------
torch.distributed.run (Elastic Launch)
======================================

.. automodule:: torch.distributed.run
18 changes: 11 additions & 7 deletions docs/source/elastic/train_script.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _elastic_train_script:

Train script
-------------

Expand All @@ -7,18 +9,20 @@ working with ``torch.distributed.run`` with these differences:
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
``MASTER_ADDR``, and ``MASTER_PORT``.

2. ``rdzv_backend`` and ``rdzv_endpoint`` must be provided. For most users
this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_).
2. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
the master address.

3. Make sure you have a ``load_checkpoint(path)`` and
``save_checkpoint(path)`` logic in your script. When workers fail
we restart all the workers with the same program arguments so you will
lose progress up to the most recent checkpoint
``save_checkpoint(path)`` logic in your script. When any number of
workers fail we restart all the workers with the same program
arguments so you will lose progress up to the most recent checkpoint
(see `elastic launch <distributed.html>`_).

4. ``use_env`` flag has been removed. If you were parsing local rank by parsing
the ``--local_rank`` option, you need to get the local rank from the
environment variable ``LOCAL_RANK`` (e.g. ``os.environ["LOCAL_RANK"]``).
environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).

Below is an expository example of a training script that checkpoints on each
epoch, hence the worst-case progress lost on failure is one full epoch worth
Expand All @@ -31,7 +35,7 @@ of training.
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensure that this will work
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
result = self._pcontext.wait(0)
if result:
if result.is_failed():
log.error(f"[{role}] Worker group failed")
# map local rank failure to global rank
worker_failures = {}
for local_rank, failure in result.failures.items():
Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/elastic/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import os
import logging

from torch.distributed.elastic.events.handlers import get_logging_handler
Expand Down Expand Up @@ -46,12 +47,12 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger:
return _events_logger
logging_handler = get_logging_handler(destination)
_events_logger = logging.getLogger(f"torchelastic-events-{destination}")
_events_logger.setLevel(logging.DEBUG)
_events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
# Do not propagate message to the root logger
_events_logger.propagate = False
_events_logger.addHandler(logging_handler)
return _events_logger


def record(event: Event, destination: str = "console") -> None:
def record(event: Event, destination: str = "null") -> None:
_get_or_create_logger(destination).info(event.serialize())
3 changes: 2 additions & 1 deletion torch/distributed/elastic/events/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

_log_handlers: Dict[str, logging.Handler] = {
"console": logging.StreamHandler(),
"null": logging.NullHandler(),
}


def get_logging_handler(destination: str = "console") -> logging.Handler:
def get_logging_handler(destination: str = "null") -> logging.Handler:
return _log_handlers[destination]
33 changes: 26 additions & 7 deletions torch/distributed/elastic/multiprocessing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,24 +465,32 @@ def __init__(
entrypoint: str,
args: Tuple,
env: Dict[str, str],
preexec_fn: Callable,
preexec_fn: Optional[Callable],
stdout: str,
stderr: str,
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
args_str = [str(e) for e in args]

# inherit parent environment vars
env_vars = os.environ.copy()
env_vars.update(env)

self.proc: subprocess.Popen = subprocess.Popen(
args_str = (entrypoint, *[str(e) for e in args])
self.proc: subprocess.Popen = self._popen(args_str, env_vars, preexec_fn)

def _popen(
self, args: Tuple, env: Dict[str, str], preexec_fn: Optional[Callable]
) -> subprocess.Popen:
if IS_WINDOWS:
# Reset preexec_fn on windows, since windows does not support it
preexec_fn = None

return subprocess.Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
# `Tuple[str, *Tuple[Any, ...]]`.
args=(entrypoint, *args_str),
env=env_vars,
args=args,
env=env,
preexec_fn=preexec_fn,
stdout=self._stdout,
stderr=self._stderr,
Expand All @@ -497,6 +505,17 @@ def close(self):
self._stderr.close()


def _pr_set_pdeathsig() -> None:
"""
Sets PR_SET_PDEATHSIG to ensure a child process is
terminated appropriately.
See http://stackoverflow.com/questions/1884941/ for more information.
For libc.so.6 read http://www.linux-m68k.org/faq/glibcinfo.html
"""
mp._prctl_pr_set_pdeathsig(signal.SIGTERM) # type: ignore[attr-defined]


class SubprocessContext(PContext):
"""
``PContext`` holding worker processes invoked as a binary.
Expand Down Expand Up @@ -541,7 +560,7 @@ def _start(self):
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
args=self.args[local_rank],
env=self.envs[local_rank],
preexec_fn=mp._prctl_pr_set_pdeathsig(signal.SIGTERM), # type: ignore[attr-defined]
preexec_fn=_pr_set_pdeathsig,
stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank],
)
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/elastic/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_logger(name: Optional[str] = None):
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
env. variable or INFO as default. The function will use the
env. variable or WARNING as default. The function will use the
module name of the caller if no name is provided.
Args:
Expand All @@ -32,7 +32,7 @@ def get_logger(name: Optional[str] = None):

def _setup_logger(name: Optional[str] = None):
log = logging.getLogger(name)
log.setLevel(os.environ.get("LOGLEVEL", "INFO"))
log.setLevel(os.environ.get("LOGLEVEL", "WARNING"))
return log


Expand Down
4 changes: 0 additions & 4 deletions torch/distributed/elastic/utils/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from datetime import timedelta
from typing import List

Expand Down Expand Up @@ -64,8 +63,5 @@ def barrier(
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
warnings.warn(
"This is an experimental API and will be changed in future.", FutureWarning
)
data = f"{rank}".encode(encoding="UTF-8")
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
32 changes: 25 additions & 7 deletions torch/distributed/launch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
r"""
`torch.distributed.launch` is a module that spawns up multiple distributed
``torch.distributed.launch`` is a module that spawns up multiple distributed
training processes on each of the training nodes.
NOTE: This module is deprecated, use torch.distributed.run.
.. warning::
This module is going to be deprecated in favor of :ref:`torch.distributed.run <launcher-api>`.
The utility can be used for single-node distributed training, in which one or
more processes per node will be spawned. The utility can be used for either
Expand Down Expand Up @@ -136,9 +138,12 @@
https://github.com/pytorch/pytorch/issues/12042 for an example of
how things can go wrong if you don't do this correctly.
"""

import logging
import warnings

from torch.distributed.run import get_args_parser, run

Expand All @@ -159,14 +164,27 @@ def parse_args(args):
return parser.parse_args(args)


def launch(args):
if args.no_python and not args.use_env:
raise ValueError(
"When using the '--no_python' flag,"
" you must also set the '--use_env' flag."
)
run(args)


def main(args=None):
logger.warn(
"The module torch.distributed.launch is deprecated "
"and going to be removed in future."
"Migrate to torch.distributed.run"
warnings.warn(
"The module torch.distributed.launch is deprecated\n"
"and will be removed in future. Use torch.distributed.run.\n"
"Note that --use_env is set by default in torch.distributed.run.\n"
"If your script expects `--local_rank` argument to be set, please\n"
"change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
"https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
"further instructions\n", FutureWarning
)
args = parse_args(args)
run(args)
launch(args)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions torch/distributed/launcher/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState # type: ignore[import]
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent # type: ignore[import]
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
self._config = config
self._entrypoint = entrypoint

def __call__(self, *args, **kwargs):
def __call__(self, *args):
return launch_agent(self._config, self._entrypoint, list(args))


Expand Down Expand Up @@ -172,7 +172,6 @@ def _get_addr_and_port(

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# torch.distributed.elastic.multiprocessing.errors.record.
@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
Expand Down
Loading

0 comments on commit 61e9e88

Please sign in to comment.