Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Analysis utils #2435

Merged
merged 49 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
0a4b7b0
Add analysis tools for sensitivity and topology.
May 14, 2020
712d982
Reformat the code and add several small new features.
May 14, 2020
202593c
Add the flops information rendering for the visulization.
May 14, 2020
8a7a799
Add the depedency rendering feature.
May 14, 2020
362441a
Update the interface of the SensitivityAnalysis
May 15, 2020
e69e78f
Add sensitivity rendering feature.
May 15, 2020
5823276
Add copyright and license.
May 15, 2020
4a70d79
Remove the unrelated files.
May 15, 2020
fc95dd7
Fix some typos.
May 15, 2020
a90c35e
Fix a small issue.
May 18, 2020
1909ff0
Fix a small issue.
May 19, 2020
2d13dda
Fix bug.
May 20, 2020
0e79624
Add compatibility with versions prior to torch-1.4.0.
May 21, 2020
96cea74
Add the mask conflict fix module.
May 25, 2020
6029603
Update the interface.
May 25, 2020
9beb1e2
Add unit test for analysis_utils.
May 26, 2020
6b25ff3
Fix the format warnings from pylint.
May 28, 2020
d0bda49
Add dependencies.
May 28, 2020
4154cf0
comment the visualization test temporarily.
May 28, 2020
83f0b26
update
May 28, 2020
388056c
Skip the test when the torch version is too old.
May 28, 2020
ccbcc6c
update
May 28, 2020
4ce8255
update according to the review comments.
Jun 1, 2020
0f70f67
update according to review comments.
Jun 1, 2020
2eac259
Add docs for analysis_utils.
Jun 1, 2020
810f20e
update rst
Jun 1, 2020
dcdc736
Merge branch 'master' of https://github.com/microsoft/nni into analys…
Jun 1, 2020
3b9f4df
Use TorchModuleGraph to analyze the shape dependency.
Jun 10, 2020
a214bb8
refactor the compression utils.
Jun 10, 2020
6d1a546
Update the corresponding unit test.
Jun 10, 2020
3aeb8a2
Remove the visualization modules and related dependencies.
Jun 10, 2020
bf72f3d
update
Jun 10, 2020
caced25
Update the docs.
Jun 11, 2020
69ea95e
Merge branch 'master' of https://github.com/microsoft/nni into analys…
Jun 11, 2020
c0e93e5
update docs.
Jun 12, 2020
6d7ea88
update docs.
Jun 12, 2020
e7790a2
update docs
Jun 12, 2020
b7671da
Update according the review comments.
Jun 15, 2020
1b9705b
Rename the unit test.
Jun 15, 2020
9d0519e
update docs
Jun 15, 2020
f563802
fix pylint errors
Jun 15, 2020
a24acd0
Update.
Jun 16, 2020
91d5f49
update
Jun 16, 2020
33178a2
fix grammar
Jun 16, 2020
7cab808
update doc
Jun 16, 2020
e8d4c31
Update the docs.
Jun 16, 2020
3351cef
update doc
Jun 16, 2020
db0ff63
update Docs.
Jun 16, 2020
7153bd7
remove unnecessray comments
Jun 16, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deployment/pypi/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
'scipy',
'coverage',
'colorama',
'scikit-learn>=0.20,<0.22'
'scikit-learn>=0.20,<0.22',
'graphviz',
'matplotlib'
],
classifiers = [
'Programming Language :: Python :: 3',
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def read(fname):
'schema',
'PythonWebHDFS',
'colorama',
'scikit-learn>=0.20,<0.22'
'scikit-learn>=0.20,<0.22',
'graphviz',
'matplotlib'
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
],

entry_points = {
Expand Down
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/analysis_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/analysis_utils/sensitivity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the MIT license.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .sensitivity_analysis import SensitivityAnalysis
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import copy
import csv
import logging
from collections import OrderedDict
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

import numpy as np
import torch.nn as nn

from nni.compression.torch import LevelPruner
from nni.compression.torch import L1FilterPruner
from nni.compression.torch import L2FilterPruner

# use Agg backend
matplotlib.use('Agg')
SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d']
SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME]

