diff --git a/CHANGELOG.md b/CHANGELOG.md index 49703527a56f3..412eaf32aacce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -405,6 +405,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287)) +- Removed DeepSpeed FP16 Exception as FP32 is now supported ([#8462](https://github.com/PyTorchLightning/pytorch-lightning/pull/8462)) + + ### Fixed - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e704b662fd6ca..d47d4caa0321b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -541,7 +541,7 @@ def _format_precision_config(self): amp_type = self.lightning_module.trainer.accelerator_connector.amp_type amp_level = self.lightning_module.trainer.accelerator_connector.amp_level precision = self.lightning_module.trainer.accelerator_connector.precision - if precision == 16: + if precision in (16, 'mixed'): if "fp16" not in self.config and amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -559,8 +559,6 @@ def _format_precision_config(self): "enabled": True, "opt_level": amp_level, } - if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config): - raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.") def _create_default_config( self, diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index f3d89b54ae236..dea6bd141dfa1 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -164,13 +164,14 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): @RunIf(amp_native=True, deepspeed=True) +@pytest.mark.parametrize("precision", [16, 'mixed']) @pytest.mark.parametrize( "amp_backend", [ pytest.param("native", marks=RunIf(amp_native=True)), pytest.param("apex", marks=RunIf(amp_apex=True)), ] ) -def test_deepspeed_precision_choice(amp_backend, tmpdir): +def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): """ Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin @@ -181,12 +182,12 @@ def test_deepspeed_precision_choice(amp_backend, tmpdir): default_root_dir=tmpdir, plugins='deepspeed', amp_backend=amp_backend, - precision=16, + precision=precision, ) assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert trainer.accelerator.precision_plugin.precision == precision @RunIf(deepspeed=True) @@ -224,21 +225,6 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) -@RunIf(min_gpus=1, deepspeed=True) -def test_invalid_deepspeed_defaults_no_precision(tmpdir): - """Test to ensure that using defaults, if precision is not set to 16, we throw an exception.""" - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - plugins='deepspeed', - ) - with pytest.raises( - MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' - ): - trainer.fit(model) - - @RunIf(min_gpus=1, deepspeed=True, special=True) def test_warn_deepspeed_override_backward(tmpdir): """Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.""" @@ -448,6 +434,13 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): _assert_save_model_is_equal(model, tmpdir, trainer) +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_fp32_works(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, gpus=1, plugins='deepspeed_stage_3', fast_dev_run=True) + trainer.fit(model) + + class ModelParallelClassificationModel(LightningModule): def __init__(self, lr: float = 0.01, num_blocks: int = 5):