Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dev new pruner #1679

Merged
merged 27 commits into from
Nov 21, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/en_US/Compressor/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [Filter Pruner](./Pruner.md#filter-pruner) | Pruning least important filters in convolution layers(PRUNING FILTERS FOR EFFICIENT CONVNETS)[Reference Paper](https://arxiv.org/abs/1608.08710) |
| [Slim Pruner](./Pruner.md#slim-pruner) | Pruning channels in convolution layers by pruning scaling factors in BN layers(Learning Efficient Convolutional Networks through Network Slimming)[Reference Paper](https://arxiv.org/abs/1708.06519) |
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
| [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)|
Expand Down
54 changes: 53 additions & 1 deletion docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ config_list = [{
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': 'default'
'op_types': ['default']
}]
pruner = AGP_Pruner(model, config_list)
pruner.compress()
Expand All @@ -92,3 +92,55 @@ You can view example for more information

***

## Filter Pruner

In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf.
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

![](../../img/filter_pruner.png)
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

> The procedure of pruning m filters from the ith convolutional layer is as follows:
> 1. For each filter $F_{i,j}$ , calculate the sum of its absolute kernel weights $s_j = \sum_{l=1}^{n_i}\sum|K_l|$
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
> 2. Sort the filters by $s_j$.
> 3. Prune $m$ filters with the smallest sum values and their corresponding feature maps. The
> kernels in the next convolutional layer corresponding to the pruned feature maps are also
> removed.
> 4. A new kernel matrix is created for both the $i$th and $i+1$th layers, and the remaining kernel
> weights are copied to the new model.
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

### Usage

PyTorch code

```
from nni.compression.torch import FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = FilterPruner(config_list)
pruner(model)
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
```

#### User configuration for Filter Pruner

- **sparsity:** This is to specify the sparsity operations to be compressed to
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

## Slim Pruner
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

In ['Learning Efficient Convolutional Networks through Network Slimming'](https://arxiv.org/pdf/1708.06519.pdf), authors Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang.

![](../../img/slim_pruner.png)

> Slim Pruner prunes channels in the convolution layers by masking corresponding scaling factors in the later BN layers, L1 regularization on the scaling factors should be applied in batch normalization (BN) layers while training, scaling factors of BN layers are **globally ranked** while pruning, so the sparse model can be automatically found given sparsity.

### Usage

PyTorch code

```
from nni.compression.torch import SlimPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = SlimPruner(config_list)
pruner(model)
```

#### User configuration for Filter Pruner

- **sparsity:** This is to specify the sparsity operations to be compressed to
Binary file added docs/img/filter_pruner.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/slim_pruner.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
147 changes: 147 additions & 0 deletions examples/model_compress/main_filter_pruner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from nni.compression.torch import AGP_Pruner
import math
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms


class vgg(nn.Module):
def __init__(self, init_weights=True):
super(vgg, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.cfg = cfg
self.feature = self.make_layers(cfg, True)
num_classes = 10
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
if init_weights:
self._initialize_weights()

def make_layers(self, cfg, batch_norm=True):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)

def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(2)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(0.5)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


def main():
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)

model = vgg()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 150, 0)
for epoch in range(300):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16.pth')
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
# configure_list = [{
# 'initial_sparsity': 0,
# 'final_sparsity': 0.8,
# 'start_epoch': 0,
# 'end_epoch': 10,
# 'frequency': 1,
# 'op_types': ['default']
# }]
#
# pruner = AGP_Pruner(model, configure_list)
# pruner.compress()
# # you can also use compress(model) method
# # like that pruner.compress(model)
#
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# for epoch in range(10):
# pruner.update_epoch(epoch)
# print('# Epoch {} #'.format(epoch))
# train(model, device, train_loader, optimizer)
# test(model, device, test_loader)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def calc_mask(self, layer, config):
class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
weights to achieve a preset level of network sparsity.

Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
Expand Down
Loading