diff --git a/CHANGELOG.md b/CHANGELOG.md
index 707b37b885..960e0ed7ac 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))
+- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724))
+
+
### Changed
- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
diff --git a/docs/source/callbacks/sparseml.rst b/docs/source/callbacks/sparseml.rst
new file mode 100644
index 0000000000..54f8051b6b
--- /dev/null
+++ b/docs/source/callbacks/sparseml.rst
@@ -0,0 +1,62 @@
+=================
+SparseML Callback
+=================
+
+`SparseML `__ allows you to leverage sparsity to improve inference times substantially.
+
+SparseML requires you to fine-tune your model with the ``SparseMLCallback`` + a SparseML Recipe. By training with the ``SparseMLCallback``, you can leverage the `DeepSparse `__ engine to exploit the introduced sparsity, resulting in large performance improvements.
+
+.. warning::
+
+ The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs.
+
+To use leverage SparseML & DeepSparse follow the below steps:
+
+1. Choose your Sparse Recipe
+----------------------------
+
+To choose a recipe, have a look at `recipes `__ and `Sparse Zoo `__.
+
+It may be easier to infer a recipe via the UI dashboard using `Sparsify `__ which allows you to tweak and configure a recipe.
+This requires to import an ONNX model, which you can get from your ``LightningModule`` by doing ``model.to_onnx(output_path)``.
+
+2. Train with SparseMLCallback
+------------------------------
+
+.. testcode::
+ :skipif: not _SPARSEML_AVAILABLE
+
+ from pytorch_lightning import LightningModule, Trainer
+ from pl_bolts.callbacks import SparseMLCallback
+
+ class MyModel(LightningModule):
+ ...
+
+ model = MyModel()
+
+ trainer = Trainer(
+ callbacks=SparseMLCallback(recipe_path='recipe.yaml')
+ )
+
+3. Export to ONNX!
+------------------
+
+Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format.
+Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below.
+
+.. testcode::
+ :skipif: not _SPARSEML_AVAILABLE
+
+ import torch
+
+ model = MyModel()
+ ...
+
+ # export the onnx model, using the `model.example_input_array`
+ SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/')
+
+ # export the onnx model, providing a sample batch
+ SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32))
+
+
+Once your model has been exported, you can import this into either `Sparsify `__ or `DeepSparse `__.
diff --git a/docs/source/conf.py b/docs/source/conf.py
index ea38d1de68..c285ed475a 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -405,5 +405,7 @@ def find_source():
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
+from pl_bolts.utils import _SPARSEML_AVAILABLE
+
"""
coverage_skip_undoc_in_source = True
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 860b82ba12..c0eeea8689 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -24,6 +24,7 @@ Lightning-Bolts documentation
callbacks/variational_callbacks
callbacks/vision_callbacks
callbacks/torch_ort
+ callbacks/sparseml
.. toctree::
:maxdepth: 2
diff --git a/pl_bolts/callbacks/__init__.py b/pl_bolts/callbacks/__init__.py
index 9558aed816..f52a408c48 100644
--- a/pl_bolts/callbacks/__init__.py
+++ b/pl_bolts/callbacks/__init__.py
@@ -2,6 +2,7 @@
from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.callbacks.printing import PrintTableMetricsCallback
+from pl_bolts.callbacks.sparseml import SparseMLCallback
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.callbacks.torch_ort import ORTCallback
from pl_bolts.callbacks.variational import LatentDimInterpolator
@@ -20,4 +21,5 @@
"ConfusedLogitCallback",
"TensorboardGenerativeModelImageSampler",
"ORTCallback",
+ "SparseMLCallback",
]
diff --git a/pl_bolts/callbacks/sparseml.py b/pl_bolts/callbacks/sparseml.py
new file mode 100644
index 0000000000..de05088994
--- /dev/null
+++ b/pl_bolts/callbacks/sparseml.py
@@ -0,0 +1,94 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import torch
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from pl_bolts.utils import _PL_GREATER_EQUAL_1_4_5, _SPARSEML_AVAILABLE, _TORCH_MAX_VERSION_1_8_1
+
+if _SPARSEML_AVAILABLE:
+ from sparseml.pytorch.optim import ScheduledModifierManager
+ from sparseml.pytorch.utils import ModuleExporter
+
+
+class SparseMLCallback(Callback):
+ """Enables SparseML aware training. Requires a recipe to run during training.
+
+ Args:
+ recipe_path: Path to a SparseML compatible yaml recipe.
+ More information at https://docs.neuralmagic.com/sparseml/source/recipes.html
+ """
+
+ def __init__(self, recipe_path):
+ if not _SPARSEML_AVAILABLE:
+ if not _PL_GREATER_EQUAL_1_4_5:
+ raise MisconfigurationException("SparseML requires PyTorch Lightning 1.4.5 or greater.")
+ if not _TORCH_MAX_VERSION_1_8_1:
+ raise MisconfigurationException("SparseML requires PyTorch version 1.8.1 or lower.")
+ raise MisconfigurationException("SparseML has not be installed, install with pip install sparseml")
+ self.manager = ScheduledModifierManager.from_yaml(recipe_path)
+
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ optimizer = trainer.optimizers
+
+ if len(optimizer) > 1:
+ raise MisconfigurationException("SparseML only supports training with one optimizer.")
+ optimizer = optimizer[0]
+ optimizer = self.manager.modify(
+ pl_module, optimizer, steps_per_epoch=self._num_training_steps_per_epoch(trainer), epoch=0
+ )
+ trainer.optimizers = [optimizer]
+
+ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ self.manager.finalize(pl_module)
+
+ def _num_training_steps_per_epoch(self, trainer: Trainer) -> int:
+ """Total training steps inferred from the datamodule and devices."""
+ if isinstance(trainer.limit_train_batches, int) and trainer.limit_train_batches != 0:
+ dataset_size = trainer.limit_train_batches
+ elif isinstance(trainer.limit_train_batches, float):
+ # limit_train_batches is a percentage of batches
+ dataset_size = len(trainer.datamodule.train_dataloader())
+ dataset_size = int(dataset_size * trainer.limit_train_batches)
+ else:
+ dataset_size = len(trainer.datamodule.train_dataloader())
+
+ num_devices = max(1, trainer.num_gpus, trainer.num_processes)
+ if trainer.tpu_cores:
+ num_devices = max(num_devices, trainer.tpu_cores)
+
+ effective_batch_size = trainer.accumulate_grad_batches * num_devices
+ max_estimated_steps = dataset_size // effective_batch_size
+
+ if trainer.max_steps and trainer.max_steps < max_estimated_steps:
+ return trainer.max_steps
+ return max_estimated_steps
+
+ @staticmethod
+ def export_to_sparse_onnx(
+ model: LightningModule, output_dir: str, sample_batch: Optional[torch.Tensor] = None
+ ) -> None:
+ """Exports the model to ONNX format."""
+ with model._prevent_trainer_and_dataloaders_deepcopy():
+ exporter = ModuleExporter(model, output_dir=output_dir)
+ sample_batch = sample_batch if sample_batch is not None else model.example_input_array
+ if sample_batch is None:
+ raise MisconfigurationException(
+ "To export the model, a sample batch must be passed via "
+ "``SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)`` "
+ "or an ``example_input_array`` property within the LightningModule"
+ )
+ exporter.export_onnx(sample_batch=sample_batch)
diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py
index 8a77cd66e7..9f523d2953 100644
--- a/pl_bolts/utils/__init__.py
+++ b/pl_bolts/utils/__init__.py
@@ -39,6 +39,9 @@ def _compare_version(package: str, op, version) -> bool:
_MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib")
_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1")
_PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0")
+_PL_GREATER_EQUAL_1_4_5 = _compare_version("pytorch_lightning", operator.ge, "1.4.5")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")
+_TORCH_MAX_VERSION_1_8_1 = _compare_version("torch", operator.le, "1.8.1")
+_SPARSEML_AVAILABLE = _module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_1_8_1
__all__ = ["BatchGradientVerification"]
diff --git a/requirements/test.txt b/requirements/test.txt
index 7c862c329b..da446df50d 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -12,3 +12,4 @@ mypy>=0.790
atari-py==0.2.6 # needed for RL
scikit-learn>=0.23
+sparseml
diff --git a/tests/callbacks/test_sparseml.py b/tests/callbacks/test_sparseml.py
new file mode 100644
index 0000000000..006416dc3a
--- /dev/null
+++ b/tests/callbacks/test_sparseml.py
@@ -0,0 +1,104 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from pathlib import Path
+
+import pytest
+import torch
+from pytorch_lightning import Callback, Trainer
+from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from pl_bolts.callbacks import SparseMLCallback
+from pl_bolts.utils import _SPARSEML_AVAILABLE
+from tests.helpers.boring_model import BoringModel
+
+if _SPARSEML_AVAILABLE:
+ from sparseml.pytorch.optim import RecipeManagerStepWrapper
+
+
+@pytest.fixture
+def recipe():
+ return """
+ version: 0.1.0
+ modifiers:
+ - !EpochRangeModifier
+ start_epoch: 0.0
+ end_epoch: 1.0
+
+ - !LearningRateModifier
+ start_epoch: 0
+ end_epoch: -1.0
+ update_frequency: -1.0
+ init_lr: 0.005
+ lr_class: MultiStepLR
+ lr_kwargs: {'milestones': [43, 60], 'gamma': 0.1}
+
+ - !GMPruningModifier
+ start_epoch: 0
+ end_epoch: 40
+ update_frequency: 1.0
+ init_sparsity: 0.05
+ final_sparsity: 0.85
+ mask_type: unstructured
+ params: __ALL__
+ """
+
+
+@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.")
+def test_train_sparse_ml_callback(tmpdir, recipe):
+ class TestCallback(Callback):
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ assert isinstance(trainer.optimizers[0], RecipeManagerStepWrapper)
+
+ recipe_path = Path(tmpdir) / "recipe.yaml"
+ recipe_path.write_text(recipe)
+
+ model = BoringModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ fast_dev_run=True,
+ callbacks=[SparseMLCallback(recipe_path=str(recipe_path)), TestCallback()],
+ )
+ trainer.fit(model)
+
+ sample_batch = torch.randn(1, 32)
+ output_dir = Path(tmpdir) / "model_export/"
+ SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)
+ assert os.path.exists(output_dir)
+
+
+@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.")
+def test_fail_if_no_example_input_array_or_sample_batch(tmpdir, recipe):
+ model = BoringModel()
+ with pytest.raises(MisconfigurationException, match="To export the model, a sample batch must be passed"):
+ output_dir = Path(tmpdir) / "model_export/"
+ SparseMLCallback.export_to_sparse_onnx(model, output_dir)
+
+
+@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.")
+def test_fail_if_multiple_optimizers(tmpdir, recipe):
+ recipe_path = Path(tmpdir) / "recipe.yaml"
+ recipe_path.write_text(recipe)
+
+ class TestModel(BoringModel):
+ def configure_optimizers(self):
+ return [torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters())], []
+
+ model = TestModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir, fast_dev_run=True, callbacks=[SparseMLCallback(recipe_path=str(recipe_path))]
+ )
+ with pytest.raises(MisconfigurationException, match="SparseML only supports training with one optimizer."):
+ trainer.fit(model)