Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support lottery ticket hypothesis #1685

Merged
merged 28 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en_US/Compressor/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [Lottery Ticket Pruner](./Pruner.md#agp-pruner) | The pruning process used by "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks". It prunes a model iteratively. [Reference Paper](https://arxiv.org/abs/1803.03635)|
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
| [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)|
Expand Down
41 changes: 41 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,44 @@ You can view example for more information

***

## Lottery Ticket Hypothesis
[The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635), authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis, and articulate the *lottery ticket hypothesis*: dense, randomly-initialized, feed-forward networks contain subnetworks (*winning tickets*) that -- when trained in isolation -- reach test accuracy comparable to the original network in a similar number of iterations.

In this paper, the authors use the following process to prune a model, called *iterative prunning*:
>1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
>2. Train the network for j iterations, arriving at parameters theta_j.
>3. Prune p% of the parameters in theta_j, creating a mask m.
>4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
>5. Repeat step 2, 3, and 4.

If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.

### Usage

PyTorch code
```python
from nni.compression.torch import LotteryTicketPruner
config_list = [{
'prune_iterations': 5,
'sparsity': 0.8,
'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, config_list, optimizer)
pruner.compress()
for _ in pruner.get_prune_iterations():
pruner.prune_iteration_start()
for epoch in range(epoch_num):
...
```

The above configuration means that there are 5 times of iterative pruning. As the 5 times iterative pruning are executed in the same run, LotteryTicketPruner needs `model` and `optimizer` (**Note that should add `lr_scheduler` if used**) to reset their states every time a new prune iteration starts. Please use `get_prune_iterations` to get the pruning iterations, and invoke `prune_iteration_start` at the beginning of each iteration. `epoch_num` is better to be large enough for model convergence, because the hypothesis is that the performance (accuracy) got in latter rounds with high sparsity could be comparable with that got in the first round.


*Tensorflow version will be supported later.*

#### User configuration for LotteryTicketPruner

* **prune_iterations:** The number of rounds for the iterative pruning, i.e., the number of iterative pruning.
* **sparsity:** The final sparsity when the compression is done.

***
95 changes: 95 additions & 0 deletions examples/model_compress/lottery_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from nni.compression.torch import AGP_Pruner, LevelPruner, LotteryTicketPruner
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms


class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))

def apply_pruner(model, optimizer):
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.8,
'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()
return pruner

def main():
torch.manual_seed(0)
device = torch.device('cpu')

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

pruner = apply_pruner(model, optimizer)

for _ in pruner.get_prune_iterations():
pruner.prune_iteration_start()
for epoch in range(5):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28])


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion examples/model_compress/main_torch_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ def main():
test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28])


if __name__ == '__main__':
main()

QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/compression/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
from .lottery_ticket import LotteryTicketPruner
41 changes: 21 additions & 20 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, name, module):

self._forward = None


class Compressor:
"""
Abstract base PyTorch compressor
Expand All @@ -32,7 +31,21 @@ def __init__(self, model, config_list):
"""
self.bound_model = model
self.config_list = config_list
self.modules_to_compress = []
self.modules_to_compress = None

def detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
if self.modules_to_compress is None:
self.modules_to_compress = []
for name, module in self.bound_model.named_modules():
layer = LayerInfo(name, module)
config = self.select_config(layer)
if config is not None:
self.modules_to_compress.append((layer, config))
return self.modules_to_compress

def compress(self):
"""
Expand All @@ -41,26 +54,11 @@ def compress(self):
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
"""
for name, module in self.bound_model.named_modules():
layer = LayerInfo(name, module)
config = self.select_config(layer)
if config is not None:
self._instrument_layer(layer, config)
self.modules_to_compress.append((layer, config))
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
return self.bound_model

def get_modules_to_compress(self):
"""
To obtain all the to-be-compressed layers.

Returns
-------
self.modules_to_compress : list
a list of the layers, each of which is a tuple (`layer`, `config`),
`layer` is `LayerInfo`, `config` is a `dict`
"""
return self.modules_to_compress

def select_config(self, layer):
"""
Find the configuration for `layer` by parsing `self.config_list`
Expand Down Expand Up @@ -204,6 +202,9 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
input_shape : list or tuple
input shape to onnx model
"""
if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
print('Warning: You may not use self.mask_dict in base Pruner class to record masks')
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules():
mask = self.mask_dict.get(name)
Expand Down
155 changes: 155 additions & 0 deletions src/sdk/pynni/nni/compression/torch/lottery_ticket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import copy
import logging
import torch
from .compressor import Pruner

_logger = logging.getLogger(__name__)


class LotteryTicketPruner(Pruner):
"""
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks",
following NNI model compression interface.

1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
2. Train the network for j iterations, arriving at parameters theta_j.
3. Prune p% of the parameters in theta_j, creating a mask m.
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
optimizer : pytorch optimizer
The optimizer for the model
lr_scheduler : pytorch lr scheduler
The lr scheduler for the model if used
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)

# save init weights and optimizer
self.reset_weights = reset_weights
if self.reset_weights:
self._model = model
self._optimizer = optimizer
self._model_state = copy.deepcopy(model.state_dict())
self._optimizer_state = copy.deepcopy(optimizer.state_dict())
self._lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())

def _validate_config(self, config_list):
prune_iterations = None
for config in config_list:
assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations']
return prune_iterations

def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask.sum().item()
mask_size = mask.numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')

def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)

def _calc_mask(self, weight, sparsity, op_name):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
return mask

def calc_mask(self, layer, config, **kwargs):
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
"""
Generate mask for the given ``weight``.

Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
**kwargs
Not used
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
tensor
The mask for this weight
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask

def get_prune_iterations(self):
"""
Return the range for iterations.
In the first prune iteration, masks are all one, thus, add one more iteration

Returns
-------
list
A list for pruning iterations
"""
return range(self.prune_iterations + 1)

def prune_iteration_start(self):
"""
Control the pruning procedure on updated epoch number.
Should be called at the beginning of the epoch.

Parameters
----------
epoch : num
The current epoch number provided by user call
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.curr_prune_iteration is None:
self.curr_prune_iteration = 0
else:
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'

modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()

# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)