logger = logging.getLogger('Sensitivity_Analysis')
logger.setLevel(logging.INFO)


class SensitivityAnalysis:
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop=1.0):
"""
Perform sensitivity analysis for this model.
Parameters
----------
model:
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
the model to perform sensitivity analysis
val_func:
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
validation function for the model. Due to
different models may need different dataset/criterion
, therefore the user need to cover this part by themselves.
val_func take the model as the first input parameter, and
return the accuracy as output.
sparsities:
The sparsity list provided by users.
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
prune_type:
The pruner type used to prune the conv layers, default is 'l1',
and 'l2', 'fine-grained' is also supported.
early_stop:
If this flag is set, the sensitivity analysis
for a conv layer will early stop when the accuracy
drop already reach the value of early_stop (0.05 for example).
The default value is 1.0, which means the analysis won't stop
until all given sparsities are tested.

"""
self.model = model
self.val_func = val_func
self.target_layer = OrderedDict()
self.ori_state_dict = copy.deepcopy(self.model.state_dict())
self.target_layer = {}
self.sensitivities = {}
if sparsities is not None:
self.sparsities = sorted(sparsities)
else:
self.sparsities = np.arange(0.1, 1.0, 0.1)
self.sparsities = [np.round(x, 2) for x in self.sparsities]
self.Pruner = L1FilterPruner
if prune_type == 'l2':
self.Pruner = L2FilterPruner
elif prune_type == 'fine-grained':
self.Pruner = LevelPruner
self.early_stop = early_stop
self.ori_acc = None # original accuracy for the model
# already_pruned is for the iterative sensitivity analysis
# For example, sensitivity_pruner iteratively prune the target
# model according to the sensitivity. After each round of
# pruning, the sensitivity_pruner will test the new sensitivity
# for each layer
self.already_pruned = {}
self.model_parse()

@property
def layers_count(self):
return len(self.target_layer)

def model_parse(self):
for name, submodel in self.model.named_modules():
for op_type in SUPPORTED_OP_TYPE:
if isinstance(submodel, op_type):
self.target_layer[name] = submodel
self.already_pruned[name] = 0

