From 14e7e7900fda96c7e43c9443b1b41cbac8a815d4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 13:20:03 +0100 Subject: [PATCH 1/8] Add the sparseml callback --- docs/source/callbacks/sparseml.rst | 55 +++++++++++++++ pl_bolts/callbacks/__init__.py | 2 + pl_bolts/callbacks/sparseml.py | 90 +++++++++++++++++++++++++ pl_bolts/utils/__init__.py | 1 + requirements/test.txt | 1 + tests/callbacks/test_sparseml.py | 104 +++++++++++++++++++++++++++++ 6 files changed, 253 insertions(+) create mode 100644 docs/source/callbacks/sparseml.rst create mode 100644 pl_bolts/callbacks/sparseml.py create mode 100644 tests/callbacks/test_sparseml.py diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst new file mode 100644 index 0000000000..062d951798 --- /dev/null +++ b/docs/source/callbacks/sparseml.rst @@ -0,0 +1,55 @@ +================= +SparseML Callback +================= + +`SparseML `__ allows you to leverage sparsity to improve inference times substantially. + +SparseML requires you to fine-tune your model with the ``SparseMLCallback`` + a SparseML Recipe. By training with the ``SparseMLCallback``, you can leverage the `DeepSparse `__ engine to exploit the introduced sparsity, resulting in large performance improvements. + +.. warning:: + + The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs. + +To use leverage SparseML & DeepSparse follow the below steps: + +1. Choose your Sparse Recipe +---------------------------- + +To choose a recipe, have a look at `recipes `__ and `Sparse Zoo `__. + +It may be easier to infer a recipe via the UI dashboard using `Sparsify `__ which allows you to tweak and configure a recipe. +This requires to import an ONNX model, which you can get from your ``LightningModule`` by doing ``model.to_onnx(output_path)``. + +2. Train with SparseMLCallback +------------------------------ + +.. code-block:: python + + from pytorch_lightning import LightningModule, Trainer + from pl_bolts.callbacks import SparseMLCallback + + + model = MyModel() + + trainer = Trainer( + callbacks=SparseMLCallback(recipe_path='recipe.yaml') + ) + +3. Export to ONNX! +------------------ + +Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format. +Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below. + +.. code-block:: python + + import torch + + # export the onnx model, using the `model.example_input_array` + SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/') + + # export the onnx model, providing a sample batch + SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32)) + + +Once your model has been exported, you can import this into either `Sparsify `__ or `DeepSparse `__. diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py index 9558aed816..f52a408c48 100644 --- a/pl_bolts/callbacks/__init__.py +++ b/pl_bolts/callbacks/__init__.py @@ -2,6 +2,7 @@ from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.callbacks.printing import PrintTableMetricsCallback +from pl_bolts.callbacks.sparseml import SparseMLCallback from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.callbacks.torch_ort import ORTCallback from pl_bolts.callbacks.variational import LatentDimInterpolator @@ -20,4 +21,5 @@ "ConfusedLogitCallback", "TensorboardGenerativeModelImageSampler", "ORTCallback", + "SparseMLCallback", ] diff --git a/pl_bolts/callbacks/sparseml.py b/pl_bolts/callbacks/sparseml.py new file mode 100644 index 0000000000..ca0b40f750 --- /dev/null +++ b/pl_bolts/callbacks/sparseml.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.utils import _SPARSEML_AVAILABLE + +if _SPARSEML_AVAILABLE: + from sparseml.pytorch.optim import ScheduledModifierManager + from sparseml.pytorch.utils import ModuleExporter + + +class SparseMLCallback(Callback): + """Enables SparseML aware training. Requires a recipe to run during training. + + Args: + recipe_path: Path to a SparseML compatible yaml recipe. + More information at https://docs.neuralmagic.com/sparseml/source/recipes.html + """ + + def __init__(self, recipe_path): + if not _SPARSEML_AVAILABLE: + raise MisconfigurationException("SparseML has not be installed, install with pip install sparseml") + self.manager = ScheduledModifierManager.from_yaml(recipe_path) + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + optimizer = trainer.optimizers + + if len(optimizer) > 1: + raise MisconfigurationException("SparseML only supports training with one optimizer.") + optimizer = optimizer[0] + optimizer = self.manager.modify( + pl_module, optimizer, steps_per_epoch=self._num_training_steps_per_epoch(trainer), epoch=0 + ) + trainer.optimizers = [optimizer] + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.manager.finalize(pl_module) + + def _num_training_steps_per_epoch(self, trainer: Trainer) -> int: + """Total training steps inferred from the datamodule and devices.""" + if isinstance(trainer.limit_train_batches, int) and trainer.limit_train_batches != 0: + dataset_size = trainer.limit_train_batches + elif isinstance(trainer.limit_train_batches, float): + # limit_train_batches is a percentage of batches + dataset_size = len(trainer.datamodule.train_dataloader()) + dataset_size = int(dataset_size * trainer.limit_train_batches) + else: + dataset_size = len(trainer.datamodule.train_dataloader()) + + num_devices = max(1, trainer.num_gpus, trainer.num_processes) + if trainer.tpu_cores: + num_devices = max(num_devices, trainer.tpu_cores) + + effective_batch_size = trainer.accumulate_grad_batches * num_devices + max_estimated_steps = dataset_size // effective_batch_size + + if trainer.max_steps and trainer.max_steps < max_estimated_steps: + return trainer.max_steps + return max_estimated_steps + + @staticmethod + def export_to_sparse_onnx( + model: LightningModule, output_dir: str, sample_batch: Optional[torch.Tensor] = None + ) -> None: + """Exports the model to ONNX format.""" + with model._prevent_trainer_and_dataloaders_deepcopy(): + exporter = ModuleExporter(model, output_dir=output_dir) + sample_batch = sample_batch if sample_batch is not None else model.example_input_array + if sample_batch is None: + raise MisconfigurationException( + "To export the model, a sample batch must be passed via " + "``SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)`` " + "or an ``example_input_array`` property within the LightningModule" + ) + exporter.export_onnx(sample_batch=sample_batch) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 8a77cd66e7..834e3221ee 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -40,5 +40,6 @@ def _compare_version(package: str, op, version) -> bool: _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1") _PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") +_SPARSEML_AVAILABLE = _module_available("sparseml") __all__ = ["BatchGradientVerification"] diff --git a/requirements/test.txt b/requirements/test.txt index 7c862c329b..da446df50d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -12,3 +12,4 @@ mypy>=0.790 atari-py==0.2.6 # needed for RL scikit-learn>=0.23 +sparseml diff --git a/tests/callbacks/test_sparseml.py b/tests/callbacks/test_sparseml.py new file mode 100644 index 0000000000..006416dc3a --- /dev/null +++ b/tests/callbacks/test_sparseml.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import pytest +import torch +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.callbacks import SparseMLCallback +from pl_bolts.utils import _SPARSEML_AVAILABLE +from tests.helpers.boring_model import BoringModel + +if _SPARSEML_AVAILABLE: + from sparseml.pytorch.optim import RecipeManagerStepWrapper + + +@pytest.fixture +def recipe(): + return """ + version: 0.1.0 + modifiers: + - !EpochRangeModifier + start_epoch: 0.0 + end_epoch: 1.0 + + - !LearningRateModifier + start_epoch: 0 + end_epoch: -1.0 + update_frequency: -1.0 + init_lr: 0.005 + lr_class: MultiStepLR + lr_kwargs: {'milestones': [43, 60], 'gamma': 0.1} + + - !GMPruningModifier + start_epoch: 0 + end_epoch: 40 + update_frequency: 1.0 + init_sparsity: 0.05 + final_sparsity: 0.85 + mask_type: unstructured + params: __ALL__ + """ + + +@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.") +def test_train_sparse_ml_callback(tmpdir, recipe): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(trainer.optimizers[0], RecipeManagerStepWrapper) + + recipe_path = Path(tmpdir) / "recipe.yaml" + recipe_path.write_text(recipe) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=[SparseMLCallback(recipe_path=str(recipe_path)), TestCallback()], + ) + trainer.fit(model) + + sample_batch = torch.randn(1, 32) + output_dir = Path(tmpdir) / "model_export/" + SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch) + assert os.path.exists(output_dir) + + +@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.") +def test_fail_if_no_example_input_array_or_sample_batch(tmpdir, recipe): + model = BoringModel() + with pytest.raises(MisconfigurationException, match="To export the model, a sample batch must be passed"): + output_dir = Path(tmpdir) / "model_export/" + SparseMLCallback.export_to_sparse_onnx(model, output_dir) + + +@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.") +def test_fail_if_multiple_optimizers(tmpdir, recipe): + recipe_path = Path(tmpdir) / "recipe.yaml" + recipe_path.write_text(recipe) + + class TestModel(BoringModel): + def configure_optimizers(self): + return [torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters())], [] + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, callbacks=[SparseMLCallback(recipe_path=str(recipe_path))] + ) + with pytest.raises(MisconfigurationException, match="SparseML only supports training with one optimizer."): + trainer.fit(model) From 68f5c24816acfc1e7a9a8e3e429c02533ed4eea3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 13:21:49 +0100 Subject: [PATCH 2/8] Add CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 707b37b885..960e0ed7ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720)) +- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724)) + + ### Changed - Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701)) From 298f1935caa48fe8e606522f1a3e4b0c73d2966d Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 3 Sep 2021 14:30:00 +0100 Subject: [PATCH 3/8] Apply suggestions from code review Co-authored-by: Jirka Borovec --- docs/source/callbacks/sparseml.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst index 062d951798..35808ee631 100644 --- a/docs/source/callbacks/sparseml.rst +++ b/docs/source/callbacks/sparseml.rst @@ -23,7 +23,7 @@ This requires to import an ONNX model, which you can get from your ``LightningMo 2. Train with SparseMLCallback ------------------------------ -.. code-block:: python +.. testcode:: from pytorch_lightning import LightningModule, Trainer from pl_bolts.callbacks import SparseMLCallback @@ -41,7 +41,7 @@ This requires to import an ONNX model, which you can get from your ``LightningMo Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format. Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below. -.. code-block:: python +.. testcode:: import torch From 56b1d4540e377979611559217ebce2891d5a7928 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 14:30:25 +0100 Subject: [PATCH 4/8] Add sample model --- docs/source/callbacks/sparseml.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst index 35808ee631..b17369b4fc 100644 --- a/docs/source/callbacks/sparseml.rst +++ b/docs/source/callbacks/sparseml.rst @@ -45,6 +45,9 @@ Note this assumes either you have implemented the property ``example_input_array import torch + model = MyModel() + ... + # export the onnx model, using the `model.example_input_array` SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/') From 4f6df39947323bb00524e3bf4696b4ead158fa58 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 15:57:41 +0100 Subject: [PATCH 5/8] Address review, fix CI --- docs/source/callbacks/sparseml.rst | 2 ++ pl_bolts/callbacks/sparseml.py | 6 +++++- pl_bolts/utils/__init__.py | 4 +++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst index b17369b4fc..ddf7f77b02 100644 --- a/docs/source/callbacks/sparseml.rst +++ b/docs/source/callbacks/sparseml.rst @@ -28,6 +28,8 @@ This requires to import an ONNX model, which you can get from your ``LightningMo from pytorch_lightning import LightningModule, Trainer from pl_bolts.callbacks import SparseMLCallback + class MyModel(LightningModule): + ... model = MyModel() diff --git a/pl_bolts/callbacks/sparseml.py b/pl_bolts/callbacks/sparseml.py index ca0b40f750..de05088994 100644 --- a/pl_bolts/callbacks/sparseml.py +++ b/pl_bolts/callbacks/sparseml.py @@ -17,7 +17,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pl_bolts.utils import _SPARSEML_AVAILABLE +from pl_bolts.utils import _PL_GREATER_EQUAL_1_4_5, _SPARSEML_AVAILABLE, _TORCH_MAX_VERSION_1_8_1 if _SPARSEML_AVAILABLE: from sparseml.pytorch.optim import ScheduledModifierManager @@ -34,6 +34,10 @@ class SparseMLCallback(Callback): def __init__(self, recipe_path): if not _SPARSEML_AVAILABLE: + if not _PL_GREATER_EQUAL_1_4_5: + raise MisconfigurationException("SparseML requires PyTorch Lightning 1.4.5 or greater.") + if not _TORCH_MAX_VERSION_1_8_1: + raise MisconfigurationException("SparseML requires PyTorch version 1.8.1 or lower.") raise MisconfigurationException("SparseML has not be installed, install with pip install sparseml") self.manager = ScheduledModifierManager.from_yaml(recipe_path) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 834e3221ee..9f523d2953 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -39,7 +39,9 @@ def _compare_version(package: str, op, version) -> bool: _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1") _PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0") +_PL_GREATER_EQUAL_1_4_5 = _compare_version("pytorch_lightning", operator.ge, "1.4.5") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") -_SPARSEML_AVAILABLE = _module_available("sparseml") +_TORCH_MAX_VERSION_1_8_1 = _compare_version("torch", operator.le, "1.8.1") +_SPARSEML_AVAILABLE = _module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_1_8_1 __all__ = ["BatchGradientVerification"] From f4e740d281c0d80ef056b17b865fe0261de455ff Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 16:50:09 +0100 Subject: [PATCH 6/8] Add rst --- docs/source/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/index.rst b/docs/source/index.rst index 860b82ba12..c0eeea8689 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -24,6 +24,7 @@ Lightning-Bolts documentation callbacks/variational_callbacks callbacks/vision_callbacks callbacks/torch_ort + callbacks/sparseml .. toctree:: :maxdepth: 2 From 6c933514ee98ef90deaa49686242584f45ab0027 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 16:56:36 +0100 Subject: [PATCH 7/8] Add checks --- docs/source/callbacks/sparseml.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst index ddf7f77b02..54f8051b6b 100644 --- a/docs/source/callbacks/sparseml.rst +++ b/docs/source/callbacks/sparseml.rst @@ -24,6 +24,7 @@ This requires to import an ONNX model, which you can get from your ``LightningMo ------------------------------ .. testcode:: + :skipif: not _SPARSEML_AVAILABLE from pytorch_lightning import LightningModule, Trainer from pl_bolts.callbacks import SparseMLCallback @@ -44,6 +45,7 @@ Using the helper function, we handle any quantization/pruning internally and exp Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below. .. testcode:: + :skipif: not _SPARSEML_AVAILABLE import torch From d039520b8d8fecebf5df11cd49bc4e2e01a0c324 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 3 Sep 2021 17:16:05 +0100 Subject: [PATCH 8/8] Add import --- docs/source/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index ea38d1de68..c285ed475a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -405,5 +405,7 @@ def find_source(): import pytorch_lightning as pl from pytorch_lightning import Trainer, LightningModule +from pl_bolts.utils import _SPARSEML_AVAILABLE + """ coverage_skip_undoc_in_source = True