-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support the Resnet/Squeezenet/Mobilenet for speedup #2579
Changes from 87 commits
0a4b7b0
712d982
202593c
8a7a799
362441a
e69e78f
5823276
4a70d79
fc95dd7
a90c35e
1909ff0
2d13dda
0e79624
d998d17
1c1eb89
5181fe8
96cea74
6029603
9beb1e2
6b25ff3
d0bda49
4154cf0
83f0b26
388056c
ccbcc6c
4ce8255
0f70f67
2eac259
810f20e
dcdc736
3b9f4df
a214bb8
6d1a546
3aeb8a2
bf72f3d
caced25
b25bb09
0adfd7d
69ea95e
a605901
c0e93e5
6d7ea88
e7790a2
8ad21bf
525bd07
102b3ae
521adae
b7671da
1b9705b
9d0519e
f563802
8a20204
38f0759
a24acd0
91d5f49
984dbc2
ecc1b98
33178a2
7cab808
e8d4c31
3351cef
db0ff63
7153bd7
1ead6a6
48a247e
7c1dacb
10e238a
f31cda6
60e259a
ef12416
fab9933
f78c693
7cda82b
7ecb89f
ae309f1
7bb9848
bcfba41
840a416
048aba9
dfc9d2e
a84dab7
5d003e8
93bcb36
5c1acad
4c0334e
70573f8
682a351
221329a
0cdd850
e11c36d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy | ||
CLASSTYPE_KIND = 'ClassType' | ||
GETATTR_KIND = 'prim::GetAttr' | ||
CAT_KIND = 'aten::cat' | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
@@ -236,6 +237,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): | |
super().__init__(model, dummy_input, traced_model) | ||
self.global_count = 0 | ||
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() | ||
self._extract_auxiliary_info() | ||
|
||
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, | ||
module_type): | ||
|
@@ -364,6 +366,58 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, | |
node_group, inputs=inputs, outputs=outputs) | ||
return nodepy | ||
|
||
def _extract_cat_info(self, node_group, cpp_node): | ||
""" | ||
Extract the detail information of the cat operation, | ||
such the order of the input tensor, the shape of each | ||
input tensor, the output shape, and the cat dimension. | ||
|
||
Parameters | ||
---------- | ||
node_group : NodePyGroup | ||
cpp_node: torch._C.Node | ||
It should be ```aten::cat``` node | ||
|
||
Returns | ||
------- | ||
dict | ||
Include auxiliary information for the cat operation. | ||
This dict objec has four keys: 'cat_dim', 'out_shape', | ||
'in_order' and 'in_shape'. cat_dim is the dimension of | ||
the cat operation to concat the input tensors. out_shape | ||
is the shape of the output tensor of the cat operation. | ||
in_order is an ordered list which contains the corresponding | ||
parent operaion nodes of the input tensors. in_shape is also | ||
an ordered list that contains the input shapes of the input | ||
tensor. | ||
""" | ||
# only suport the cat operation | ||
assert cpp_node.kind() == CAT_KIND | ||
cat_info = {} | ||
# get the shape of the output tensor | ||
t_output = cpp_node.output() | ||
out_shape = t_output.type().sizes() | ||
cat_info['out_shape'] = out_shape | ||
# get the cat dimension | ||
inputs = cpp_node.inputs() | ||
cat_dim = list(inputs)[1].toIValue() | ||
cat_info['cat_dim'] = cat_dim | ||
# get the order of the input tensors | ||
# To get the order of the input tensors, we need | ||
# to be aware of the topology of the model, which | ||
# means we should extract the auxiliary information | ||
# after the build_index function. | ||
input_order = [] | ||
list_construct_cpp = list(cpp_node.inputs())[0].node() | ||
input_tensors = list(list_construct_cpp.inputs()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so the order of tensors returned by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to my observation and experimental results, yes it is. |
||
for _tensor in input_tensors: | ||
debug_name = _tensor.debugName() | ||
input_order.append(self.output_to_node[debug_name].unique_name) | ||
cat_info['in_order'] = input_order | ||
input_shapes = [t.type().sizes() for t in input_tensors] | ||
cat_info['in_shape'] = input_shapes | ||
return cat_info | ||
|
||
def _extract_shape_info(self, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function is written by me, it is only for |
||
""" | ||
Extract the shape information of ```aten::view``` node | ||
|
@@ -541,8 +595,8 @@ def _build_graph(self): | |
node, nodes, input_to_node, output_to_node, 'func') | ||
nodes_py.nodes_op.append(node_group) | ||
# get shape infor for view (aten::view) func | ||
if node_group.op_type in ['aten::view', 'aten::flatten']: | ||
node_group.auxiliary = self._extract_shape_info(node) | ||
# if node_group.op_type in ['aten::view', 'aten::flatten']: | ||
# node_group.auxiliary = self._extract_shape_info(node) | ||
|
||
for node in graph.outputs(): # Create sink nodes for output ops | ||
node_py = NodePyIO(node, 'output') | ||
|
@@ -552,6 +606,26 @@ def _build_graph(self): | |
# build index | ||
return self._build_index(self.nodes_py.nodes_op) | ||
|
||
def _extract_auxiliary_info(self): | ||
""" | ||
Extract the auxiliary information for the nodegroups | ||
if necessary. For example, view/flatten operations may | ||
need the shape of the input tensor and output tensor. | ||
""" | ||
# extract the input & output shape for the view and flatten | ||
for node_group in self.nodes_py.nodes_op: | ||
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: | ||
# get shape infor for view (aten::view) func | ||
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, | ||
node_group.node_cpps))[0] | ||
node_group.auxiliary = self._extract_shape_info(cpp_node) | ||
elif node_group.op_type == CAT_KIND: | ||
# get the detail information for cat func | ||
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, | ||
node_group.node_cpps))[0] | ||
node_group.auxiliary = self._extract_cat_info( | ||
node_group, cpp_node) | ||
|
||
def find_predecessors(self, unique_name): | ||
""" | ||
Find predecessor node of the given node | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import logging | ||
import torch | ||
from nni._graph_utils import build_module_graph | ||
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict | ||
from .compress_modules import replace_module | ||
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape | ||
|
||
|
@@ -53,9 +54,10 @@ def __init__(self, model, dummy_input, masks_file, map_location=None): | |
self.bound_model = model | ||
self.masks = torch.load(masks_file, map_location) | ||
self.inferred_masks = dict() # key: module_name, value: ModuleMasks | ||
self.dummy_input = dummy_input | ||
self.torch_graph = build_module_graph(model, dummy_input) | ||
|
||
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): | ||
def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please update docstring accordingly |
||
Infer input shape / output shape based on the module's weight mask / input shape / output shape. | ||
|
||
|
@@ -71,6 +73,8 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non | |
---------- | ||
module_name : str | ||
The name of the node | ||
last_module : str | ||
The name of last visited node | ||
mask : tensor of mask or ModuleMasks | ||
Mask of the weights in this node (i.e., module) | ||
in_shape : ModuleMasks | ||
|
@@ -100,10 +104,17 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non | |
raise RuntimeError( | ||
"Has not supported infering output shape from input shape for module/function: `{}`, {}" | ||
.format(m_type, module_name)) | ||
if m_type in ['aten::view', 'aten::flatten']: | ||
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: | ||
output_cmask = infer_from_inshape[m_type](module_masks, | ||
in_shape, | ||
self.torch_graph.name_to_node[module_name].auxiliary) | ||
elif m_type in ['aten::cat']: | ||
# To calculate the mask for concat operation, the output shape | ||
# , cat dimension, and the order of the input parameters. | ||
output_cmask = infer_from_inshape[m_type](module_masks, | ||
in_shape, | ||
self.torch_graph.name_to_node[module_name].auxiliary, | ||
last_module) | ||
else: | ||
output_cmask = infer_from_inshape[m_type](module_masks, in_shape) | ||
if out_shape is not None: | ||
|
@@ -117,18 +128,19 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non | |
if input_cmask: | ||
predecessors = self.torch_graph.find_predecessors(module_name) | ||
for _module_name in predecessors: | ||
self.infer_module_mask(_module_name, out_shape=input_cmask) | ||
self.infer_module_mask(_module_name, module_name, out_shape=input_cmask) | ||
if output_cmask: | ||
successors = self.torch_graph.find_successors(module_name) | ||
for _module_name in successors: | ||
self.infer_module_mask(_module_name, in_shape=output_cmask) | ||
self.infer_module_mask(_module_name, module_name, in_shape=output_cmask) | ||
|
||
def infer_modules_masks(self): | ||
""" | ||
Do shape inference of involved modules, including the shape of weights, inputs, output | ||
""" | ||
for module_name, mask in self.masks.items(): | ||
self.infer_module_mask(module_name, mask=mask) | ||
_logger.debug('Start mask inference from %s', module_name) | ||
self.infer_module_mask(module_name, None, mask=mask) | ||
|
||
def replace_compressed_modules(self): | ||
""" | ||
|
@@ -144,19 +156,20 @@ def replace_compressed_modules(self): | |
_logger.debug("replace %s, in %s type, with op_type %s", | ||
module_name, g_node.type, g_node.op_type) | ||
if g_node.type == 'module': | ||
super_module, leaf_module = get_module_by_name(self.bound_model, module_name) | ||
super_module, leaf_module = get_module_by_name(self.bound_model, g_node.name) | ||
m_type = g_node.op_type | ||
if not m_type in replace_module: | ||
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type)) | ||
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type) | ||
_logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type) | ||
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name]) | ||
setattr(super_module, module_name.split('.')[-1], compressed_module) | ||
setattr(super_module, g_node.name.split('.')[-1], compressed_module) | ||
elif g_node.type == 'func': | ||
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", | ||
module_name, g_node.op_type) | ||
else: | ||
raise RuntimeError("Unsupported node type: {}".format(g_node.type)) | ||
|
||
|
||
def speedup_model(self): | ||
""" | ||
There are basically two steps: | ||
|
@@ -165,6 +178,8 @@ def speedup_model(self): | |
""" | ||
training = self.bound_model.training | ||
_logger.info("start to speed up the model") | ||
_logger.info("fix the mask conflict of the interdependent layers") | ||
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) | ||
_logger.info("infer module masks...") | ||
self.infer_modules_masks() | ||
_logger.info("replace compressed modules...") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be better to explain the content of the dict