def analysis(self, val_args=None, val_kwargs=None, start=0, end=None):
"""
This function analyze the sensitivity to pruning for
each conv layer in the target model.
If %start and %end are not set, we analyze all the conv
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
layers by default. Users can specify several layers to
analyze or parallelize the analysis process easily through
the %start and %end parameter.

Parameters
----------
start:
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
Layer index of the sensitivity analysis start
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of "Layer index"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this index is only used internally, right? not exposed to users? if it is also exposed to users, we should explain it in more detail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start and end is the range of the conv layers to be analyzed. For example, if start and end are set to 1 and 10 respectively, then we only analyze the sensitivity of the first ten conv layers of the model. These parameters are for users to analyze the layers in parallel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does user know which conv is index 2? my concern is this user interface is not very friendly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we use the conv names to identify the start point and end point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can discuss this part in the meeting.

end:
Layer index of the sensitivity analysis end
val_args:
args for the val_function
val_kwargs:
kwargs for the val_funtion
The val_funtion will be called as:
val_function(*val_args, **val_kwargs)
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
sensitivities:
dict object that stores the trajectory of the
accuracy when the prune ratio changes
"""
if not end:
end = self.layers_count
assert start >= 0 and end <= self.layers_count
assert start <= end
if val_args is None:
val_args = []
if val_kwargs is None:
val_kwargs = {}
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
# Get the validation accuracy before pruning
if self.ori_acc is None:
self.ori_acc = self.val_func(*val_args, **val_kwargs)
namelist = list(self.target_layer.keys())
for layerid in range(start, end):
name = namelist[layerid]
self.sensitivities[name] = {}
for sparsity in self.sparsities:
# Calculate the actual prune ratio based on the already pruned ratio
sparsity = (
1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, let's discuss in the meeting

cfg = [{'sparsity': sparsity, 'op_names': [
name], 'op_types': ['Conv2d']}]
pruner = self.Pruner(self.model, cfg)
pruner.compress()
val_acc = self.val_func(*val_args, **val_kwargs)
logger.info('Layer: %s Sparsity: %.2f Accuracy: %.4f',
name, sparsity, val_acc)

self.sensitivities[name][sparsity] = val_acc
pruner._unwrap_model()
del pruner
# if the accuracy drop already reach the 'early_stop'
if val_acc + self.early_stop < self.ori_acc:
break

# reset the weights pruned by the pruner, because
# out sparsities is sorted, so we donnot need to reset
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
# weight of the layer when the sparsity changes, instead,
# we only need reset the weight when the pruning layer changes.
self.model.load_state_dict(self.ori_state_dict)

return self.sensitivities

def visualization(self, outdir, merge=False):
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
"""
Visualize the sensitivity curves of the model

Parameters
----------
outdir:
output directory of the image
merge:
if merge all the sensitivity curves into a
single image. If not, we will draw a picture
for each target layer of the model.
"""
os.makedirs(outdir, exist_ok=True)
LineStyles = [':', '-.', '--', '-']
Markers = list(Line2D.markers.keys())
if not merge:
# Draw the sensitivity curves for each layer first
for name in self.sensitivities:
X = list(self.sensitivities[name].keys())
X = sorted(X)
Y = [self.sensitivities[name][x] for x in X]
if 0.00 not in X:
# add the original accuracy into the figure
X = [0.00] + X
Y = [self.ori_acc] + Y
plt.figure(figsize=(8, 4))
plt.plot(X, Y, marker='*')
plt.xlabel('Prune Ratio')
plt.ylabel('Validation Accuracy')
plt.title(name)
plt.tight_layout()
filepath = os.path.join(outdir, '%s.jpg' % name)
plt.savefig(filepath)
plt.close()
else:
plt.figure()
styleid = 0
for name in self.sensitivities:
X = list(self.sensitivities[name].keys())
X = sorted(X)
Y = [self.sensitivities[name][x] for x in X]
if 0.00 not in X:
# add the original accuracy into the figure
X = [0.00] + X
Y = [self.ori_acc] + Y
linestyle = LineStyles[styleid % len(LineStyles)]
marker = Markers[styleid % len(Markers)]
plt.plot(X, Y, label=name, linestyle=linestyle, marker=marker)
plt.xlabel('Prune Ratio')
plt.ylabel('Validation Accuracy')
plt.legend(loc='center left', bbox_to_anchor=(1.02, 0.5))
plt.tight_layout()
filepath = os.path.join(outdir, 'all.jpg')
plt.savefig(filepath, dpi=1000, bbox_inches='tight')
styleid += 1
plt.close()

def export(self, filepath):
"""
Export the results of the sensitivity analysis
to a csv file.
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
filepath:
Path of the output file
"""
str_sparsities = [str(x) for x in self.sparsities]
header = ['layername'] + str_sparsities
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf)
csv_w.writerow(header)
for layername in self.sensitivities:
row = []
row.append(layername)
for sparsity in sorted(self.sensitivities[layername].keys()):
row.append(self.sensitivities[layername][sparsity])
csv_w.writerow(row)

def update_already_pruned(self, layername, ratio):
"""
Set the already pruned ratio for the target layer.
"""
self.already_pruned[layername] = ratio

def load_state_dict(self, state_dict):
"""
Update the weight of the model
"""
self.ori_state_dict = copy.deepcopy(state_dict)
self.model.load_state_dict(self.ori_state_dict)
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/analysis_utils/topology/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/analysis_utils/topology/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Loading