Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support lottery ticket hypothesis #1685

Merged
merged 28 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/en_US/Compressor/LotteryTicketHypothesis.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Lottery Ticket Hypothesis on NNI
===

## Introduction

The paper [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) is mainly a measurement and analysis paper, it delivers very interesting insights. To support it on NNI, we mainly implement the training approach for finding *winning tickets*.

In this paper, the authors use the following process to prune a model, called *iterative prunning*:
>1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
>2. Train the network for j iterations, arriving at parameters theta_j.
>3. Prune p% of the parameters in theta_j, creating a mask m.
>4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
>5. Repeat step 2, 3, and 4.

If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.

## Reproduce Results

We try to reproduce the experiment result of the fully connected network on MNIST using the same configuration as in the paper. The code can be referred [here](https://github.com/microsoft/nni/tree/master/examples/model_compress/lottery_torch_mnist_fc.py). In this experiment, we prune 10 times, for each pruning we train the pruned model for 50 epochs.

![](../../img/lottery_ticket_mnist_fc.png)

The above figure shows the result of the fully connected network. `round0-sparsity-0.0` is the performance without pruning. Consistent with the paper, pruning around 80% also obtain similar performance compared to non-pruning, and converges a little faster. If pruning too much, e.g., larger than 94%, the accuracy becomes lower and convergence becomes a little slower. A little different from the paper, the trend of the data in the paper is relatively more clear.
1 change: 1 addition & 0 deletions docs/en_US/Compressor/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [Lottery Ticket Pruner](./Pruner.md#agp-pruner) | The pruning process used by "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks". It prunes a model iteratively. [Reference Paper](https://arxiv.org/abs/1803.03635)|
| [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)|
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
Expand Down
41 changes: 41 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,47 @@ You can view example for more information

***

## Lottery Ticket Hypothesis
[The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635), authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis, and articulate the *lottery ticket hypothesis*: dense, randomly-initialized, feed-forward networks contain subnetworks (*winning tickets*) that -- when trained in isolation -- reach test accuracy comparable to the original network in a similar number of iterations.

In this paper, the authors use the following process to prune a model, called *iterative prunning*:
>1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
>2. Train the network for j iterations, arriving at parameters theta_j.
>3. Prune p% of the parameters in theta_j, creating a mask m.
>4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
>5. Repeat step 2, 3, and 4.

If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.

### Usage

PyTorch code
```python
from nni.compression.torch import LotteryTicketPruner
config_list = [{
'prune_iterations': 5,
'sparsity': 0.8,
'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, config_list, optimizer)
pruner.compress()
for _ in pruner.get_prune_iterations():
pruner.prune_iteration_start()
for epoch in range(epoch_num):
...
```

The above configuration means that there are 5 times of iterative pruning. As the 5 times iterative pruning are executed in the same run, LotteryTicketPruner needs `model` and `optimizer` (**Note that should add `lr_scheduler` if used**) to reset their states every time a new prune iteration starts. Please use `get_prune_iterations` to get the pruning iterations, and invoke `prune_iteration_start` at the beginning of each iteration. `epoch_num` is better to be large enough for model convergence, because the hypothesis is that the performance (accuracy) got in latter rounds with high sparsity could be comparable with that got in the first round. Simple reproducing results can be found [here](./LotteryTicketHypothesis.md).


*Tensorflow version will be supported later.*

#### User configuration for LotteryTicketPruner

* **prune_iterations:** The number of rounds for the iterative pruning, i.e., the number of iterative pruning.
* **sparsity:** The final sparsity when the compression is done.

***
## FPGM Pruner
FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf)

Expand Down
Binary file added docs/img/lottery_ticket_mnist_fc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import LotteryTicketPruner

class fc1(nn.Module):

def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(28*28, 300),
nn.ReLU(inplace=True),
nn.Linear(300, 100),
nn.ReLU(inplace=True),
nn.Linear(100, num_classes),
)

def forward(self, x):
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

def train(model, train_loader, optimizer, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (imgs, targets) in enumerate(train_loader):
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()

def test(model, test_loader, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy


if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=0, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=0, drop_last=True)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

model = fc1().to("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
criterion = nn.CrossEntropyLoss()

configure_list = [{
'prune_iterations': 10,
'sparsity': 0.96,
'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()

for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(50):
loss = train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader, criterion)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/compression/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
from .lottery_ticket import LotteryTicketPruner
10 changes: 2 additions & 8 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, name, module):

self._forward = None


class Compressor:
"""
Abstract base PyTorch compressor
Expand All @@ -37,7 +36,6 @@ def __init__(self, model, config_list):
def detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.

The model will be instrumented and user should never edit it after calling this method.
"""
if self.modules_to_compress is None:
Expand All @@ -49,7 +47,6 @@ def detect_modules_to_compress(self):
self.modules_to_compress.append((layer, config))
return self.modules_to_compress


def compress(self):
"""
Compress the model with algorithm implemented by subclass.
Expand Down Expand Up @@ -218,6 +215,8 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
input_shape : list or tuple
input shape to onnx model
"""
if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules():
if name == "":
Expand All @@ -227,25 +226,20 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
mask_sum = mask.sum().item()
mask_num = mask.numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
print('Layer: %s Sparsity: %.2f' % (name, 1 - mask_sum / mask_num))
m.weight.data = m.weight.data.mul(mask)
else:
_logger.info('Layer: %s NOT compressed', name)
print('Layer: %s NOT compressed' % name)
torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
print('Model state_dict saved to %s' % model_path)
if mask_path is not None:
torch.save(self.mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
print('Mask dict saved to %s' % mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
print('Model in onnx with input shape %s saved to %s' % (input_data.shape, onnx_path))


class Quantizer(Compressor):
Expand Down
148 changes: 148 additions & 0 deletions src/sdk/pynni/nni/compression/torch/lottery_ticket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import copy
import logging
import torch
from .compressor import Pruner

_logger = logging.getLogger(__name__)


class LotteryTicketPruner(Pruner):
"""
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks",
following NNI model compression interface.

1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
2. Train the network for j iterations, arriving at parameters theta_j.
3. Prune p% of the parameters in theta_j, creating a mask m.
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
optimizer : pytorch optimizer
The optimizer for the model
lr_scheduler : pytorch lr scheduler
The lr scheduler for the model if used
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)

# save init weights and optimizer
self.reset_weights = reset_weights
if self.reset_weights:
self._model = model
self._optimizer = optimizer
self._model_state = copy.deepcopy(model.state_dict())
self._optimizer_state = copy.deepcopy(optimizer.state_dict())
self._lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())

def _validate_config(self, config_list):
prune_iterations = None
for config in config_list:
assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations']
return prune_iterations

def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask.sum().item()
mask_size = mask.numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')

def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)

def _calc_mask(self, weight, sparsity, op_name):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
return mask

def calc_mask(self, layer, config):
"""
Generate mask for the given ``weight``.

Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight

Returns
-------
tensor
The mask for this weight
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask

def get_prune_iterations(self):
"""
Return the range for iterations.
In the first prune iteration, masks are all one, thus, add one more iteration

Returns
-------
list
A list for pruning iterations
"""
return range(self.prune_iterations + 1)

def prune_iteration_start(self):
"""
Control the pruning procedure on updated epoch number.
Should be called at the beginning of the epoch.
"""
if self.curr_prune_iteration is None:
self.curr_prune_iteration = 0
else:
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'

modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()

# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)