Skip to content

Commit

Permalink
Add torchelastic check when sanitizing GPUs (#8095)
Browse files Browse the repository at this point in the history
* Add torchelastic check

* Add changelog

* Address review

* fix
  • Loading branch information
Sean Naren authored Jun 23, 2021
1 parent 4dc08e4 commit 8bd7b1b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073))


- Add torchelastic check when sanitizing GPUs ([#8095](https://github.com/PyTorchLightning/pytorch-lightning/pull/8095))


### Changed


Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
Expand Down Expand Up @@ -78,6 +79,11 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if TorchElasticEnvironment.is_using_torchelastic() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1:
# omit sanity check on torchelastic as by default shows one visible GPU per process
return gpus

gpus = _sanitize_gpu_ids(gpus)

return gpus
Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
import os
from collections import namedtuple
from unittest import mock
from unittest.mock import patch

import pytest
Expand All @@ -21,6 +23,7 @@
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
Expand Down Expand Up @@ -219,6 +222,29 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun
device_parser.parse_gpu_ids(gpus)


@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0",
"LOCAL_RANK": "1",
"GROUP_RANK": "1",
"RANK": "3",
"WORLD_SIZE": "4",
"LOCAL_WORLD_SIZE": "2",
}
)
@mock.patch('torch.cuda.device_count', return_value=1)
@pytest.mark.parametrize("gpus", [[0, 1, 2], 2, '0'])
def test_torchelastic_gpu_parsing(mocked_device_count, gpus):
"""
Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device
That we omit sanitizing the gpus as only one of the GPUs is visible.
"""
trainer = Trainer(gpus=gpus)
assert isinstance(trainer.accelerator_connector.cluster_environment, TorchElasticEnvironment)
assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus)
assert trainer.gpus == gpus


@RunIf(min_gpus=1)
def test_single_gpu_batch_parse():
trainer = Trainer(gpus=1)
Expand Down

0 comments on commit 8bd7b1b

Please sign in to comment.