-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support the Resnet/Squeezenet/Mobilenet for speedup #2579
Conversation
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
model should be set to eval mode before the jit.trace call. Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
In the original way, addmm will also triger the dependency set searching, which may lead to a wrong dependency set. Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
The name of the node is not a unique identifier globally. Signed-off-by: Ningxin <[email protected]>
mask_conflict can fix the mask conflict of the layers that has channel dependency. This part should be called before the speedup function, so that, the speedup module can handle the model with residual connection/concat operations. Signed-off-by: Ningxin <[email protected]>
update the interface. if we alreay have the traced graph of the model we donnot need to trace the model again. Signed-off-by: Ningxin <[email protected]>
Add unittest for tools in analysis_utils to verify the correctness of the visulization, channel dependency, and mask conflict. Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
I find another problem in the TorchModuleGraph (#2581). It may be too late to fix this problem before this release(code freeze at 6.22), but fortunately, there are not many models with this problem. I'll try to fix it in the next release. Thanks~ |
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
@@ -11,13 +11,13 @@ | |||
|
|||
from nni.compression.torch import L1FilterPruner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you please add test cases to verify the model speedup correctness for resnet, squeezenet and mobilenet like this test case?
https://github.com/microsoft/nni/blob/master/src/sdk/pynni/tests/test_model_speedup.py#L106
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure~
input_shapes = [t.type().sizes() for t in input_tensors] | ||
cat_info['in_shape'] = input_shapes | ||
return cat_info | ||
|
||
def _extract_shape_info(self, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function is written by me, it is only for view
(pretty limited). maybe we can generalize this function to extract different module's shape if needed in future.
Returns | ||
------- | ||
dict | ||
Include auxiliary information for the cat operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be better to explain the content of the dict
# after the build_index function. | ||
input_order = [] | ||
list_construct_cpp = list(cpp_node.inputs())[0].node() | ||
input_tensors = list(list_construct_cpp.inputs()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the order of tensors returned by .inputs()
is the order of input arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to my observation and experimental results, yes it is.
However, because jit itself lacks documentation, I have no documentation to support this point.
I will read the source code of jit and double check it.
self.torch_graph = build_module_graph(model, dummy_input) | ||
|
||
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): | ||
def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update docstring accordingly
|
||
Parameters | ||
---------- | ||
model : torch.nn.Module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use consistent order with input arguments
""" | ||
torch.save(self.masks, path) | ||
|
||
class CatMaskPadding(MaskFix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you explain the logic of cat mask padding in docstring to deliver the high level idea of how the conflict is resolved?
# no layer is pruned | ||
continue | ||
elif count == len(layers): | ||
# all the layers have been pruned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even all the layers have been pruned, is it possible their masks are still not consistent?
for layer in layers: | ||
module = name_to_module[layer] | ||
w_shape = module.weight.data.size() | ||
w_mask = torch.ones(w_shape).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the mask is all ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cat concatenates the input masks as the output mask, so when part of the input layers are not pruned, we still need to pass the masks of these not-pruned layers(all ones) to the cat operation to ensure the shape of the final output mask is right.
graph : torch._C.Graph | ||
masks : dict | ||
a dict object that stores the masks | ||
graph : torch._C.torch.jit.TopLevelTracedModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inconsistent order with input arguments
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
In this pr, the speedup module will support the add/cat operations and the convolution layers that have more than 1 group. I have tested the speedup module on the resnet18, squeezenet1_1, and mobilenetv_2 and it works fine.