From 7755572b4f37b811b83f6a933329b01af4735e66 Mon Sep 17 00:00:00 2001 From: chaton Date: Fri, 11 Dec 2020 14:51:45 +0100 Subject: [PATCH] Check if optimizer supports closure (#4981) * check if optimizer support closure * cleanup test * resolve tests * resolve flake * update test due to patch limit * update * update dep * Update tests/core/test_lightning_optimizer.py Co-authored-by: Rohit Gupta * Update tests/core/test_lightning_optimizer.py Co-authored-by: Rohit Gupta * resolve bug * update test * resolve tests * Update requirements/extra.txt Co-authored-by: Jirka Borovec * remove bolts dep * remove bolts * add missing bolts dep for tests * remove need for bolts Co-authored-by: Rohit Gupta Co-authored-by: Jirka Borovec --- pytorch_lightning/core/optimizer.py | 10 ++- pytorch_lightning/utilities/__init__.py | 1 + requirements/extra.txt | 2 +- tests/core/test_lightning_module.py | 5 ++ tests/core/test_lightning_optimizer.py | 74 ++++++++++++++++--- .../optimization/test_manual_optimization.py | 11 ++- 6 files changed, 81 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index dc63231ba6ccb..e6b973b336e43 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import types from typing import Any, Callable, Optional from weakref import proxy @@ -60,7 +61,7 @@ def __init__(self, self._trainer = None self._optimizer = optimizer self._accumulate_grad_batches = accumulate_grad_batches - self._automatic_optimization = None + self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters self._optimizer_idx = None @property @@ -73,7 +74,6 @@ def accumulate_grad_batches(self, accumulate_grad_batches): def _on_trainer_init(self, trainer): self._trainer = proxy(trainer) - self._automatic_optimization = trainer.train_loop.automatic_optimization for opt_idx, opt in enumerate(trainer.optimizers): if opt == self._optimizer: self._optimizer_idx = opt_idx @@ -111,7 +111,11 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n else: with trainer.profiler.profile(profiler_name): - optimizer.step(closure=closure, *args, **kwargs) + if self._support_closure: + optimizer.step(closure=closure, *args, **kwargs) + else: + closure() + optimizer.step(*args, **kwargs) accelerator_backend = trainer.accelerator_backend if accelerator_backend is not None and accelerator_backend.rpc_enabled: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a11862b4003bc..e5641337cc8d2 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -54,6 +54,7 @@ def _module_available(module_path: str) -> bool: OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") HOROVOD_AVAILABLE = _module_available("horovod.torch") +BOLTS_AVAILABLE = _module_available("pl_bolts") TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') diff --git a/requirements/extra.txt b/requirements/extra.txt index ad54358269bd1..3f14b1e5910dd 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip \ No newline at end of file +https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 3e2e6d040f44c..a7054a3a7ef49 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -55,6 +55,11 @@ def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir): class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + def configure_optimizers(self): optimizer = SGD(self.layer.parameters(), lr=0.1) optimizer_2 = Adam(self.layer.parameters(), lr=0.1) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index e6ec59ec4f5aa..16963a2af3c0d 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -19,10 +19,12 @@ import torch.nn as nn from torch.optim import Adam, Optimizer +import pytorch_lightning as pl from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset +from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset def test_lightning_optimizer(tmpdir): @@ -80,8 +82,8 @@ def configure_optimizers(self): assert trainer.optimizers[0].__repr__() == expected -@patch("torch.optim.Adam.step") -@patch("torch.optim.SGD.step") +@patch("torch.optim.Adam.step", autospec=True) +@patch("torch.optim.SGD.step", autospec=True) def test_lightning_optimizer_manual_optimization(mock_sgd_step, mock_adam_step, tmpdir): """ Test that the user can use our LightningOptimizer. Not recommended for now. @@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) self.manual_backward(loss_1, opt_1) - opt_1.step(idx="1") + opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) self.manual_backward(loss_2, opt_2) - opt_2.step(closure=closure, idx="2") + opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool: assert len(mock_adam_step.mock_calls) == 8 -@patch("torch.optim.Adam.step") -@patch("torch.optim.SGD.step") +@patch("torch.optim.Adam.step", autospec=True) +@patch("torch.optim.SGD.step", autospec=True) def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(mock_sgd_step, mock_adam_step, tmpdir): """ Test that the user can use our LightningOptimizer. Not recommended. @@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) self.manual_backward(loss_1, opt_1) - opt_1.step(idx="1") + opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) self.manual_backward(loss_2, opt_2) - opt_2.step(closure=closure, idx="2") + opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -195,9 +197,8 @@ def test_state(tmpdir): assert isinstance(lightning_optimizer, Adam) assert isinstance(lightning_optimizer, Optimizer) lightning_dict = {} - special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", - "_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization", - "_accumulate_grad_batches"] + special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure", + "_trainer"] for k, v in lightning_optimizer.__dict__.items(): if k not in special_attrs: lightning_dict[k] = v @@ -206,6 +207,55 @@ def test_state(tmpdir): assert optimizer.state == lightning_optimizer.state +def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir): + class OptimizerWrapper(object): + def __init__(self, optimizer): + self.optim = optimizer + self.state_dict = self.optim.state_dict + self.load_state_dict = self.optim.load_state_dict + self.zero_grad = self.optim.zero_grad + self.add_param_group = self.optim.add_param_group + self.__setstate__ = self.optim.__setstate__ + self.__getstate__ = self.optim.__getstate__ + self.__repr__ = self.optim.__repr__ + + @property + def __class__(self): + return Optimizer + + @property + def state(self): + return self.optim.state + + @property + def param_groups(self): + return self.optim.param_groups + + @param_groups.setter + def param_groups(self, value): + self.optim.param_groups = value + + def step(self): + # wrongly defined step. Should contain closure + self.optim.step(closure=None) + + class TestLightningOptimizerModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0.1) + optimizer = OptimizerWrapper(optimizer) + return [optimizer] + + model = TestLightningOptimizerModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + log_every_n_steps=1, + ) + trainer.fit(model) + + def test_lightning_optimizer_automatic_optimization(tmpdir): """ Test lightning optimize works with make_optimizer_step in automatic_optimization diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 5e341e9c66f63..9e369a874acd0 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -825,7 +825,7 @@ def optimizer_closure(): retain_graph = num_backward != backward_idx # noqa E225 self.manual_backward(loss_1, opt, retain_graph=retain_graph) - opt.step(1, closure=optimizer_closure, something="new") + opt.step(closure=optimizer_closure) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool: ) trainer.fit(model) - expected_calls = [call(1, closure=ANY, something="new") for s in range(2)] + expected_calls = [call() for s in range(2)] step_mock.assert_has_calls(expected_calls) @@ -902,7 +902,7 @@ def dis_closure(): if batch_idx % 4 == 0 : # Note: Set make_optimizer_step to True or it will use by default # Trainer(accumulate_grad_batches=x) - opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam') + opt_dis.step(closure=dis_closure, make_optimizer_step=True) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -933,10 +933,9 @@ def automatic_optimization(self) -> bool: ) trainer.fit(model) - expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)] + expected_calls = [call(optim='sgd') for s in range(4)] mock_sgd_step.assert_has_calls(expected_calls) - - expected_calls = [call(closure=ANY, optim='adam') for s in range(2)] + expected_calls = [call() for s in range(2)] mock_adam_step.assert_has_calls(expected_calls)