From 175521fe6bb6a7295b7b446c704d1542362bb88b Mon Sep 17 00:00:00 2001 From: marsggbo Date: Sat, 5 Jun 2021 16:33:37 +0800 Subject: [PATCH 01/27] Update horovod.py fixed issue #7839 --- .../plugins/training_type/horovod.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 99899aed11753..054eae72a2f94 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -90,12 +90,21 @@ def _filter_named_parameters(model, optimizer): return [(name, p) for name, p in model.named_parameters() if p in opt_params] # Horovod: wrap optimizers to perform gradient aggregation via allreduce - optimizers = [ - hvd.DistributedOptimizer( - optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) - ) for optimizer in optimizers - ] - self.lightning_module.trainer.accelerator.optimizers = optimizers + h_optimizers = [] + for optimizer in optimizers: + if 'horovod' not in str(optimizer.__class__): + h_optimizers.append(hvd.DistributedOptimizer( + optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) + )) + else: + h_optimizers.append(optimizer) + self.lightning_module.trainer.accelerator.optimizers = h_optimizers + # optimizers = [ + # hvd.DistributedOptimizer( + # optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) + # ) for optimizer in optimizers + # ] + # self.lightning_module.trainer.accelerator.optimizers = optimizers def start_training(self, trainer): with ExitStack() as stack: From 80542f3f26846912e8deae1cae8be36c3f4e5d2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 5 Jun 2021 08:44:32 +0000 Subject: [PATCH 02/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/horovod.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 054eae72a2f94..8a250db57cbdf 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -93,9 +93,11 @@ def _filter_named_parameters(model, optimizer): h_optimizers = [] for optimizer in optimizers: if 'horovod' not in str(optimizer.__class__): - h_optimizers.append(hvd.DistributedOptimizer( - optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) - )) + h_optimizers.append( + hvd.DistributedOptimizer( + optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) + ) + ) else: h_optimizers.append(optimizer) self.lightning_module.trainer.accelerator.optimizers = h_optimizers From 37e31e033c2024b67e4444c19a6910c1661f7d0f Mon Sep 17 00:00:00 2001 From: marsggbo Date: Mon, 7 Jun 2021 10:33:51 +0800 Subject: [PATCH 03/27] Update horovod.py --- pytorch_lightning/plugins/training_type/horovod.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8a250db57cbdf..4c391f2c31c8a 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -101,12 +101,6 @@ def _filter_named_parameters(model, optimizer): else: h_optimizers.append(optimizer) self.lightning_module.trainer.accelerator.optimizers = h_optimizers - # optimizers = [ - # hvd.DistributedOptimizer( - # optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) - # ) for optimizer in optimizers - # ] - # self.lightning_module.trainer.accelerator.optimizers = optimizers def start_training(self, trainer): with ExitStack() as stack: From e781a1d82e578c53c9f86cb6ba6bb2225486ab43 Mon Sep 17 00:00:00 2001 From: marsggbo Date: Mon, 7 Jun 2021 12:04:17 +0800 Subject: [PATCH 04/27] Create test_horovod.py --- tests/plugins/test_horovod.py | 48 +++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/plugins/test_horovod.py diff --git a/tests/plugins/test_horovod.py b/tests/plugins/test_horovod.py new file mode 100644 index 0000000000000..a7d187eea9af7 --- /dev/null +++ b/tests/plugins/test_horovod.py @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import HorovodPlugin +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +class BoringModelGPU(BoringModel): + + def on_train_start(self) -> None: + # make sure that the model is on GPU when training + assert self.device == torch.device(f"cuda:{self.trainer.training_type_plugin.local_rank}") + self.start_cuda_memory = torch.cuda.memory_allocated() + + +@RunIf(skip_windows=True, min_gpus=2, special=True) +def test_horovod_with_2_gpus(): + """Tests if device is set correctely when training and after teardown for HorovodPlugin.""" + trainer = Trainer(gpus=2, accelerator="horovod", fast_dev_run=True) + # assert training type plugin attributes for device setting + assert isinstance(trainer.training_type_plugin, HorovodPlugin) + assert trainer.training_type_plugin.on_gpu + assert not trainer.training_type_plugin.on_tpu + local_rank = trainer.training_type_plugin.local_rank + assert trainer.training_type_plugin.root_device == torch.device(f"cuda:{local_rank}") + + model = BoringModelGPU() + + trainer.fit(model) + + # assert after training, model is moved to CPU and memory is deallocated + assert model.device == torch.device("cpu") + cuda_memory = torch.cuda.memory_allocated() + assert cuda_memory < model.start_cuda_memory From 1f2fe56beb314d64d72294370a2050f63d7c1185 Mon Sep 17 00:00:00 2001 From: marsggbo Date: Mon, 7 Jun 2021 12:11:44 +0800 Subject: [PATCH 05/27] solve issue #7853 --- pytorch_lightning/plugins/training_type/horovod.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 4c391f2c31c8a..777800d7cb9fa 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,7 +21,8 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import group as GROUP if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -181,10 +182,10 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ def all_gather( self, result: Union[torch.Tensor], - group: Optional[Any] = group.WORLD, + group: Optional[Any] = GROUP.WORLD, sync_grads: bool = False ) -> torch.Tensor: - if group is not None and group != group.WORLD: + if group is not None and group != GROUP.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." From 44884252ac2a776a458c2ec19f23561bfc890810 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jun 2021 04:12:42 +0000 Subject: [PATCH 06/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 777800d7cb9fa..5ebc5daf3b4a2 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,8 +21,8 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.distributed import group as GROUP +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd From 224e824ae2a5814537d235da3e49e422e8d385e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 8 Jun 2021 22:40:25 +0200 Subject: [PATCH 07/27] Update tests/plugins/test_horovod.py Co-authored-by: thomas chaton --- tests/plugins/test_horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_horovod.py b/tests/plugins/test_horovod.py index a7d187eea9af7..5ba4f71af630a 100644 --- a/tests/plugins/test_horovod.py +++ b/tests/plugins/test_horovod.py @@ -29,7 +29,7 @@ def on_train_start(self) -> None: @RunIf(skip_windows=True, min_gpus=2, special=True) def test_horovod_with_2_gpus(): - """Tests if device is set correctely when training and after teardown for HorovodPlugin.""" + """Tests if device is set correctly when training and after teardown for HorovodPlugin.""" trainer = Trainer(gpus=2, accelerator="horovod", fast_dev_run=True) # assert training type plugin attributes for device setting assert isinstance(trainer.training_type_plugin, HorovodPlugin) From 71618ff529f845b331c035284f3422f47b82e365 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jul 2021 12:01:19 +0000 Subject: [PATCH 08/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/horovod.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index eded12984c714..312da563d1706 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -20,9 +20,9 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as GROUP -from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_only, ReduceOp - +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd From c71630ac5f7f564e849364199abc8fb60ec091de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 14:10:17 +0200 Subject: [PATCH 09/27] extract method --- .../plugins/training_type/horovod.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 312da563d1706..8a33787115633 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, Tuple import torch +import torch.nn as nn +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer @@ -86,22 +88,7 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) - def _filter_named_parameters(model, optimizer): - opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])]) - return [(name, p) for name, p in model.named_parameters() if p in opt_params] - - # Horovod: wrap optimizers to perform gradient aggregation via allreduce - h_optimizers = [] - for optimizer in optimizers: - if 'horovod' not in str(optimizer.__class__): - h_optimizers.append( - hvd.DistributedOptimizer( - optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) - ) - ) - else: - h_optimizers.append(optimizer) - self.lightning_module.trainer.accelerator.optimizers = h_optimizers + self.lightning_module.trainer.accelerator.optimizers = self._wrap_optimizers(optimizers) def start_training(self, trainer): with ExitStack() as stack: @@ -206,3 +193,22 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: # synchronize all horovod optimizers. for optimizer in self.lightning_module.trainer.optimizers: optimizer.synchronize() + + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List[hvd.DistributedOptimizer]: + # Horovod: wrap optimizers to perform gradient aggregation via allreduce + h_optimizers = [] + for optimizer in optimizers: + if "horovod" not in str(optimizer.__class__): + h_optimizers.append( + hvd.DistributedOptimizer( + optimizer, named_parameters=self._filter_named_parameters(self.lightning_module, optimizer) + ) + ) + else: + h_optimizers.append(optimizer) + return h_optimizers + + @staticmethod + def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: + opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])]) + return [(name, p) for name, p in model.named_parameters() if p in opt_params] From f1e7dd836f53d0e5d472de26e6b16e33351a8126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 14:22:43 +0200 Subject: [PATCH 10/27] update typehint hvd --- pytorch_lightning/plugins/training_type/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8a33787115633..c85ccf7b136c6 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -194,7 +194,7 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: for optimizer in self.lightning_module.trainer.optimizers: optimizer.synchronize() - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List[hvd.DistributedOptimizer]: + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: # Horovod: wrap optimizers to perform gradient aggregation via allreduce h_optimizers = [] for optimizer in optimizers: From f161b4e1cf50230c339afc5057a1b70e6376beef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jul 2021 12:39:05 +0000 Subject: [PATCH 11/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index c85ccf7b136c6..ea81120e6e934 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Any, List, Optional, Union, Tuple +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn as nn From 82b9f63aadfef740b20dbcb7f3301a6d3626f20e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 14:46:29 +0200 Subject: [PATCH 12/27] update docstring --- pytorch_lightning/plugins/training_type/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index ea81120e6e934..24011ef537f5e 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -195,7 +195,7 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: optimizer.synchronize() def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: - # Horovod: wrap optimizers to perform gradient aggregation via allreduce + """ Wraps optimizers to perform gradient aggregation via allreduce. """ h_optimizers = [] for optimizer in optimizers: if "horovod" not in str(optimizer.__class__): From 3e3964b68ca1f98c181a72dec6c729d2d3b82455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 14:47:47 +0200 Subject: [PATCH 13/27] rename local var --- pytorch_lightning/plugins/training_type/horovod.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 24011ef537f5e..bd139331de812 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -196,17 +196,17 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """ Wraps optimizers to perform gradient aggregation via allreduce. """ - h_optimizers = [] + hvd_optimizers = [] for optimizer in optimizers: if "horovod" not in str(optimizer.__class__): - h_optimizers.append( + hvd_optimizers.append( hvd.DistributedOptimizer( optimizer, named_parameters=self._filter_named_parameters(self.lightning_module, optimizer) ) ) else: - h_optimizers.append(optimizer) - return h_optimizers + hvd_optimizers.append(optimizer) + return hvd_optimizers @staticmethod def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: From f80fce94d6d49db27d777182fc95dc1814e663ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 14:55:29 +0200 Subject: [PATCH 14/27] add assertions to default train model --- tests/models/data/horovod/train_default_model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index c4cbaeb1363c9..543552dd54dd8 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -55,15 +55,24 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True): class TestModel(BoringModel): + def on_train_start(self) -> None: + device = torch.device("cuda", self.trainer.training_type_plugin.local_rank) + assert self.device == device + self.initial_cuda_memory = torch.cuda.memory_allocated(device) + def training_epoch_end(self, outputs) -> None: res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") assert res.sum() == self.trainer.training_type_plugin.world_size model = TestModel() trainer = Trainer(**trainer_options) + trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" + assert model.device == torch.device("cpu") + assert torch.cuda.memory_allocated() < model.initial_cuda_memory + # Horovod should be initialized following training. If not, this will raise an exception. if check_size: assert hvd.size() == 2 From 12208511ffdda1f3d387cc8cee165e03a128d604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 14:58:36 +0200 Subject: [PATCH 15/27] fix test for cpu case --- tests/models/data/horovod/train_default_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 543552dd54dd8..9da7afe523f02 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -56,7 +56,7 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True): class TestModel(BoringModel): def on_train_start(self) -> None: - device = torch.device("cuda", self.trainer.training_type_plugin.local_rank) + device = torch.device("cuda", self.trainer.local_rank) if on_gpu else torch.device("cpu") assert self.device == device self.initial_cuda_memory = torch.cuda.memory_allocated(device) From d455a77c57e4cf1d7ea9554113ca5ee47914e585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:01:30 +0200 Subject: [PATCH 16/27] fix assertion for cpu device --- tests/models/data/horovod/train_default_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 9da7afe523f02..1c49dc5abb050 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -58,7 +58,7 @@ class TestModel(BoringModel): def on_train_start(self) -> None: device = torch.device("cuda", self.trainer.local_rank) if on_gpu else torch.device("cpu") assert self.device == device - self.initial_cuda_memory = torch.cuda.memory_allocated(device) + self.initial_cuda_memory = torch.cuda.memory_allocated(trainer.local_rank) def training_epoch_end(self, outputs) -> None: res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") @@ -71,7 +71,7 @@ def training_epoch_end(self, outputs) -> None: assert trainer.state.finished, f"Training failed with {trainer.state}" assert model.device == torch.device("cpu") - assert torch.cuda.memory_allocated() < model.initial_cuda_memory + assert torch.cuda.memory_allocated(trainer.local_rank) <= model.initial_cuda_memory # Horovod should be initialized following training. If not, this will raise an exception. if check_size: From 441c5c9d7cfb769b37c2e14fabe694728a408b05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:02:47 +0200 Subject: [PATCH 17/27] remove old test_horovod file --- tests/plugins/test_horovod.py | 48 ----------------------------------- 1 file changed, 48 deletions(-) delete mode 100644 tests/plugins/test_horovod.py diff --git a/tests/plugins/test_horovod.py b/tests/plugins/test_horovod.py deleted file mode 100644 index 5ba4f71af630a..0000000000000 --- a/tests/plugins/test_horovod.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from pytorch_lightning import Trainer -from pytorch_lightning.plugins import HorovodPlugin -from tests.helpers.boring_model import BoringModel -from tests.helpers.runif import RunIf - - -class BoringModelGPU(BoringModel): - - def on_train_start(self) -> None: - # make sure that the model is on GPU when training - assert self.device == torch.device(f"cuda:{self.trainer.training_type_plugin.local_rank}") - self.start_cuda_memory = torch.cuda.memory_allocated() - - -@RunIf(skip_windows=True, min_gpus=2, special=True) -def test_horovod_with_2_gpus(): - """Tests if device is set correctly when training and after teardown for HorovodPlugin.""" - trainer = Trainer(gpus=2, accelerator="horovod", fast_dev_run=True) - # assert training type plugin attributes for device setting - assert isinstance(trainer.training_type_plugin, HorovodPlugin) - assert trainer.training_type_plugin.on_gpu - assert not trainer.training_type_plugin.on_tpu - local_rank = trainer.training_type_plugin.local_rank - assert trainer.training_type_plugin.root_device == torch.device(f"cuda:{local_rank}") - - model = BoringModelGPU() - - trainer.fit(model) - - # assert after training, model is moved to CPU and memory is deallocated - assert model.device == torch.device("cpu") - cuda_memory = torch.cuda.memory_allocated() - assert cuda_memory < model.start_cuda_memory From d7e3e0f3d996a675e06ab3f8c09cccf1259dca8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:06:36 +0200 Subject: [PATCH 18/27] keep list comprehension as suggested by @carmocca --- .../plugins/training_type/horovod.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index bd139331de812..a1ce3ad3fc80c 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -196,17 +196,10 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """ Wraps optimizers to perform gradient aggregation via allreduce. """ - hvd_optimizers = [] - for optimizer in optimizers: - if "horovod" not in str(optimizer.__class__): - hvd_optimizers.append( - hvd.DistributedOptimizer( - optimizer, named_parameters=self._filter_named_parameters(self.lightning_module, optimizer) - ) - ) - else: - hvd_optimizers.append(optimizer) - return hvd_optimizers + return [ + hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt)) + if "horovod" not in str(opt.__class__) else opt for opt in optimizers + ] @staticmethod def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: From e3878fb1ec9bb93889cf7e9f282cec2c9e21791a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:14:01 +0200 Subject: [PATCH 19/27] prevent wrapping when testing --- pytorch_lightning/plugins/training_type/horovod.py | 3 +++ tests/models/data/horovod/train_default_model.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index a1ce3ad3fc80c..34ef875db5f03 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -196,6 +196,9 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """ Wraps optimizers to perform gradient aggregation via allreduce. """ + if not self.lightning_module.trainer.training: + # no need to wrap optimizers + return optimizers return [ hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt)) if "horovod" not in str(opt.__class__) else opt for opt in optimizers diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 1c49dc5abb050..0cca1718c9fb8 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -103,6 +103,8 @@ def training_epoch_end(self, outputs) -> None: # Test the root_gpu property assert trainer.root_gpu == hvd.local_rank() + # trainer.test(model) + if __name__ == "__main__": args = parser.parse_args() From 8021fbdf8fa3fcdef94e7a4bf4dce15f2d196f71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:19:01 +0200 Subject: [PATCH 20/27] revert previous commit --- pytorch_lightning/plugins/training_type/horovod.py | 3 --- tests/models/data/horovod/train_default_model.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 34ef875db5f03..a1ce3ad3fc80c 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -196,9 +196,6 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """ Wraps optimizers to perform gradient aggregation via allreduce. """ - if not self.lightning_module.trainer.training: - # no need to wrap optimizers - return optimizers return [ hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt)) if "horovod" not in str(opt.__class__) else opt for opt in optimizers diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 0cca1718c9fb8..1c49dc5abb050 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -103,8 +103,6 @@ def training_epoch_end(self, outputs) -> None: # Test the root_gpu property assert trainer.root_gpu == hvd.local_rank() - # trainer.test(model) - if __name__ == "__main__": args = parser.parse_args() From 44fbd76c84338b64b6649d49590c0791d1b7dd32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:24:28 +0200 Subject: [PATCH 21/27] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2be735ba3f855..b34e41e793a71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -493,6 +493,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442)) +- Fixed a `TypeError` when wrapping optimizers in the `HorovodPlugin` and running `Trainer.test` ([#7839](https://github.com/PyTorchLightning/pytorch-lightning/pull/7839)) + + ## [1.3.8] - 2021-07-01 ### Fixed From 8ee0b98e06123d6d3ae390ecf5323a290f8e5dd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:25:41 +0200 Subject: [PATCH 22/27] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b34e41e793a71..8ba75b878bc68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -493,7 +493,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442)) -- Fixed a `TypeError` when wrapping optimizers in the `HorovodPlugin` and running `Trainer.test` ([#7839](https://github.com/PyTorchLightning/pytorch-lightning/pull/7839)) +- Fixed a `TypeError` when wrapping optimizers in the `HorovodPlugin` and running `Trainer.test` ([#7840](https://github.com/PyTorchLightning/pytorch-lightning/pull/7840)) ## [1.3.8] - 2021-07-01 From df41b53788c667d9ac5105eab48e55797b89fbc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:43:27 +0200 Subject: [PATCH 23/27] remove flaky memory test --- tests/models/data/horovod/train_default_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 1c49dc5abb050..09f8742a7ec27 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -56,9 +56,8 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True): class TestModel(BoringModel): def on_train_start(self) -> None: - device = torch.device("cuda", self.trainer.local_rank) if on_gpu else torch.device("cpu") - assert self.device == device - self.initial_cuda_memory = torch.cuda.memory_allocated(trainer.local_rank) + expected_device = torch.device("cuda", self.trainer.local_rank) if on_gpu else torch.device("cpu") + assert self.device == expected_device def training_epoch_end(self, outputs) -> None: res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") @@ -71,7 +70,6 @@ def training_epoch_end(self, outputs) -> None: assert trainer.state.finished, f"Training failed with {trainer.state}" assert model.device == torch.device("cpu") - assert torch.cuda.memory_allocated(trainer.local_rank) <= model.initial_cuda_memory # Horovod should be initialized following training. If not, this will raise an exception. if check_size: From 8e0565f26e4713fc449d0ed121dedfe010eb996a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:43:59 +0200 Subject: [PATCH 24/27] add trainer.test() call to verify fix --- tests/models/data/horovod/train_default_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 09f8742a7ec27..adf29459c0966 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -68,6 +68,7 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" + trainer.test(model) assert model.device == torch.device("cpu") From f3fa5cfe761abfc1370f5bbb892c0425cce099e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 15:45:39 +0200 Subject: [PATCH 25/27] avoid wrapping optimizers for test --- pytorch_lightning/plugins/training_type/horovod.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index a1ce3ad3fc80c..6821c71630687 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -64,6 +64,10 @@ def setup(self, model): def pre_dispatch(self): + if not self.lightning_module.trainer.training: + # no need to setup optimizers + return + def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt From 8ac3abdf81474636ef9e482f251c6b2f3fedf737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 16:30:18 +0200 Subject: [PATCH 26/27] rename uppercase constant dist_group --- pytorch_lightning/plugins/training_type/horovod.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 6821c71630687..716d9597e9693 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -23,7 +23,7 @@ from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.distributed import group as GROUP +from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: @@ -174,10 +174,10 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ def all_gather( self, result: Union[torch.Tensor], - group: Optional[Any] = GROUP.WORLD, + group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False ) -> torch.Tensor: - if group is not None and group != GROUP.WORLD: + if group is not None and group != dist_group.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." From 62a3f018d4e3e641d54417f42f31f77a4f5303fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 20 Jul 2021 17:05:02 +0200 Subject: [PATCH 27/27] Update pytorch_lightning/plugins/training_type/horovod.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/plugins/training_type/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 716d9597e9693..e9bb986ea10a6 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -207,5 +207,5 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed @staticmethod def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]: - opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])]) + opt_params = set(p for group in optimizer.param_groups for p in group.get("params", [])) return [(name, p) for name, p in model.named_parameters() if p in opt_params]