Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Destroy process group in atexit handler #19931

Merged
merged 9 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870), [#19872](https://github.com/Lightning-AI/pytorch-lightning/pull/19872))

- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))


### Changed

Expand Down
10 changes: 10 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import contextlib
import logging
import os
Expand Down Expand Up @@ -291,6 +292,10 @@ def _init_dist_connection(
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)

if torch_distributed_backend == "nccl":
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
atexit.register(_destroy_dist_connection)

# On rank=0 let everyone know training is starting
rank_zero_info(
f"{'-' * 100}\n"
Expand All @@ -300,6 +305,11 @@ def _init_dist_connection(
)


def _destroy_dist_connection() -> None:
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ModelParallelStrategy` to support 2D parallelism ([#19878](https://github.com/Lightning-AI/pytorch-lightning/pull/19878), [#19888](https://github.com/Lightning-AI/pytorch-lightning/pull/19888))

- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))


### Changed
Expand Down
5 changes: 2 additions & 3 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.distributed
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
from lightning.fabric.utilities.distributed import _distributed_is_initialized
from lightning.fabric.utilities.distributed import _destroy_dist_connection

if sys.version_info >= (3, 9):
from concurrent.futures.process import _ExecutorManagerThread
Expand Down Expand Up @@ -78,8 +78,7 @@ def restore_env_variables():
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
_destroy_dist_connection()


@pytest.fixture(autouse=True)
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
_InfiniteBarrier,
_init_dist_connection,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
Expand Down Expand Up @@ -217,3 +219,13 @@ def test_infinite_barrier():
barrier.__exit__(None, None, None)
assert barrier.barrier.call_count == 2
dist_mock.destroy_process_group.assert_called_once()


@mock.patch("lightning.fabric.utilities.distributed.atexit")
@mock.patch("lightning.fabric.utilities.distributed.torch.distributed.init_process_group")
def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
_init_dist_connection(LightningEnvironment(), "nccl")
atexit_mock.register.assert_called_once_with(_destroy_dist_connection)
atexit_mock.reset_mock()
_init_dist_connection(LightningEnvironment(), "gloo")
atexit_mock.register.assert_not_called()
5 changes: 2 additions & 3 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch.distributed
from lightning.fabric.plugins.environments.lightning import find_free_network_port
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
from lightning.fabric.utilities.distributed import _distributed_is_initialized
from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
Expand Down Expand Up @@ -123,8 +123,7 @@ def restore_signal_handlers():
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
_destroy_dist_connection()


@pytest.fixture(autouse=True)
Expand Down
Loading