This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Analysis utils #2435
Merged
Merged
Analysis utils #2435
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
0a4b7b0
Add analysis tools for sensitivity and topology.
712d982
Reformat the code and add several small new features.
202593c
Add the flops information rendering for the visulization.
8a7a799
Add the depedency rendering feature.
362441a
Update the interface of the SensitivityAnalysis
e69e78f
Add sensitivity rendering feature.
5823276
Add copyright and license.
4a70d79
Remove the unrelated files.
fc95dd7
Fix some typos.
a90c35e
Fix a small issue.
1909ff0
Fix a small issue.
2d13dda
Fix bug.
0e79624
Add compatibility with versions prior to torch-1.4.0.
96cea74
Add the mask conflict fix module.
6029603
Update the interface.
9beb1e2
Add unit test for analysis_utils.
6b25ff3
Fix the format warnings from pylint.
d0bda49
Add dependencies.
4154cf0
comment the visualization test temporarily.
83f0b26
update
388056c
Skip the test when the torch version is too old.
ccbcc6c
update
4ce8255
update according to the review comments.
0f70f67
update according to review comments.
2eac259
Add docs for analysis_utils.
810f20e
update rst
dcdc736
Merge branch 'master' of https://github.com/microsoft/nni into analys…
3b9f4df
Use TorchModuleGraph to analyze the shape dependency.
a214bb8
refactor the compression utils.
6d1a546
Update the corresponding unit test.
3aeb8a2
Remove the visualization modules and related dependencies.
bf72f3d
update
caced25
Update the docs.
69ea95e
Merge branch 'master' of https://github.com/microsoft/nni into analys…
c0e93e5
update docs.
6d7ea88
update docs.
e7790a2
update docs
b7671da
Update according the review comments.
1b9705b
Rename the unit test.
9d0519e
update docs
f563802
fix pylint errors
a24acd0
Update.
91d5f49
update
33178a2
fix grammar
7cab808
update doc
e8d4c31
Update the docs.
3351cef
update doc
db0ff63
update Docs.
7153bd7
remove unnecessray comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Python API Reference of Compression Utilities | ||
|
||
```eval_rst | ||
.. contents:: | ||
``` | ||
|
||
## Sensitivity Utilities | ||
|
||
```eval_rst | ||
.. autoclass:: nni.compression.torch.utils.sensitivity_analysis.SensitivityAnalysis | ||
:members: | ||
|
||
``` | ||
|
||
## Topology Utilities | ||
|
||
```eval_rst | ||
.. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency | ||
:members: | ||
|
||
.. autoclass:: nni.compression.torch.utils.mask_conflict.MaskConflict | ||
:members: | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Analysis Utils for Model Compression | ||
|
||
```eval_rst | ||
.. contents:: | ||
``` | ||
|
||
We provide several easy-to-use tools for users to analyze their model during model compression. | ||
|
||
## Sensitivity Analysis | ||
First, we provide a sensitivity analysis tool (**SensitivityAnalysis**) for users to analyze the sensitivity of each convolutional layer in their model. Specifically, the SensitiviyAnalysis gradually prune each layer of the model, and test the accuracy of the model at the same time. Note that, SensitivityAnalysis only prunes a layer once a time, and the other layers are set to their original weights. According to the accuracies of different convolutional layers under different sparsities, we can easily find out which layers the model accuracy is more sensitive to. | ||
|
||
### Usage | ||
|
||
The following codes show the basic usage of the SensitivityAnalysis. | ||
```python | ||
from nni.compression.torch.utils.sensitivity_analysis import SensitivityAnalysis | ||
|
||
def val(model): | ||
model.eval() | ||
total = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for batchid, (data, label) in enumerate(val_loader): | ||
data, label = data.cuda(), label.cuda() | ||
out = model(data) | ||
_, predicted = out.max(1) | ||
total += data.size(0) | ||
correct += predicted.eq(label).sum().item() | ||
return correct / total | ||
|
||
s_analyzer = SensitivityAnalysis(model=net, val_func=val) | ||
sensitivity = s_analyzer.analysis(val_args=[net]) | ||
os.makedir(outdir) | ||
s_analyzer.export(os.path.join(outdir, filename)) | ||
``` | ||
|
||
Two key parameters of SensitivityAnalysis are `model`, and `val_func`. `model` is the neural network that to be analyzed and the `val_func` is the validation function that returns the model accuracy/loss/ or other metrics on the validation dataset. Due to different scenarios may have different ways to calculate the loss/accuracy, so users should prepare a function that returns the model accuracy/loss on the dataset and pass it to SensitivityAnalysis. | ||
SensitivityAnalysis can export the sensitivity results as a csv file usage is shown in the example above. | ||
|
||
Futhermore, users can specify the sparsities values used to prune for each layer by optional parameter `sparsities`. | ||
```python | ||
s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75]) | ||
``` | ||
the SensitivityAnalysis will prune 25% 50% 75% weights gradually for each layer, and record the model's accuracy at the same time (SensitivityAnalysis only prune a layer once a time, the other layers are set to their original weights). If the sparsities is not set, SensitivityAnalysis will use the numpy.arange(0.1, 1.0, 0.1) as the default sparsity values. | ||
|
||
Users can also speed up the progress of sensitivity analysis by the early_stop_mode and early_stop_value option. By default, the SensitivityAnalysis will test the accuracy under all sparsities for each layer. In contrast, when the early_stop_mode and early_stop_value are set, the sensitivity analysis for a layer will stop, when the accuracy/loss has already met the threshold set by early_stop_value. We support four early stop modes: minimize, maximize, dropped, raised. | ||
|
||
minimize: The analysis stops when the validation metric return by the val_func lower than `early_stop_value`. | ||
|
||
maximize: The analysis stops when the validation metric return by the val_func larger than `early_stop_value`. | ||
|
||
dropped: The analysis stops when the validation metric has dropped by `early_stop_value`. | ||
|
||
raised: The analysis stops when the validation metric has raised by `early_stop_value`. | ||
|
||
```python | ||
s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75], early_stop_mode='dropped', early_stop_value=0.1) | ||
``` | ||
If users only want to analyze several specified convolutional layers, users can specify the target conv layers by the `specified_layers` in analysis function. `specified_layers` is a list that consists of the Pytorch module names of the conv layers. For example | ||
```python | ||
sensitivity = s_analyzer.analysis(val_args=[net], specified_layers=['Conv1']) | ||
``` | ||
In this example, only the `Conv1` layer is analyzed. In addtion, users can quickly and easily achieve the analysis parallelization by launching multiple processes and assigning different conv layers of the same model to each process. | ||
|
||
|
||
### Output example | ||
The following lines are the example csv file exported from SensitivityAnalysis. The first line is constructed by 'layername' and sparsity list. Here the sparsity value means how much weight SensitivityAnalysis prune for each layer. Each line below records the model accuracy when this layer is under different sparsities. Note that, due to the early_stop option, some layers may | ||
not have model accuracies/losses under all sparsities, for example, its accuracy drop has already exceeded the threshold set by the user. | ||
``` | ||
layername,0.05,0.1,0.2,0.3,0.4,0.5,0.7,0.85,0.95 | ||
features.0,0.54566,0.46308,0.06978,0.0374,0.03024,0.01512,0.00866,0.00492,0.00184 | ||
features.3,0.54878,0.51184,0.37978,0.19814,0.07178,0.02114,0.00438,0.00442,0.00142 | ||
features.6,0.55128,0.53566,0.4887,0.4167,0.31178,0.19152,0.08612,0.01258,0.00236 | ||
features.8,0.55696,0.54194,0.48892,0.42986,0.33048,0.2266,0.09566,0.02348,0.0056 | ||
features.10,0.55468,0.5394,0.49576,0.4291,0.3591,0.28138,0.14256,0.05446,0.01578 | ||
``` | ||
|
||
## Topology Analysis | ||
We also provide several tools for the topology analysis during the model compression. These tools are to help users compress their model better. Because of the complex topology of the network, when compressing the model, users often need to spend a lot of effort to check whether the compression configuration is reasonable. So we provide these tools for topology analysis to reduce the burden on users. | ||
|
||
### ChannelDependency | ||
Complicated models may have residual connection/concat operations in their models. When the user prunes these models, they need to be careful about the channel-count dependencies between the convolution layers in the model. Taking the following residual block in the resnet18 as an example. The output features of the `layer2.0.conv2` and `layer2.0.downsample.0` are added together, so the number of the output channels of `layer2.0.conv2` and `layer2.0.downsample.0` should be the same, or there may be a tensor shape conflict. | ||
|
||
![](../../img/channel_dependency_example.jpg) | ||
|
||
|
||
If the layers have channel dependency are assigned with different sparsities (here we only discuss the structured pruning by L1FilterPruner/L2FilterPruner), then there will be a shape conflict during these layers. Even the pruned model with mask works fine, the pruned model cannot be speedup to the final model directly that runs on the devices, because there will be a shape conflict when the model tries to add/concat the outputs of these layers. This tool is to find the layers that have channel count dependencies to help users better prune their model. | ||
|
||
#### Usage | ||
```python | ||
from nni.compression.torch.utils.shape_dependency import ChannelDependency | ||
data = torch.ones(1, 3, 224, 224).cuda() | ||
channel_depen = ChannelDependency(net, data) | ||
channel_depen.export('dependency.csv') | ||
``` | ||
|
||
#### Output Example | ||
The following lines are the output example of torchvision.models.resnet18 exported by ChannelDependency. The layers at the same line have output channel dependencies with each other. For example, layer1.1.conv2, conv1, and layer1.0.conv2 have output channel dependencies with each other, which means the output channel(filters) numbers of these three layers should be same with each other, otherwise, the model may have shape conflict. | ||
``` | ||
Dependency Set,Convolutional Layers | ||
Set 1,layer1.1.conv2,layer1.0.conv2,conv1 | ||
Set 2,layer1.0.conv1 | ||
Set 3,layer1.1.conv1 | ||
Set 4,layer2.0.conv1 | ||
Set 5,layer2.1.conv2,layer2.0.conv2,layer2.0.downsample.0 | ||
Set 6,layer2.1.conv1 | ||
Set 7,layer3.0.conv1 | ||
Set 8,layer3.0.downsample.0,layer3.1.conv2,layer3.0.conv2 | ||
Set 9,layer3.1.conv1 | ||
Set 10,layer4.0.conv1 | ||
Set 11,layer4.0.downsample.0,layer4.1.conv2,layer4.0.conv2 | ||
Set 12,layer4.1.conv1 | ||
``` | ||
|
||
### MaskConflict | ||
When the masks of different layers in a model have conflict (for example, assigning different sparsities for the layers that have channel dependency), we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value. | ||
|
||
``` | ||
from nni.compression.torch.utils.mask_conflict import MaskConflict | ||
mc = MaskConflict('./resnet18_mask', net, data) | ||
mc.fix_mask_conflict() | ||
mc.export('./resnet18_fixed_mask') | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 96 additions & 0 deletions
96
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
import logging | ||
import torch | ||
import numpy as np | ||
from .shape_dependency import ChannelDependency | ||
# logging.basicConfig(level = logging.DEBUG) | ||
_logger = logging.getLogger('FixMaskConflict') | ||
|
||
class MaskConflict: | ||
def __init__(self, mask_file, model=None, dummy_input=None, graph=None): | ||
""" | ||
MaskConflict fix the mask conflict between the layers that | ||
has channel dependecy with each other. | ||
|
||
Parameters | ||
---------- | ||
model : torch.nn.Module | ||
model to fix the mask conflict | ||
dummy_input : torch.Tensor | ||
input example to trace the model | ||
mask_file : str | ||
the path of the original mask file | ||
graph : torch._C.Graph | ||
the traced graph of the target model, is this parameter is not None, | ||
we donnot use the model and dummpy_input to get the trace graph. | ||
""" | ||
# check if the parameters are valid | ||
parameter_valid = False | ||
if graph is not None: | ||
parameter_valid = True | ||
elif (model is not None) and (dummy_input is not None): | ||
parameter_valid = True | ||
if not parameter_valid: | ||
raise Exception('The input parameters is invalid!') | ||
self.model = model | ||
self.dummy_input = dummy_input | ||
self.graph = graph | ||
self.mask_file = mask_file | ||
self.masks = torch.load(self.mask_file) | ||
|
||
def fix_mask_conflict(self): | ||
""" | ||
Fix the mask conflict before the mask inference for the layers that | ||
has shape dependencies. This function should be called before the | ||
mask inference of the 'speedup' module. | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
channel_depen = ChannelDependency(self.model, self.dummy_input, self.graph) | ||
depen_sets = channel_depen.dependency_sets | ||
for dset in depen_sets: | ||
if len(dset) == 1: | ||
# This layer has no channel dependency with other layers | ||
continue | ||
channel_remain = set() | ||
fine_grained = False | ||
for name in dset: | ||
if name not in self.masks: | ||
# this layer is not pruned | ||
continue | ||
w_mask = self.masks[name]['weight'] | ||
shape = w_mask.size() | ||
count = np.prod(shape[1:]) | ||
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() | ||
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist() | ||
if len(all_ones) + len(all_zeros) < w_mask.size(0): | ||
# In fine-grained pruning, there is no need to check | ||
# the shape conflict | ||
_logger.info('Layers %s using fine-grained pruning', ','.join(dset)) | ||
fine_grained = True | ||
break | ||
channel_remain.update(all_ones) | ||
_logger.debug('Layer: %s ', name) | ||
_logger.debug('Original pruned filters: %s', str(all_zeros)) | ||
# Update the masks for the layers in the dependency set | ||
if fine_grained: | ||
continue | ||
ori_channels = 0 | ||
for name in dset: | ||
mask = self.masks[name] | ||
w_shape = mask['weight'].size() | ||
ori_channels = w_shape[0] | ||
for i in channel_remain: | ||
mask['weight'][i] = torch.ones(w_shape[1:]) | ||
if hasattr(mask, 'bias'): | ||
mask['bias'][i] = 1 | ||
_logger.info(','.join(dset)) | ||
_logger.info('Pruned Filters after fixing conflict:') | ||
pruned_filters = set(list(range(ori_channels)))-channel_remain | ||
_logger.info(str(sorted(pruned_filters))) | ||
return self.masks | ||
|
||
def export(self, path): | ||
""" | ||
Export the masks after fixing the conflict to file. | ||
""" | ||
torch.save(self.masks, path) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add
after this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in order for users to easily get what is the content of this doc