Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for devices flag to Trainer #8440

Merged
merged 22 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480))


- Added support for `devices` flag to Trainer ([#8440](https://github.com/PyTorchLightning/pytorch-lightning/pull/8440))


kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
### Changed


Expand Down
38 changes: 35 additions & 3 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class AcceleratorConnector(object):
def __init__(
self,
num_processes,
devices,
tpu_cores,
ipus,
distributed_backend,
Expand All @@ -106,6 +107,7 @@ def __init__(
self._accelerator_type = None

self.num_processes = num_processes
self.devices = devices
# `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids.
self.gpus = gpus
self.parallel_device_ids = gpu_ids
Expand Down Expand Up @@ -179,6 +181,7 @@ def select_accelerator_type(self) -> None:
elif self.has_gpu:
self._accelerator_type = DeviceType.GPU
else:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU
elif self.distributed_backend == DeviceType.TPU:
if not self.has_tpu:
Expand All @@ -196,6 +199,7 @@ def select_accelerator_type(self) -> None:
raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.")
self._accelerator_type = DeviceType.GPU
elif self.distributed_backend == DeviceType.CPU:
self._set_devices_to_cpu_num_processes()
self._accelerator_type = DeviceType.CPU

if self.distributed_backend in ["auto"] + list(DeviceType):
Expand Down Expand Up @@ -302,7 +306,9 @@ def has_gpu(self) -> bool:
# Here, we are not checking for GPU availability, but instead if User has passed
# `gpus` to Trainer for training.
gpus = self.parallel_device_ids
return gpus is not None and len(gpus) > 0
if gpus is not None and len(gpus) > 0:
return True
return self._map_devices_to_accelerator(DeviceType.GPU)

@property
def use_gpu(self) -> bool:
Expand All @@ -312,7 +318,9 @@ def use_gpu(self) -> bool:
def has_tpu(self) -> bool:
# Here, we are not checking for TPU availability, but instead if User has passed
# `tpu_cores` to Trainer for training.
return self.tpu_cores is not None
if self.tpu_cores is not None:
return True
return self._map_devices_to_accelerator(DeviceType.TPU)

@property
def use_tpu(self) -> bool:
Expand All @@ -328,12 +336,36 @@ def tpu_id(self) -> Optional[int]:
def has_ipu(self) -> bool:
# Here, we are not checking for IPU availability, but instead if User has passed
# `ipus` to Trainer for training.
return self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin)
if self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin):
return True
return self._map_devices_to_accelerator(DeviceType.IPU)

@property
def use_ipu(self) -> bool:
return self._accelerator_type == DeviceType.IPU and self.has_ipu

def _set_devices_to_cpu_num_processes(self) -> None:
if self.num_processes <= 1:
self._map_devices_to_accelerator(DeviceType.CPU)

def _map_devices_to_accelerator(self, accelerator: str) -> bool:
if self.devices is None:
return False
if accelerator == DeviceType.TPU and _TPU_AVAILABLE:
self.tpu_cores = device_parser.parse_tpu_cores(self.devices)
return True
elif accelerator == DeviceType.IPU and _IPU_AVAILABLE:
self.ipus = self.devices
return True
elif accelerator == DeviceType.GPU and torch.cuda.is_available():
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
self.gpus = self.devices
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
self.parallel_device_ids = device_parser.parse_gpu_ids(self.devices)
return True
elif accelerator == DeviceType.CPU:
self.num_processes = self.devices
return True
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
return False

@property
def use_dp(self) -> bool:
return self._distrib_type == DistributedType.DP
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
devices: Optional[Union[List[int], str, int]] = None,
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
tpu_cores: Optional[Union[List[int], str, int]] = None,
Expand Down Expand Up @@ -207,6 +208,9 @@ def __init__(

deterministic: If true enables cudnn.deterministic.

devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`,
based on the accelerator type.

distributed_backend: deprecated. Please use 'accelerator'

fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
Expand Down Expand Up @@ -343,8 +347,8 @@ def __init__(
self.optimizer_connector = OptimizerConnector(self)

self.accelerator_connector = AcceleratorConnector(
num_processes, tpu_cores, ipus, distributed_backend, gpus, gpu_ids, num_nodes, sync_batchnorm, benchmark,
replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins
num_processes, devices, tpu_cores, ipus, distributed_backend, gpus, gpu_ids, num_nodes, sync_batchnorm,
benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins
)
self.logger_connector = LoggerConnector(self, log_gpu_memory)
self.model_connector = ModelConnector(self)
Expand Down
47 changes: 47 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,50 @@ def test_accelerator_cpu_with_multiple_gpus():

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)


@pytest.mark.parametrize(["devices", "plugin"], [(1, SingleDevicePlugin), (5, DDPSpawnPlugin)])
def test_accelerator_cpu_with_devices(devices, plugin):

trainer = Trainer(accelerator="cpu", devices=devices)

assert trainer.num_processes == devices
assert isinstance(trainer.training_type_plugin, plugin)
assert isinstance(trainer.accelerator, CPUAccelerator)


def test_accelerator_cpu_with_num_processes_priority():
""" Test for checking num_processes takes priority over devices. """

num_processes = 5
trainer = Trainer(accelerator="cpu", devices=8, num_processes=num_processes)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
assert trainer.num_processes == num_processes


@RunIf(min_gpus=2)
@pytest.mark.parametrize(["devices", "plugin"], [(1, SingleDevicePlugin), (2, DDPSpawnPlugin)])
def test_accelerator_gpu_with_devices(devices, plugin):

trainer = Trainer(accelerator="gpu", devices=devices)

assert trainer.gpus == devices
assert isinstance(trainer.training_type_plugin, plugin)
assert isinstance(trainer.accelerator, GPUAccelerator)


@RunIf(min_gpus=1)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
def test_accelerator_auto_with_devices_gpu():

trainer = Trainer(accelerator="auto", devices=1)

assert trainer._device_type == "gpu"
assert trainer.gpus == 1


@RunIf(min_gpus=1)
def test_accelerator_gpu_with_gpus_priority():
""" Test for checking `gpus` flag takes priority over `devices`. """

gpus = 1
trainer = Trainer(accelerator="gpu", devices=4, gpus=gpus)
assert trainer.gpus == gpus
28 changes: 28 additions & 0 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,31 @@ def test_accelerator_cpu_with_ipus_flag():

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)


@RunIf(ipu=True)
def test_accelerator_ipu_with_devices():

trainer = Trainer(accelerator="ipu", devices=8)

assert trainer.ipus == 8
assert isinstance(trainer.training_type_plugin, IPUPlugin)
assert isinstance(trainer.accelerator, IPUAccelerator)


@RunIf(ipu=True)
def test_accelerator_auto_with_devices_ipu():

trainer = Trainer(accelerator="auto", devices=8)

assert trainer._device_type == "ipu"
assert trainer.ipus == 8


@RunIf(ipu=True)
def test_accelerator_ipu_with_ipus_priority():
""" Test for checking `ipus` flag takes priority over `devices`. """

ipus = 8
trainer = Trainer(accelerator="ipu", devices=1, ipus=ipus)
assert trainer.ipus == ipus
30 changes: 30 additions & 0 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import SingleTPUPlugin, TPUSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -148,3 +149,32 @@ def test_accelerator_tpu_with_auto():

assert trainer._device_type == "tpu"
assert isinstance(trainer.accelerator, TPUAccelerator)


@RunIf(tpu=True)
@pytest.mark.parametrize(["devices", "plugin"], [([0], SingleTPUPlugin), (8, TPUSpawnPlugin)])
def test_accelerator_tpu_with_devices(devices, plugin):

trainer = Trainer(accelerator="tpu", devices=devices)

assert trainer.tpu_cores == devices
assert isinstance(trainer.training_type_plugin, plugin)
assert isinstance(trainer.accelerator, TPUAccelerator)


@RunIf(tpu=True)
def test_accelerator_auto_with_devices_tpu():

trainer = Trainer(accelerator="auto", devices=8)

assert trainer._device_type == "tpu"
assert trainer.tpu_cores == 8


@RunIf(tpu=True)
def test_accelerator_tpu_with_tpu_cores_priority():
""" Test for checking `tpu_cores` flag takes priority over `devices`. """

tpu_cores = 8
trainer = Trainer(accelerator="tpu", devices=1, tpu_cores=tpu_cores)
assert trainer.tpu_cores == tpu_cores