Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] auto compression #3631

Merged
merged 18 commits into from
May 28, 2021
119 changes: 119 additions & 0 deletions docs/en_US/Compression/AutoCompression.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
Auto Compression with NNI Experiment
====================================

If you want to compress your model, but don't know what compression algorithm to choose, or don't know what sparsity is suitable for your model, or just want to try more possibilities, auto compression may help you.
Users can choose different compression algorithms and define the algorithms' search space, then auto compression will launch an NNI experiment and try different compression algorithms with varying sparsity automatically.
Of course, in addition to the sparsity rate, users can also introduce other related parameters into the search space.
If you don't know what is search space or how to write search space, `this <./Tutorial/SearchSpaceSpec.rst>`__ is for your reference.
Auto compression using experience is similar to the NNI experiment in python.
The main differences are as follows:

* Use a generator to help generate search space object.
* Need to provide the model to be compressed, and the model should have already pre-trained.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have already been pre-trained

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it

* No need to set ``trial_command``, additional need to set ``auto_compress_module`` as ``AutoCompressExperiment`` input.

Generate search space
---------------------

Due to the extensive use of nested search space, we recommend a using generator to configure search space.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend using a?

The following is an example. Using ``add_config()`` add subconfig, then ``dumps()`` search space dict.

.. code-block:: python

from nni.algorithms.compression.pytorch.auto_compress import AutoCompressSearchSpaceGenerator

generator = AutoCompressSearchSpaceGenerator()
generator.add_config('level', [
{
"sparsity": {
"_type": "uniform",
"_value": [0.01, 0.99]
},
'op_types': ['default']
}
])
generator.add_config('qat', [
{
'quant_types': ['weight', 'output'],
'quant_bits': {
'weight': 8,
'output': 8
},
'op_types': ['Conv2d', 'Linear']
}])

search_space = generator.dumps()

Now we support the following pruners and quantizers:

.. code-block:: python

PRUNER_DICT = {
'level': LevelPruner,
'slim': SlimPruner,
'l1': L1FilterPruner,
'l2': L2FilterPruner,
'fpgm': FPGMPruner,
'taylorfo': TaylorFOWeightFilterPruner,
'apoz': ActivationAPoZRankFilterPruner,
'mean_activation': ActivationMeanRankFilterPruner
}

QUANTIZER_DICT = {
'naive': NaiveQuantizer,
'qat': QAT_Quantizer,
'dorefa': DoReFaQuantizer,
'bnn': BNNQuantizer
}

Provide user model for compression
----------------------------------

Users need to inherit ``AbstractAutoCompressModule`` and override the abstract class function.

.. code-block:: python

from nni.algorithms.compression.pytorch.auto_compress import AbstractAutoCompressModule

class AutoCompressModule(AbstractAutoCompressModule):
@classmethod
def model(cls) -> nn.Module:
...
return _model

@classmethod
def evaluator(cls) -> Callable[[nn.Module], float]:
...
return _evaluator

Users need to implement at least ``model()`` and ``evaluator()``.
If you use iterative pruner, you need to additional implement ``optimizer_factory()``, ``criterion()`` and ``sparsifying_trainer()``.
If you want to finetune the model after compression, you need to implement ``optimizer_factory()``, ``criterion()``, ``post_compress_finetuning_trainer()`` and ``post_compress_finetuning_epochs()``.
The ``optimizer_factory()`` should return a factory function, the input is an iterable variable, i.e. your ``model.parameters()``, and the output is an optimizer instance.
The two kinds of ``trainer()`` should return a trainer with input ``model, optimizer, criterion, current_epoch``.
The full abstract interface refers to :githublink:`interface.py <nni/algorithms/compression/pytorch/auto_compress/interface.py>`.
An example of ``AutoCompressModule`` implementation refers to :githublink:`auto_compress_module.py <examples/model_compress/auto_compress/torch/auto_compress_module.py>`.

Launch NNI experiment
---------------------

Similar to launch from python, the difference is no need to set ``trial_command`` and put the user-provided ``AutoCompressModule`` as ``AutoCompressExperiment`` input.

.. code-block:: python

from pathlib import Path
from nni.algorithms.compression.pytorch.auto_compress import AutoCompressExperiment

from auto_compress_module import AutoCompressModule

experiment = AutoCompressExperiment(AutoCompressModule, 'local')
experiment.config.experiment_name = 'auto compress torch example'
experiment.config.trial_concurrency = 1
experiment.config.max_trial_number = 10
experiment.config.search_space = search_space
experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken, this feature is for users who want to try different model compression algorithms without many effort. I think some of they would be confused about the experiment config setting if they are not familiar with NNI. Maybe we should tell user what these experiment parameters are or refer to related NNI doc which introduces parameters in detail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, trying to refactor and use the original config for less effort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor and now we can use experiment = AutoCompressExperiment(AutoCompressModule, 'local'), no need to use a specific config.


experiment.run(8088)
75 changes: 0 additions & 75 deletions docs/en_US/Compression/AutoPruningUsingTuners.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/en_US/Compression/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ Advanced Usage

