Skip to content

Commit

Permalink
[Fix] Remove DeepSpeed Plugin FP16 exception (#8462)
Browse files Browse the repository at this point in the history
* Remove error, add mixed to check

* Add test

* Remove test

* Add changelog

* Add test for mixed

* Update tests/plugins/test_deepspeed_plugin.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add special

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 19, 2021
1 parent 999ef5c commit 06ac7d9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287))


- Removed DeepSpeed FP16 Exception as FP32 is now supported ([#8462](https://github.com/PyTorchLightning/pytorch-lightning/pull/8462))


### Fixed

- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _format_precision_config(self):
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
precision = self.lightning_module.trainer.accelerator_connector.precision
if precision == 16:
if precision in (16, 'mixed'):
if "fp16" not in self.config and amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -559,8 +559,6 @@ def _format_precision_config(self):
"enabled": True,
"opt_level": amp_level,
}
if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config):
raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.")

def _create_default_config(
self,
Expand Down
29 changes: 11 additions & 18 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,14 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config):


@RunIf(amp_native=True, deepspeed=True)
@pytest.mark.parametrize("precision", [16, 'mixed'])
@pytest.mark.parametrize(
"amp_backend", [
pytest.param("native", marks=RunIf(amp_native=True)),
pytest.param("apex", marks=RunIf(amp_apex=True)),
]
)
def test_deepspeed_precision_choice(amp_backend, tmpdir):
def test_deepspeed_precision_choice(amp_backend, precision, tmpdir):
"""
Test to ensure precision plugin is also correctly chosen.
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin
Expand All @@ -181,12 +182,12 @@ def test_deepspeed_precision_choice(amp_backend, tmpdir):
default_root_dir=tmpdir,
plugins='deepspeed',
amp_backend=amp_backend,
precision=16,
precision=precision,
)

assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin)
assert trainer.accelerator.precision_plugin.precision == 16
assert trainer.accelerator.precision_plugin.precision == precision


@RunIf(deepspeed=True)
Expand Down Expand Up @@ -224,21 +225,6 @@ def test_deepspeed_defaults(tmpdir):
assert isinstance(plugin.config["zero_optimization"], dict)


@RunIf(min_gpus=1, deepspeed=True)
def test_invalid_deepspeed_defaults_no_precision(tmpdir):
"""Test to ensure that using defaults, if precision is not set to 16, we throw an exception."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
plugins='deepspeed',
)
with pytest.raises(
MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.'
):
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_warn_deepspeed_override_backward(tmpdir):
"""Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning."""
Expand Down Expand Up @@ -448,6 +434,13 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):
_assert_save_model_is_equal(model, tmpdir, trainer)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_fp32_works(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, gpus=1, plugins='deepspeed_stage_3', fast_dev_run=True)
trainer.fit(model)


class ModelParallelClassificationModel(LightningModule):

def __init__(self, lr: float = 0.01, num_blocks: int = 5):
Expand Down

0 comments on commit 06ac7d9

Please sign in to comment.