Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Support the Resnet/Squeezenet/Mobilenet for speedup #2579

Merged
merged 90 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
0a4b7b0
Add analysis tools for sensitivity and topology.
May 14, 2020
712d982
Reformat the code and add several small new features.
May 14, 2020
202593c
Add the flops information rendering for the visulization.
May 14, 2020
8a7a799
Add the depedency rendering feature.
May 14, 2020
362441a
Update the interface of the SensitivityAnalysis
May 15, 2020
e69e78f
Add sensitivity rendering feature.
May 15, 2020
5823276
Add copyright and license.
May 15, 2020
4a70d79
Remove the unrelated files.
May 15, 2020
fc95dd7
Fix some typos.
May 15, 2020
a90c35e
Fix a small issue.
May 18, 2020
1909ff0
Fix a small issue.
May 19, 2020
2d13dda
Fix bug.
May 20, 2020
0e79624
Add compatibility with versions prior to torch-1.4.0.
May 21, 2020
d998d17
Update shape_dependency.
May 23, 2020
1c1eb89
Merge branch 'master' of https://github.com/microsoft/nni into speedu…
May 23, 2020
5181fe8
Find a bug in _graph_utils.
May 25, 2020
96cea74
Add the mask conflict fix module.
May 25, 2020
6029603
Update the interface.
May 25, 2020
9beb1e2
Add unit test for analysis_utils.
May 26, 2020
6b25ff3
Fix the format warnings from pylint.
May 28, 2020
d0bda49
Add dependencies.
May 28, 2020
4154cf0
comment the visualization test temporarily.
May 28, 2020
83f0b26
update
May 28, 2020
388056c
Skip the test when the torch version is too old.
May 28, 2020
ccbcc6c
update
May 28, 2020
4ce8255
update according to the review comments.
Jun 1, 2020
0f70f67
update according to review comments.
Jun 1, 2020
2eac259
Add docs for analysis_utils.
Jun 1, 2020
810f20e
update rst
Jun 1, 2020
dcdc736
Merge branch 'master' of https://github.com/microsoft/nni into analys…
Jun 1, 2020
3b9f4df
Use TorchModuleGraph to analyze the shape dependency.
Jun 10, 2020
a214bb8
refactor the compression utils.
Jun 10, 2020
6d1a546
Update the corresponding unit test.
Jun 10, 2020
3aeb8a2
Remove the visualization modules and related dependencies.
Jun 10, 2020
bf72f3d
update
Jun 10, 2020
caced25
Update the docs.
Jun 11, 2020
b25bb09
Merge branch 'master' of https://github.com/microsoft/nni into speedu…
Jun 11, 2020
0adfd7d
Merge branch 'master' of https://github.com/microsoft/nni into speedu…
Jun 11, 2020
69ea95e
Merge branch 'master' of https://github.com/microsoft/nni into analys…
Jun 11, 2020
a605901
use unique_name to speedup the model
Jun 11, 2020
c0e93e5
update docs.
Jun 12, 2020
6d7ea88
update docs.
Jun 12, 2020
e7790a2
update docs
Jun 12, 2020
8ad21bf
Add the auiliary fetch for cat operation.
Jun 13, 2020
525bd07
Add cat support for the speedup module.
Jun 14, 2020
102b3ae
Add/Concat support for the speedup module.
Jun 15, 2020
521adae
Add support for dropout.
Jun 15, 2020
b7671da
Update according the review comments.
Jun 15, 2020
1b9705b
Rename the unit test.
Jun 15, 2020
9d0519e
update docs
Jun 15, 2020
f563802
fix pylint errors
Jun 15, 2020
8a20204
merge with analysis_utils
Jun 15, 2020
38f0759
group dependency
Jun 15, 2020
a24acd0
Update.
Jun 16, 2020
91d5f49
update
Jun 16, 2020
984dbc2
Merge branch 'analysis_utils' of https://github.com/zheng-ningxin/nni…
Jun 16, 2020
ecc1b98
update
Jun 16, 2020
33178a2
fix grammar
Jun 16, 2020
7cab808
update doc
Jun 16, 2020
e8d4c31
Update the docs.
Jun 16, 2020
3351cef
update doc
Jun 16, 2020
db0ff63
update Docs.
Jun 16, 2020
7153bd7
remove unnecessray comments
Jun 16, 2020
1ead6a6
Merge branch 'analysis_utils' of https://github.com/zheng-ningxin/nni…
Jun 16, 2020
48a247e
Merge branch 'master' of https://github.com/microsoft/nni into speedu…
Jun 17, 2020
7c1dacb
refactor the code.
Jun 17, 2020
10e238a
Add support for group convolution.
Jun 17, 2020
f31cda6
find dependency on torchvision
Jun 17, 2020
60e259a
fix a bug in group support
Jun 18, 2020
ef12416
Find another bug that we need to trace by onnx.set_traing.
Jun 18, 2020
fab9933
Fix a bug of speedup CoarseMask
Jun 18, 2020
f78c693
tested on squeezenet1_1.
Jun 18, 2020
7cda82b
Fix a bug in maskconflict.
Jun 19, 2020
7ecb89f
add the support for the mean operation.
Jun 19, 2020
ae309f1
remove the unnecessary print and comments.
Jun 19, 2020
7bb9848
refactor the code.
Jun 19, 2020
bcfba41
Update the docstring.
Jun 19, 2020
840a416
fix a typo.
Jun 19, 2020
048aba9
update the test_compression_utils
Jun 19, 2020
dfc9d2e
fix pylint errors.
Jun 19, 2020
a84dab7
update
Jun 19, 2020
5d003e8
update doc.
Jun 20, 2020
93bcb36
Add the integration test for the speedu module.
Jun 23, 2020
5c1acad
add support for adaptive_avg_pool2d
Jun 23, 2020
4c0334e
add the support for reshape operation.
Jun 23, 2020
70573f8
update the doc string.
Jun 23, 2020
682a351
Add the absolute threshold for the unitest.
Jun 23, 2020
221329a
Add assert to make sure input channels are evenly pruned.
Jun 24, 2020
0cdd850
mute the progress bar when download the pretrained model.
Jun 24, 2020
e11c36d
use a small batch size
Jun 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/en_US/Compressor/CompressionReference.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
.. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency
:members:

