diff --git a/docs/en_US/Compressor/Overview.md b/docs/en_US/Compressor/Overview.md
index 7d0b360185..47aaec1e06 100644
--- a/docs/en_US/Compressor/Overview.md
+++ b/docs/en_US/Compressor/Overview.md
@@ -118,7 +118,8 @@ class YourPruner(nni.compressors.tf_compressor.TfPruner):
def update_epoch(self, epoch_num, sess):
pass
- def step(self):
+ # note for pytorch version, there is no sess in input arguments
+ def step(self, sess):
# can do some processing based on the model or weights binded
# in the func bind_model
pass
@@ -159,10 +160,18 @@ class YourPruner(nni.compressors.tf_compressor.TfQuantizer):
def update_epoch(self, epoch_num, sess):
pass
- def step(self):
+ # note for pytorch version, there is no sess in input arguments
+ def step(self, sess):
# can do some processing based on the model or weights binded
# in the func bind_model
pass
+
+ # you can also design your method
+ def your_method(self, your_input):
+ #your code
+
+ def bind_model(self, model):
+ #preprocess model
```
__[TODO]__ Will add another member function `quantize_layer_output`, as some quantization algorithms also quantize layers' output.
diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md
index c147d61db8..99787e350d 100644
--- a/docs/en_US/Compressor/Pruner.md
+++ b/docs/en_US/Compressor/Pruner.md
@@ -13,16 +13,21 @@ We first sort the weights in the specified layer by their absolute values. And t
Tensorflow code
```
-pruner = nni.compressors.tf_compressor.LevelPruner([{'sparsity':0.8,'support_type': 'default'}])
+configure_list = [{'sparsity':0.8,'support_type': 'default'}]
+pruner = nni.compressors.tf_compressor.LevelPruner(configure_list)
pruner(model_graph)
```
Pytorch code
```
-pruner = nni.compressors.torch_compressor.LevelPruner([{'sparsity':0.8,'support_type': 'default'}])
+configure_list = [{'sparsity':0.8,'support_type': 'default'}]
+pruner = nni.compressors.torch_compressor.LevelPruner(configure_list)
pruner(model)
```
+#### User configuration for LevelPruner
+* **sparsity:** This is to specify the sparsity operations to be compressed to
+
***
@@ -41,13 +46,29 @@ First, you should import pruner and add mask to model.
Tensorflow code
```
from nni.compressors.tfCompressor import AGPruner
-pruner = AGPruner(initial_sparsity=0, final_sparsity=0.8, start_epoch=1, end_epoch=10, frequency=1)
+configure_list = [{
+ 'initial_sparsity': 0,
+ 'final_sparsity': 0.8,
+ 'start_epoch': 1,
+ 'end_epoch': 10,
+ 'frequency': 1,
+ 'support_type': 'default'
+ }]
+pruner = AGPruner(configure_list)
pruner(tf.get_default_graph())
```
Pytorch code
```
from nni.compressors.torchCompressor import AGPruner
-pruner = AGPruner(initial_sparsity=0, final_sparsity=0.8, start_epoch=1, end_epoch=10, frequency=1)
+configure_list = [{
+ 'initial_sparsity': 0,
+ 'final_sparsity': 0.8,
+ 'start_epoch': 1,
+ 'end_epoch': 10,
+ 'frequency': 1,
+ 'support_type': 'default'
+ }]
+pruner = AGPruner(configure_list)
pruner(model)
```
@@ -62,6 +83,14 @@ Pytorch code
pruner.update_epoch(epoch)
```
You can view example for more information
+
+#### User configuration for AGPruner
+* **initial_sparsity:** This is to specify the sparsity when compressor starts to compress
+* **final_sparsity:** This is to specify the sparsity when compressor finishes to compress
+* **start_epoch:** This is to specify the epoch number when compressor starts to compress
+* **end_epoch:** This is to specify the epoch number when compressor finishes to compress
+* **frequency:** This is to specify every *frequency* number epochs compressor compress once
+
***
@@ -76,15 +105,15 @@ You can prune weight step by step and reach one target sparsity by SensitivityPr
Tensorflow code
```
from nni.compressors.tfCompressor import SensitivityPruner
-
-pruner = SensitivityPruner(sparsity = 0.8)
+configure_list = [{'sparsity':0.8,'support_type': 'default'}]
+pruner = SensitivityPruner(configure_list)
pruner(tf.get_default_graph())
```
Pytorch code
```
from nni.compressors.torchCompressor import SensitivityPruner
-
-pruner = SensitivityPruner(sparsity = 0.8)
+configure_list = [{'sparsity':0.8,'support_type': 'default'}]
+pruner = SensitivityPruner(configure_list)
pruner(model)
```
Like AGPruner, you should update mask information every epoch by adding code below
@@ -98,4 +127,8 @@ Pytorch code
pruner.update_epoch(epoch)
```
You can view example for more information
+
+#### User configuration for SensitivityPruner
+* **sparsity:** This is to specify the sparsity operations to be compressed to
+
***
diff --git a/docs/en_US/Compressor/Quantizer.md b/docs/en_US/Compressor/Quantizer.md
index 43d8e655b7..bc9a2eb365 100644
--- a/docs/en_US/Compressor/Quantizer.md
+++ b/docs/en_US/Compressor/Quantizer.md
@@ -5,7 +5,7 @@ Quantizer on NNI Compressor
## NaiveQuantizer
-We provide NaiveQuantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm.
+We provide NaiveQuantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm without any configure.
### Usage
tensorflow
@@ -16,6 +16,7 @@ pytorch
```
nni.compressors.torch_compressor.NaiveQuantizer()(model)
```
+
***
@@ -34,18 +35,24 @@ You can quantize your model to 8 bits with the code below before your training c
Tensorflow code
```
from nni.compressors.tfCompressor import QATquantizer
-quantizer = QATquantizer(q_bits = 8)
+configure_list = [{'q_bits':8, 'support_type':'default'}]
+quantizer = QATquantizer(configure_list)
quantizer(tf.get_default_graph())
```
Pytorch code
```
from nni.compressors.torchCompressor import QATquantizer
-quantizer = QATquantizer(q_bits = 8)
+configure_list = [{'q_bits':8, 'support_type':'default'}]
+quantizer = QATquantizer(configure_list)
quantizer(model)
```
You can view example for more information
+#### User configuration for QATquantizer
+* **q_bits:** This is to specify the q_bits operations to be quantized to
+
+
***
@@ -58,14 +65,19 @@ To implement DoReFaQuantizer, you can add code below before your training code
Tensorflow code
```
from nni.compressors.tfCompressor import DoReFaQuantizer
-quantizer = DoReFaQuantizer(q_bits = 8)
+configure_list = [{'q_bits':8, 'support_type':'default'}]
+quantizer = DoReFaQuantizer(configure_list)
quantizer(tf.get_default_graph())
```
Pytorch code
```
from nni.compressors.torchCompressor import DoReFaQuantizer
-quantizer = DoReFaQuantizer(q_bits = 8)
+configure_list = [{'q_bits':8, 'support_type':'default'}]
+quantizer = DoReFaQuantizer(configure_list)
quantizer(model)
```
You can view example for more information
+
+#### User configuration for QATquantizer
+* **q_bits:** This is to specify the q_bits operations to be quantized to
diff --git a/src/sdk/pynni/nni/compressors/tf_compressor/_nnimc_tf.py b/src/sdk/pynni/nni/compressors/tf_compressor/_nnimc_tf.py
index 85a73443ef..715f3a07cd 100644
--- a/src/sdk/pynni/nni/compressors/tf_compressor/_nnimc_tf.py
+++ b/src/sdk/pynni/nni/compressors/tf_compressor/_nnimc_tf.py
@@ -32,7 +32,7 @@ def compress(self, model):
"""
assert self._bound_model is None, "Each NNI compressor instance can only compress one model"
self._bound_model = model
- self.preprocess_model(model)
+ self.bind_model(model)
def compress_default_graph(self):
"""
@@ -42,7 +42,7 @@ def compress_default_graph(self):
self.compress(tf.get_default_graph())
- def preprocess_model(self, model):
+ def bind_model(self, model):
"""
This method is called when a model is bound to the compressor.
Users can optionally overload this method to do model-specific initialization.
diff --git a/src/sdk/pynni/nni/compressors/torch_compressor/_nnimc_torch.py b/src/sdk/pynni/nni/compressors/torch_compressor/_nnimc_torch.py
index 0311194957..7b484dccbb 100644
--- a/src/sdk/pynni/nni/compressors/torch_compressor/_nnimc_torch.py
+++ b/src/sdk/pynni/nni/compressors/torch_compressor/_nnimc_torch.py
@@ -31,10 +31,10 @@ def compress(self, model):
"""
assert self._bound_model is None, "Each NNI compressor instance can only compress one model"
self._bound_model = model
- self.preprocess_model(model)
+ self.bind_model(model)
- def preprocess_model(self, model):
+ def bind_model(self, model):
"""
This method is called when a model is bound to the compressor.
Users can optionally overload this method to do model-specific initialization.