Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Compression for Tensorflow #2755

Merged
merged 12 commits into from
Aug 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ Tensorflow code
```python
from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model_graph, config_list)
pruner = LevelPruner(model, config_list)
pruner.compress()
```

@@ -117,17 +117,6 @@ FPGMPruner prune filters with the smallest geometric median.

### Usage

Tensorflow code
```python
from nni.compression.tensorflow import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
```

PyTorch code
```python
from nni.compression.torch import FPGMPruner
@@ -146,11 +135,6 @@ pruner.compress()
.. autoclass:: nni.compression.torch.FPGMPruner
```

##### Tensorflow
```eval_rst
.. autoclass:: nni.compression.tensorflow.FPGMPruner
```

## L1Filter Pruner

This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf.
@@ -383,12 +367,6 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AGPPruner
```

##### Tensorflow

```eval_rst
.. autoclass:: nni.compression.tensorflow.AGPPruner
```

***

## NetAdapt Pruner
@@ -620,4 +598,4 @@ pruner.compress(eval_args=[model], finetune_args=[model])

```eval_rst
.. autoclass:: nni.compression.torch.SensitivityPruner
```
```
12 changes: 1 addition & 11 deletions docs/zh_CN/Compressor/Pruner.md
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ TensorFlow 代码
```python
from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model_graph, config_list)
pruner = LevelPruner(model, config_list)
pruner.compress()
```

@@ -102,16 +102,6 @@ pruner.compress()

### 用法

TensorFlow 代码
```python
from nni.compression.tensorflow import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should keep Chinese doc unchanged, and let @JunweiSUN to update Chinese doc using crowdin

PyTorch 代码
```python
from nni.compression.torch import FPGMPruner
82 changes: 82 additions & 0 deletions examples/model_compress/model_prune_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse

import tensorflow as tf

import nni.compression.tensorflow

prune_config = {
'level': {
'dataset_name': 'mnist',
'model_name': 'naive',
'pruner_class': nni.compression.tensorflow.LevelPruner,
'config_list': [{
'sparsity': 0.9,
'op_types': ['default'],
}]
},
}


def get_dataset(dataset_name='mnist'):
assert dataset_name == 'mnist'

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., tf.newaxis] / 255.0
x_test = x_test[..., tf.newaxis] / 255.0
return (x_train, y_train), (x_test, y_test)


def create_model(model_name='naive'):
assert model_name == 'naive'
return tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=500),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(units=10),
tf.keras.layers.Softmax()
])


def create_pruner(model, pruner_name):
pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list)


def main(args):
model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name']
train_set, test_set = get_dataset(dataset_name)
model = create_model(model_name)

optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

print('start training')
model.fit(train_set[0], train_set[1], batch_size=args.batch_size, epochs=args.pretrain_epochs, validation_data=test_set)

print('start model pruning')
optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4)
pruner = create_pruner(model, args.pruner_name)
model = pruner.compress()
model.compile(optimizer=optimizer_finetune, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_set[0], train_set[1], batch_size=args.batch_size, epochs=args.prune_epochs, validation_data=test_set)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pruner_name', type=str, default='level')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--pretrain_epochs', type=int, default=10)
parser.add_argument('--prune_epochs', type=int, default=10)

args = parser.parse_args()
main(args)
5 changes: 2 additions & 3 deletions src/sdk/pynni/nni/compression/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
from .compressor import Compressor, Pruner
from .pruning import *
195 changes: 0 additions & 195 deletions src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py

This file was deleted.

Loading