Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dev new pruner #1679

Merged
merged 27 commits into from
Nov 21, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/en_US/Compressor/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [Filter Pruner](./Pruner.md#filter-pruner) | Pruning least important filters in convolution layers(PRUNING FILTERS FOR EFFICIENT CONVNETS)[Reference Paper](https://arxiv.org/abs/1608.08710) |
| [Slim Pruner](./Pruner.md#slim-pruner) | Pruning channels in convolution layers by pruning scaling factors in BN layers(Learning Efficient Convolutional Networks through Network Slimming)[Reference Paper](https://arxiv.org/abs/1708.06519) |
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
| [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)|
Expand Down Expand Up @@ -115,7 +117,7 @@ class YourPruner(nni.compression.tensorflow.Pruner):
def calc_mask(self, weight, config, **kwargs):
# weight is the target weight tensor
# config is the selected dict object in config_list for this layer
# kwargs contains op, op_types, and op_name
# kwargs contains op, op_type, and op_name
# design your mask and return your mask
return your_mask

Expand Down Expand Up @@ -158,7 +160,7 @@ class YourPruner(nni.compression.tensorflow.Quantizer):
def quantize_weight(self, weight, config, **kwargs):
# weight is the target weight tensor
# config is the selected dict object in config_list for this layer
# kwargs contains op, op_types, and op_name
# kwargs contains op, op_type, and op_name
# design your quantizer and return new weight
return new_weight

Expand Down
54 changes: 53 additions & 1 deletion docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ config_list = [{
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': 'default'
'op_types': ['default']
}]
pruner = AGP_Pruner(config_list)
pruner(model)
Expand All @@ -92,3 +92,55 @@ You can view example for more information

***

## Filter Pruner

In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf.
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

![](../../img/filter_pruner.png)
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

> The procedure of pruning m filters from the ith convolutional layer is as follows:
> 1. For each filter $F_{i,j}$ , calculate the sum of its absolute kernel weights $s_j = \sum_{l=1}^{n_i}\sum|K_l|$
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
> 2. Sort the filters by $s_j$.
> 3. Prune $m$ filters with the smallest sum values and their corresponding feature maps. The
> kernels in the next convolutional layer corresponding to the pruned feature maps are also
> removed.
> 4. A new kernel matrix is created for both the $i$th and $i+1$th layers, and the remaining kernel
> weights are copied to the new model.
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

### Usage

PyTorch code

```
from nni.compression.torch import FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = FilterPruner(config_list)
pruner(model)
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
```

#### User configuration for Filter Pruner

- **sparsity:** This is to specify the sparsity operations to be compressed to
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

## Slim Pruner
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved

In ['Learning Efficient Convolutional Networks through Network Slimming'](https://arxiv.org/pdf/1708.06519.pdf), authors Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang.

![](../../img/slim_pruner.png)

> Slim Pruner prunes channels in the convolution layers by masking corresponding scaling factors in the later BN layers, L1 regularization on the scaling factors should be applied in batch normalization (BN) layers while training, scaling factors of BN layers are **globally ranked** while pruning, so the sparse model can be automatically found given sparsity.

### Usage

PyTorch code

```
from nni.compression.torch import SlimPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = SlimPruner(config_list)
pruner(model)
```

#### User configuration for Filter Pruner

- **sparsity:** This is to specify the sparsity operations to be compressed to
Binary file added docs/img/filter_pruner.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/slim_pruner.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def calc_mask(self, weight, config, op_name, **kwargs):
class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
weights to achieve a preset level of network sparsity.

Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
Expand Down
89 changes: 87 additions & 2 deletions src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from .compressor import Pruner

__all__ = ['LevelPruner', 'AGP_Pruner']
__all__ = ['LevelPruner', 'AGP_Pruner', 'FilterPruner', 'SlimPruner']

logger = logging.getLogger('torch pruner')

Expand Down Expand Up @@ -102,5 +102,90 @@ def compute_target_sparsity(self, config):
def update_epoch(self, epoch):
if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list:
for k in self.if_init_list.keys():
self.if_init_list[k] = True


class FilterPruner(Pruner):
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
"""A structured pruning algorithm that prunes the filters of smallest magnitude
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
weights sum in the convolution layers to achieve a preset level of network sparsity.

Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf,
"PRUNING FILTERS FOR EFFICIENT CONVNETS", 2017 ICLR
https://arxiv.org/abs/1608.08710
"""

def __init__(self, config_list):
"""
config_list: supported keys:
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
- sparsity
"""
super().__init__(config_list)
self.mask_list = {}
self.if_init_list = {}

def calc_mask(self, weight, config, op_name, op_type, **kwargs):
assert op_type == 'Conv2d', 'FilterPruner only supports 2d convolution layer pruning'
if self.if_init_list.get(op_name, True):
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
kernels = weight.shape[0]
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
w_abs = weight.abs()
k = int(kernels * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(kernels, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
self.mask_list.update({op_name: mask})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
return mask


class SlimPruner(Pruner):
"""A structured pruning algorithm that prunes channels by pruning the weights of BN layers

Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""

def __init__(self, config_list):
"""
config_list: supported keys:
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
- sparsity
"""
super().__init__(config_list)
self.mask_list = {}
self.if_init_list = {}

def bind_model(self, model):
weight_list = []
config = self._config_list[0]
op_types = config.get('op_types')
op_names = config.get('op_names')
if op_types is not None:
assert op_types == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
for name, m in model.named_modules():
if type(m).__name__ == 'BatchNorm2d':
weight_list.append(m.weight.data.clone())
else:
for name, m in model.named_modules():
if name in op_names:
assert type(
m).__name__ == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(m.weight.data.clone())
tanglang96 marked this conversation as resolved.
Show resolved Hide resolved
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max()

def calc_mask(self, weight, config, op_name, op_type, **kwargs):
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if self.if_init_list.get(op_name, True):
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
self.mask_list.update({op_name: mask})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
return mask