Framework <./Framework>
Customize a new algorithm <./CustomizeCompressor>
Automatic Model Compression <./AutoPruningUsingTuners>
Automatic Model Compression (Beta) <./AutoCompression>
1 change: 0 additions & 1 deletion docs/en_US/Compression/pruning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ For details, please refer to the following tutorials:
Pruners <Pruner>
Dependency Aware Mode <DependencyAware>
Model Speedup <ModelSpeedup>
Automatic Model Pruning with NNI Tuners <AutoPruningUsingTuners>
130 changes: 130 additions & 0 deletions examples/model_compress/auto_compress/torch/auto_compress_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Optional, Iterable

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms

from nni.algorithms.compression.pytorch.auto_compress import AbstractAutoCompressModule

torch.manual_seed(1)

class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

_use_cuda = torch.cuda.is_available()

_train_kwargs = {'batch_size': 64}
_test_kwargs = {'batch_size': 1000}
if _use_cuda:
_cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
_train_kwargs.update(_cuda_kwargs)
_test_kwargs.update(_cuda_kwargs)

_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

_dataset1 = datasets.MNIST('./data', train=True, download=True, transform=_transform)
_dataset2 = datasets.MNIST('./data', train=False, transform=_transform)
_train_loader = torch.utils.data.DataLoader(_dataset1, **_train_kwargs)
_test_loader = torch.utils.data.DataLoader(_dataset2, **_test_kwargs)

_device = torch.device("cuda" if _use_cuda else "cpu")
_epoch = 2

def _train(model, optimizer, criterion, epoch):
model.train()
for data, target in _train_loader:
data, target = data.to(_device), target.to(_device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

def _test(model):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in _test_loader:
data, target = data.to(_device), target.to(_device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(_test_loader.dataset)
acc = 100 * correct / len(_test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(_test_loader.dataset), acc))
return acc

_model = LeNet().to(_device)

# _pre_train_optimizer = optim.Adadelta(_model.parameters(), lr=1)
# _scheduler = StepLR(_pre_train_optimizer, step_size=1, gamma=0.7)
# for i in range(_epoch):
# _train(_model, _pre_train_optimizer, F.nll_loss, i)
# _scheduler.step()

class AutoCompressModule(AbstractAutoCompressModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this module used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is implemented by user, and will import by import_ in AutoCompressEngine.trial_execute_compress().

It is strange to fix the code file name auto_compress_module.py, I will modify this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do users have to use the name "AutoCompressModule"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor and no need to fix name AutoCompressModule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring for the member functions

@classmethod
def model(cls) -> nn.Module:
return _model

@classmethod
def evaluator(cls) -> Callable[[nn.Module], float]:
return _test

@classmethod
def optimizer_factory(cls) -> Optional[Callable[[Iterable], optim.Optimizer]]:
def _optimizer_factory(params: Iterable):
return torch.optim.SGD(params, lr=0.01)
return _optimizer_factory

@classmethod
def criterion(cls) -> Optional[Callable]:
return F.nll_loss

@classmethod
def sparsifying_trainer(cls, compress_algorithm_name: str) -> Optional[Callable[[nn.Module, optim.Optimizer, Callable, int], None]]:
return _train

@classmethod
def post_compress_finetuning_trainer(cls, compress_algorithm_name: str) -> Optional[Callable[[nn.Module, optim.Optimizer, Callable, int], None]]:
return _train

@classmethod
def post_compress_finetuning_epochs(cls, compress_algorithm_name: str) -> int:
return 2
50 changes: 50 additions & 0 deletions examples/model_compress/auto_compress/torch/auto_compress_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path

from nni.algorithms.compression.pytorch.auto_compress import AutoCompressExperiment, AutoCompressSearchSpaceGenerator

from auto_compress_module import AutoCompressModule

generator = AutoCompressSearchSpaceGenerator()
generator.add_config('level', [
{
"sparsity": {
"_type": "uniform",
"_value": [0.01, 0.99]
},
'op_types': ['default']
}
])
generator.add_config('l1', [
{
"sparsity": {
"_type": "uniform",
"_value": [0.01, 0.99]
},
'op_types': ['Conv2d']
}
])
generator.add_config('qat', [
{
'quant_types': ['weight', 'output'],
'quant_bits': {
'weight': 8,
'output': 8
},
'op_types': ['Conv2d', 'Linear']
}])
search_space = generator.dumps()

experiment = AutoCompressExperiment(AutoCompressModule, 'local')
experiment.config.experiment_name = 'auto compress torch example'
experiment.config.trial_concurrency = 1
experiment.config.max_trial_number = 10
experiment.config.search_space = search_space
experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True

experiment.run(8088)
6 changes: 6 additions & 0 deletions nni/algorithms/compression/pytorch/auto_compress/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .experiment import AutoCompressExperiment
from .interface import AbstractAutoCompressModule
from .utils import AutoCompressSearchSpaceGenerator
Loading