Skip to content

Commit

Permalink
Raise error for devices for cpu not being int
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 committed Jul 20, 2021
1 parent 12634ca commit 25456e0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ def _map_devices_to_accelerator(self, accelerator: str) -> bool:
self.parallel_device_ids = device_parser.parse_gpu_ids(self.devices)
return True
if accelerator == DeviceType.CPU:
if not isinstance(self.devices, int):
raise MisconfigurationException(
"The flag `devices` only supports integer for `accelerator='cpu'`,"
f" got `devices={self.devices}` instead."
)
self.num_processes = self.devices
return True
return False
Expand Down
6 changes: 6 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,9 @@ def test_set_devices_if_none_gpu():

trainer = Trainer(accelerator="gpu", gpus=2)
assert trainer.devices == 2


def test_devices_with_cpu_only_supports_integer():

with pytest.raises(MisconfigurationException, match="The flag `devices` only supports integer"):
Trainer(accelerator="cpu", devices="1,3")

0 comments on commit 25456e0

Please sign in to comment.