forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #73 from microsoft/master
pull code
- Loading branch information
Showing
56 changed files
with
1,653 additions
and
902 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
## Overview | ||
The model compression framework has two main components: `pruner` and `module wrapper`. | ||
|
||
### pruner | ||
A `pruner` is responsible for : | ||
1. provide a `cal_mask` method that calculates masks for weight and bias. | ||
2. replace the module with `module wrapper` based on config. | ||
3. modify the optimizer so that the `cal_mask` method is called every time the `step` method is called. | ||
|
||
### module wrapper | ||
A `module wrapper` is a module containing : | ||
1. the origin module | ||
2. some buffers used by `cal_mask` | ||
3. a new forward method that applies masks before running the original forward method. | ||
|
||
the reasons to use `module wrapper` : | ||
1. some buffers are needed by `cal_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated. | ||
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method. | ||
|
||
## How it works | ||
A basic pruner usage: | ||
```python | ||
configure_list = [{ | ||
'sparsity': 0.7, | ||
'op_types': ['BatchNorm2d'], | ||
}] | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) | ||
pruner = SlimPruner(model, configure_list, optimizer) | ||
model = pruner.compress() | ||
``` | ||
|
||
A pruner receive model, config and optimizer as arguments. In the `__init__` method, the `step` method of the optimizer is replaced with a new `step` method that calls `cal_mask`. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a `module wrapper`. Afterward, the new model and new optimizer are returned, which can be trained as before. `compress` method will calculate the default masks. | ||
|
||
## Implement a new pruning algorithm | ||
Implementing a new pruning algorithm requires implementing a new `pruner` class, which should subclass `Pruner` and override the `cal_mask` method. The `cal_mask` is called by`optimizer.step` method. | ||
The `Pruner` base class provided basic functionality listed above, for example, replacing modules and patching optimizer. | ||
|
||
A basic pruner look likes this: | ||
```python | ||
class NewPruner(Pruner): | ||
def __init__(self, model, config_list, optimizer) | ||
super().__init__(model, config_list, optimizer) | ||
# do some initialization | ||
|
||
def calc_mask(self, wrapper, **kwargs): | ||
# do something to calculate weight_mask | ||
wrapper.weight_mask = weight_mask | ||
``` | ||
### Set wrapper attribute | ||
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`. | ||
|
||
```python | ||
class NewPruner(Pruner): | ||
def __init__(self, model, config_list, optimizer): | ||
super().__init__(model, config_list, optimizer) | ||
self.set_wrappers_attribute("if_calculated", False) | ||
|
||
def calc_mask(self, wrapper): | ||
# do something to calculate weight_mask | ||
if wrapper.if_calculated: | ||
pass | ||
else: | ||
wrapper.if_calculated = True | ||
# update masks | ||
``` | ||
|
||
### Collect data during forward | ||
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module. | ||
|
||
```python | ||
class ActivationRankFilterPruner(Pruner): | ||
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1): | ||
super().__init__(model, config_list, optimizer) | ||
self.set_wrappers_attribute("if_calculated", False) | ||
self.set_wrappers_attribute("collected_activation", []) | ||
self.statistics_batch_num = statistics_batch_num | ||
|
||
def collector(module_, input_, output): | ||
if len(module_.collected_activation) < self.statistics_batch_num: | ||
module_.collected_activation.append(self.activation(output.detach().cpu())) | ||
self.add_activation_collector(collector) | ||
assert activation in ['relu', 'relu6'] | ||
if activation == 'relu': | ||
self.activation = torch.nn.functional.relu | ||
elif activation == 'relu6': | ||
self.activation = torch.nn.functional.relu6 | ||
else: | ||
self.activation = None | ||
``` | ||
The collector function will be called each time the forward method runs. | ||
|
||
Users can also remove this collector like this: | ||
```python | ||
collector_id = self.add_activation_collector(collector) | ||
# ... | ||
self.remove_activation_collector(collector_id) | ||
``` | ||
|
||
### Multi-GPU support | ||
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective. | ||
Since `cal_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
**Run an Experiment on DLTS** | ||
=== | ||
NNI supports running an experiment on [DLTS](https://github.com/microsoft/DLWorkspace.git), called dlts mode. Before starting to use NNI dlts mode, you should have an account to access DLTS dashboard. | ||
|
||
## Setup Environment | ||
|
||
Step 1. Choose a cluster from DLTS dashboard, ask administrator for the cluster dashboard URL. | ||
|
||
![Choose Cluster](../../img/dlts-step1.png) | ||
|
||
Step 2. Prepare a NNI config YAML like the following: | ||
|
||
```yaml | ||
# Set this field to "dlts" | ||
trainingServicePlatform: dlts | ||
authorName: your_name | ||
experimentName: auto_mnist | ||
trialConcurrency: 2 | ||
maxExecDuration: 3h | ||
maxTrialNum: 100 | ||
searchSpacePath: search_space.json | ||
useAnnotation: false | ||
tuner: | ||
builtinTunerName: TPE | ||
classArgs: | ||
optimize_mode: maximize | ||
trial: | ||
command: python3 mnist.py | ||
codeDir: . | ||
gpuNum: 1 | ||
image: msranni/nni | ||
# Configuration to access DLTS | ||
dltsConfig: | ||
dashboard: # Ask administrator for the cluster dashboard URL | ||
``` | ||
Remember to fill the cluster dashboard URL to the last line. | ||
Step 3. Open your working directory of the cluster, paste the NNI config as well as related code to a directory. | ||
![Copy Config](../../img/dlts-step3.png) | ||
Step 4. Submit a NNI manager job to the specified cluster. | ||
![Submit Job](../../img/dlts-step4.png) | ||
Step 5. Go to Endpoints tab of the newly created job, click the Port 40000 link to check trial's information. | ||
![View NNI WebUI](../../img/dlts-step5.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.