.. autoclass:: nni.compression.torch.utils.mask_conflict.MaskConflict
.. autoclass:: nni.compression.torch.utils.shape_dependency.GroupDependency
:members:

.. autoclass:: nni.compression.torch.utils.mask_conflict.CatMaskPadding
:members:

.. autoclass:: nni.compression.torch.utils.mask_conflict.GroupMaskConflict
:members:

.. autoclass:: nni.compression.torch.utils.mask_conflict.ChannelMaskConflict
:members:

```
6 changes: 2 additions & 4 deletions docs/en_US/Compressor/CompressionUtils.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ Set 12,layer4.1.conv1
When the masks of different layers in a model have conflict (for example, assigning different sparsities for the layers that have channel dependency), we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value.

```
from nni.compression.torch.utils.mask_conflict import MaskConflict
mc = MaskConflict('./resnet18_mask', net, data)
mc.fix_mask_conflict()
mc.export('./resnet18_fixed_mask')
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
```
78 changes: 76 additions & 2 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
CAT_KIND = 'aten::cat'

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -236,6 +237,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()

def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
module_type):
Expand Down Expand Up @@ -364,6 +366,58 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
node_group, inputs=inputs, outputs=outputs)
return nodepy

def _extract_cat_info(self, node_group, cpp_node):
"""
Extract the detail information of the cat operation,
such the order of the input tensor, the shape of each
input tensor, the output shape, and the cat dimension.

Parameters
----------
node_group : NodePyGroup
cpp_node: torch._C.Node
It should be ```aten::cat``` node

Returns
-------
dict
Include auxiliary information for the cat operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be better to explain the content of the dict

This dict objec has four keys: 'cat_dim', 'out_shape',
'in_order' and 'in_shape'. cat_dim is the dimension of
the cat operation to concat the input tensors. out_shape
is the shape of the output tensor of the cat operation.
in_order is an ordered list which contains the corresponding
parent operaion nodes of the input tensors. in_shape is also
an ordered list that contains the input shapes of the input
tensor.
"""
# only suport the cat operation
assert cpp_node.kind() == CAT_KIND
cat_info = {}
# get the shape of the output tensor
t_output = cpp_node.output()
out_shape = t_output.type().sizes()
cat_info['out_shape'] = out_shape
# get the cat dimension
inputs = cpp_node.inputs()
cat_dim = list(inputs)[1].toIValue()
cat_info['cat_dim'] = cat_dim
# get the order of the input tensors
# To get the order of the input tensors, we need
# to be aware of the topology of the model, which
# means we should extract the auxiliary information
# after the build_index function.
input_order = []
list_construct_cpp = list(cpp_node.inputs())[0].node()
input_tensors = list(list_construct_cpp.inputs())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the order of tensors returned by .inputs() is the order of input arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my observation and experimental results, yes it is.
However, because jit itself lacks documentation, I have no documentation to support this point.
I will read the source code of jit and double check it.

for _tensor in input_tensors:
debug_name = _tensor.debugName()
input_order.append(self.output_to_node[debug_name].unique_name)
cat_info['in_order'] = input_order
input_shapes = [t.type().sizes() for t in input_tensors]
cat_info['in_shape'] = input_shapes
return cat_info

def _extract_shape_info(self, node):
Copy link
Contributor

@QuanluZhang QuanluZhang Jun 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is written by me, it is only for view (pretty limited). maybe we can generalize this function to extract different module's shape if needed in future.

"""
Extract the shape information of ```aten::view``` node
Expand Down Expand Up @@ -541,8 +595,8 @@ def _build_graph(self):
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
if node_group.op_type in ['aten::view', 'aten::flatten']:
node_group.auxiliary = self._extract_shape_info(node)
# if node_group.op_type in ['aten::view', 'aten::flatten']:
# node_group.auxiliary = self._extract_shape_info(node)

for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
Expand All @@ -552,6 +606,26 @@ def _build_graph(self):
# build index
return self._build_index(self.nodes_py.nodes_op)

def _extract_auxiliary_info(self):
"""
Extract the auxiliary information for the nodegroups
if necessary. For example, view/flatten operations may
need the shape of the input tensor and output tensor.
"""
# extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
# get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_cat_info(
node_group, cpp_node)

def find_predecessors(self, unique_name):
"""
Find predecessor node of the given node
Expand Down
38 changes: 31 additions & 7 deletions src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask)
'ReLU6': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask),
'Dropout': lambda module, mask: no_replace(module, mask),
'Dropout2d': lambda module, mask: no_replace(module, mask),
'Dropout3d': lambda module, mask: no_replace(module, mask)
}

def no_replace(module, mask):
Expand Down Expand Up @@ -111,28 +115,48 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]

_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=1, # currently only support groups is 1
groups=conv.groups,
bias=conv.bias is not None,
padding_mode=conv.padding_mode)

new_conv.to(conv.weight.device)
tmp_weight_data = tmp_bias_data = None

if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index)
if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index)
# NOTE: does not support group
else:
tmp_weight_data = conv.weight.data
# For the convolutional layers that have more than one group
# we need to copy the weight group by group, because the input
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
filter_step = int(out_channels / conv.groups)
if mask.input_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data if tmp_weight_data is None else tmp_weight_data,
1, in_channels_index)
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data)
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
f_start = groupid * filter_step
f_end = (groupid + 1) * filter_step
new_conv.weight.data[f_start:f_end] = torch.index_select(tmp_weight_data[f_start:f_end], 1, current_input_index)
else:
new_conv.weight.data.copy_(tmp_weight_data)

if conv.bias is not None:
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)

return new_conv
31 changes: 23 additions & 8 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import torch
from nni._graph_utils import build_module_graph
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape

Expand Down Expand Up @@ -53,9 +54,10 @@ def __init__(self, model, dummy_input, masks_file, map_location=None):
self.bound_model = model
self.masks = torch.load(masks_file, map_location)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.dummy_input = dummy_input
self.torch_graph = build_module_graph(model, dummy_input)

def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update docstring accordingly

Infer input shape / output shape based on the module's weight mask / input shape / output shape.

Expand All @@ -71,6 +73,8 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
----------
module_name : str
The name of the node
last_module : str
The name of last visited node
mask : tensor of mask or ModuleMasks
Mask of the weights in this node (i.e., module)
in_shape : ModuleMasks
Expand Down Expand Up @@ -100,10 +104,17 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type in ['aten::view', 'aten::flatten']:
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.torch_graph.name_to_node[module_name].auxiliary)
elif m_type in ['aten::cat']:
# To calculate the mask for concat operation, the output shape
# , cat dimension, and the order of the input parameters.
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.torch_graph.name_to_node[module_name].auxiliary,
last_module)
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
Expand All @@ -117,18 +128,19 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name)
for _module_name in predecessors:
self.infer_module_mask(_module_name, out_shape=input_cmask)
self.infer_module_mask(_module_name, module_name, out_shape=input_cmask)
if output_cmask:
successors = self.torch_graph.find_successors(module_name)
for _module_name in successors:
self.infer_module_mask(_module_name, in_shape=output_cmask)
self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)

def infer_modules_masks(self):
"""
Do shape inference of involved modules, including the shape of weights, inputs, output
"""
for module_name, mask in self.masks.items():
self.infer_module_mask(module_name, mask=mask)
_logger.debug('Start mask inference from %s', module_name)
self.infer_module_mask(module_name, None, mask=mask)

def replace_compressed_modules(self):
"""
Expand All @@ -144,19 +156,20 @@ def replace_compressed_modules(self):
_logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
super_module, leaf_module = get_module_by_name(self.bound_model, g_node.name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
_logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module)
setattr(super_module, g_node.name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))


def speedup_model(self):
"""
There are basically two steps:
Expand All @@ -165,6 +178,8 @@ def speedup_model(self):
"""
training = self.bound_model.training
_logger.info("start to speed up the model")
_logger.info("fix the mask conflict of the interdependent layers")
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
Expand Down
Loading