Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: horovod optimizer missing 2 required positional arguments #7840

Merged
merged 32 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
175521f
Update horovod.py
marsggbo Jun 5, 2021
42e09c5
Merge pull request #1 from marsggbo/marsggbo-fixbug-horovod
marsggbo Jun 5, 2021
80542f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2021
18d0f22
Merge branch 'PyTorchLightning:master' into master
marsggbo Jun 7, 2021
37e31e0
Update horovod.py
marsggbo Jun 7, 2021
e781a1d
Create test_horovod.py
marsggbo Jun 7, 2021
1f2fe56
solve issue #7853
marsggbo Jun 7, 2021
4488425
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
224e824
Update tests/plugins/test_horovod.py
awaelchli Jun 8, 2021
3d35464
Merge branch 'master' into master
awaelchli Jul 20, 2021
71618ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
c71630a
extract method
awaelchli Jul 20, 2021
f1e7dd8
update typehint hvd
awaelchli Jul 20, 2021
f161b4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
82b9f63
update docstring
awaelchli Jul 20, 2021
3e3964b
rename local var
awaelchli Jul 20, 2021
f80fce9
add assertions to default train model
awaelchli Jul 20, 2021
1220851
fix test for cpu case
awaelchli Jul 20, 2021
d455a77
fix assertion for cpu device
awaelchli Jul 20, 2021
441c5c9
remove old test_horovod file
awaelchli Jul 20, 2021
d7e3e0f
keep list comprehension as suggested by @carmocca
awaelchli Jul 20, 2021
e3878fb
prevent wrapping when testing
awaelchli Jul 20, 2021
8021fbd
revert previous commit
awaelchli Jul 20, 2021
44fbd76
update changelog
awaelchli Jul 20, 2021
8ee0b98
update changelog
awaelchli Jul 20, 2021
df41b53
remove flaky memory test
awaelchli Jul 20, 2021
8e0565f
add trainer.test() call to verify fix
awaelchli Jul 20, 2021
f3fa5cf
avoid wrapping optimizers for test
awaelchli Jul 20, 2021
8ac3abd
rename uppercase constant dist_group
awaelchli Jul 20, 2021
7be9623
Merge branch 'master' into master
tchaton Jul 20, 2021
62a3f01
Update pytorch_lightning/plugins/training_type/horovod.py
awaelchli Jul 20, 2021
a1ee5b6
Merge branch 'master' into master
awaelchli Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


- Fixed a `TypeError` when wrapping optimizers in the `HorovodPlugin` and running `Trainer.test` ([#7840](https://github.com/PyTorchLightning/pytorch-lightning/pull/7840))


## [1.3.8] - 2021-07-01

### Fixed
Expand Down
40 changes: 25 additions & 15 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import distributed_available, group, rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand Down Expand Up @@ -60,6 +64,10 @@ def setup(self, model):

def pre_dispatch(self):

if not self.lightning_module.trainer.training:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# no need to setup optimizers
return

def _unpack_lightning_optimizer(opt):
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt

Expand All @@ -84,17 +92,7 @@ def _unpack_lightning_optimizer(opt):
for optimizer in optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

def _filter_named_parameters(model, optimizer):
opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])])
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

# Horovod: wrap optimizers to perform gradient aggregation via allreduce
optimizers = [
hvd.DistributedOptimizer(
optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer)
) for optimizer in optimizers
]
self.lightning_module.trainer.accelerator.optimizers = optimizers
self.lightning_module.trainer.accelerator.optimizers = self._wrap_optimizers(optimizers)

def start_training(self, trainer):
with ExitStack() as stack:
Expand Down Expand Up @@ -176,10 +174,10 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[
def all_gather(
self,
result: Union[torch.Tensor],
group: Optional[Any] = group.WORLD,
group: Optional[Any] = dist_group.WORLD,
sync_grads: bool = False
) -> torch.Tensor:
if group is not None and group != group.WORLD:
if group is not None and group != dist_group.WORLD:
raise ValueError(
"Horovod does not support allgather using a subcommunicator at this time. "
"Unset `group`."
Expand All @@ -199,3 +197,15 @@ def post_backward(self, closure_loss: torch.Tensor) -> None:
# synchronize all horovod optimizers.
for optimizer in self.lightning_module.trainer.optimizers:
optimizer.synchronize()

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]:
""" Wraps optimizers to perform gradient aggregation via allreduce. """
return [
hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt))
if "horovod" not in str(opt.__class__) else opt for opt in optimizers
]

@staticmethod
def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])])
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
8 changes: 8 additions & 0 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,22 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True):

class TestModel(BoringModel):

def on_train_start(self) -> None:
expected_device = torch.device("cuda", self.trainer.local_rank) if on_gpu else torch.device("cpu")
assert self.device == expected_device

def training_epoch_end(self, outputs) -> None:
res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum")
assert res.sum() == self.trainer.training_type_plugin.world_size

model = TestModel()
trainer = Trainer(**trainer_options)

trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
trainer.test(model)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

assert model.device == torch.device("cpu")

# Horovod should be initialized following training. If not, this will raise an exception.
if check_size:
Expand Down