diff --git a/CHANGELOG.md b/CHANGELOG.md index 707b37b885..960e0ed7ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720)) +- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724)) + + ### Changed - Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701)) diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst new file mode 100644 index 0000000000..54f8051b6b --- /dev/null +++ b/docs/source/callbacks/sparseml.rst @@ -0,0 +1,62 @@ +================= +SparseML Callback +================= + +`SparseML `__ allows you to leverage sparsity to improve inference times substantially. + +SparseML requires you to fine-tune your model with the ``SparseMLCallback`` + a SparseML Recipe. By training with the ``SparseMLCallback``, you can leverage the `DeepSparse `__ engine to exploit the introduced sparsity, resulting in large performance improvements. + +.. warning:: + + The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs. + +To use leverage SparseML & DeepSparse follow the below steps: + +1. Choose your Sparse Recipe +---------------------------- + +To choose a recipe, have a look at `recipes `__ and `Sparse Zoo `__. + +It may be easier to infer a recipe via the UI dashboard using `Sparsify `__ which allows you to tweak and configure a recipe. +This requires to import an ONNX model, which you can get from your ``LightningModule`` by doing ``model.to_onnx(output_path)``. + +2. Train with SparseMLCallback +------------------------------ + +.. testcode:: + :skipif: not _SPARSEML_AVAILABLE + + from pytorch_lightning import LightningModule, Trainer + from pl_bolts.callbacks import SparseMLCallback + + class MyModel(LightningModule): + ... + + model = MyModel() + + trainer = Trainer( + callbacks=SparseMLCallback(recipe_path='recipe.yaml') + ) + +3. Export to ONNX! +------------------ + +Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format. +Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below. + +.. testcode:: + :skipif: not _SPARSEML_AVAILABLE + + import torch + + model = MyModel() + ... + + # export the onnx model, using the `model.example_input_array` + SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/') + + # export the onnx model, providing a sample batch + SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32)) + + +Once your model has been exported, you can import this into either `Sparsify `__ or `DeepSparse `__. diff --git a/docs/source/conf.py b/docs/source/conf.py index ea38d1de68..c285ed475a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -405,5 +405,7 @@ def find_source(): import pytorch_lightning as pl from pytorch_lightning import Trainer, LightningModule +from pl_bolts.utils import _SPARSEML_AVAILABLE + """ coverage_skip_undoc_in_source = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 860b82ba12..c0eeea8689 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -24,6 +24,7 @@ Lightning-Bolts documentation callbacks/variational_callbacks callbacks/vision_callbacks callbacks/torch_ort + callbacks/sparseml .. toctree:: :maxdepth: 2 diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py index 9558aed816..f52a408c48 100644 --- a/pl_bolts/callbacks/__init__.py +++ b/pl_bolts/callbacks/__init__.py @@ -2,6 +2,7 @@ from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.callbacks.printing import PrintTableMetricsCallback +from pl_bolts.callbacks.sparseml import SparseMLCallback from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.callbacks.torch_ort import ORTCallback from pl_bolts.callbacks.variational import LatentDimInterpolator @@ -20,4 +21,5 @@ "ConfusedLogitCallback", "TensorboardGenerativeModelImageSampler", "ORTCallback", + "SparseMLCallback", ] diff --git a/pl_bolts/callbacks/sparseml.py b/pl_bolts/callbacks/sparseml.py new file mode 100644 index 0000000000..de05088994 --- /dev/null +++ b/pl_bolts/callbacks/sparseml.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.utils import _PL_GREATER_EQUAL_1_4_5, _SPARSEML_AVAILABLE, _TORCH_MAX_VERSION_1_8_1 + +if _SPARSEML_AVAILABLE: + from sparseml.pytorch.optim import ScheduledModifierManager + from sparseml.pytorch.utils import ModuleExporter + + +class SparseMLCallback(Callback): + """Enables SparseML aware training. Requires a recipe to run during training. + + Args: + recipe_path: Path to a SparseML compatible yaml recipe. + More information at https://docs.neuralmagic.com/sparseml/source/recipes.html + """ + + def __init__(self, recipe_path): + if not _SPARSEML_AVAILABLE: + if not _PL_GREATER_EQUAL_1_4_5: + raise MisconfigurationException("SparseML requires PyTorch Lightning 1.4.5 or greater.") + if not _TORCH_MAX_VERSION_1_8_1: + raise MisconfigurationException("SparseML requires PyTorch version 1.8.1 or lower.") + raise MisconfigurationException("SparseML has not be installed, install with pip install sparseml") + self.manager = ScheduledModifierManager.from_yaml(recipe_path) + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + optimizer = trainer.optimizers + + if len(optimizer) > 1: + raise MisconfigurationException("SparseML only supports training with one optimizer.") + optimizer = optimizer[0] + optimizer = self.manager.modify( + pl_module, optimizer, steps_per_epoch=self._num_training_steps_per_epoch(trainer), epoch=0 + ) + trainer.optimizers = [optimizer] + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.manager.finalize(pl_module) + + def _num_training_steps_per_epoch(self, trainer: Trainer) -> int: + """Total training steps inferred from the datamodule and devices.""" + if isinstance(trainer.limit_train_batches, int) and trainer.limit_train_batches != 0: + dataset_size = trainer.limit_train_batches + elif isinstance(trainer.limit_train_batches, float): + # limit_train_batches is a percentage of batches + dataset_size = len(trainer.datamodule.train_dataloader()) + dataset_size = int(dataset_size * trainer.limit_train_batches) + else: + dataset_size = len(trainer.datamodule.train_dataloader()) + + num_devices = max(1, trainer.num_gpus, trainer.num_processes) + if trainer.tpu_cores: + num_devices = max(num_devices, trainer.tpu_cores) + + effective_batch_size = trainer.accumulate_grad_batches * num_devices + max_estimated_steps = dataset_size // effective_batch_size + + if trainer.max_steps and trainer.max_steps < max_estimated_steps: + return trainer.max_steps + return max_estimated_steps + + @staticmethod + def export_to_sparse_onnx( + model: LightningModule, output_dir: str, sample_batch: Optional[torch.Tensor] = None + ) -> None: + """Exports the model to ONNX format.""" + with model._prevent_trainer_and_dataloaders_deepcopy(): + exporter = ModuleExporter(model, output_dir=output_dir) + sample_batch = sample_batch if sample_batch is not None else model.example_input_array + if sample_batch is None: + raise MisconfigurationException( + "To export the model, a sample batch must be passed via " + "``SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)`` " + "or an ``example_input_array`` property within the LightningModule" + ) + exporter.export_onnx(sample_batch=sample_batch) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 8a77cd66e7..9f523d2953 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -39,6 +39,9 @@ def _compare_version(package: str, op, version) -> bool: _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1") _PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0") +_PL_GREATER_EQUAL_1_4_5 = _compare_version("pytorch_lightning", operator.ge, "1.4.5") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") +_TORCH_MAX_VERSION_1_8_1 = _compare_version("torch", operator.le, "1.8.1") +_SPARSEML_AVAILABLE = _module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_1_8_1 __all__ = ["BatchGradientVerification"] diff --git a/requirements/test.txt b/requirements/test.txt index 7c862c329b..da446df50d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -12,3 +12,4 @@ mypy>=0.790 atari-py==0.2.6 # needed for RL scikit-learn>=0.23 +sparseml diff --git a/tests/callbacks/test_sparseml.py b/tests/callbacks/test_sparseml.py new file mode 100644 index 0000000000..006416dc3a --- /dev/null +++ b/tests/callbacks/test_sparseml.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import pytest +import torch +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.callbacks import SparseMLCallback +from pl_bolts.utils import _SPARSEML_AVAILABLE +from tests.helpers.boring_model import BoringModel + +if _SPARSEML_AVAILABLE: + from sparseml.pytorch.optim import RecipeManagerStepWrapper + + +@pytest.fixture +def recipe(): + return """ + version: 0.1.0 + modifiers: + - !EpochRangeModifier + start_epoch: 0.0 + end_epoch: 1.0 + + - !LearningRateModifier + start_epoch: 0 + end_epoch: -1.0 + update_frequency: -1.0 + init_lr: 0.005 + lr_class: MultiStepLR + lr_kwargs: {'milestones': [43, 60], 'gamma': 0.1} + + - !GMPruningModifier + start_epoch: 0 + end_epoch: 40 + update_frequency: 1.0 + init_sparsity: 0.05 + final_sparsity: 0.85 + mask_type: unstructured + params: __ALL__ + """ + + +@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.") +def test_train_sparse_ml_callback(tmpdir, recipe): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(trainer.optimizers[0], RecipeManagerStepWrapper) + + recipe_path = Path(tmpdir) / "recipe.yaml" + recipe_path.write_text(recipe) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=[SparseMLCallback(recipe_path=str(recipe_path)), TestCallback()], + ) + trainer.fit(model) + + sample_batch = torch.randn(1, 32) + output_dir = Path(tmpdir) / "model_export/" + SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch) + assert os.path.exists(output_dir) + + +@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.") +def test_fail_if_no_example_input_array_or_sample_batch(tmpdir, recipe): + model = BoringModel() + with pytest.raises(MisconfigurationException, match="To export the model, a sample batch must be passed"): + output_dir = Path(tmpdir) / "model_export/" + SparseMLCallback.export_to_sparse_onnx(model, output_dir) + + +@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.") +def test_fail_if_multiple_optimizers(tmpdir, recipe): + recipe_path = Path(tmpdir) / "recipe.yaml" + recipe_path.write_text(recipe) + + class TestModel(BoringModel): + def configure_optimizers(self): + return [torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters())], [] + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, callbacks=[SparseMLCallback(recipe_path=str(recipe_path))] + ) + with pytest.raises(MisconfigurationException, match="SparseML only supports training with one optimizer."): + trainer.fit(model)