Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Support the Resnet/Squeezenet/Mobilenet for speedup #2579

Merged
merged 90 commits into from
Jun 24, 2020

Conversation

zheng-ningxin
Copy link
Contributor

In this pr, the speedup module will support the add/cat operations and the convolution layers that have more than 1 group. I have tested the speedup module on the resnet18, squeezenet1_1, and mobilenetv_2 and it works fine.

Ningxin added 30 commits May 14, 2020 01:26
Signed-off-by: Ningxin <[email protected]>
model should be set to eval mode before the jit.trace call.

Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
In the original way, addmm will also triger the dependency set searching,
which may lead to a wrong dependency set.

Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
The name of the node is not a unique identifier globally.

Signed-off-by: Ningxin <[email protected]>
mask_conflict can fix the mask conflict of the layers that
has channel dependency. This part should be called before the
speedup function, so that, the speedup module can handle the model
with residual connection/concat operations.

Signed-off-by: Ningxin <[email protected]>
update the interface. if we alreay have the traced graph of the model
we donnot need to trace the model again.

Signed-off-by: Ningxin <[email protected]>
Add unittest for tools in analysis_utils to verify the correctness of the
visulization, channel dependency, and mask conflict.

Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
@zheng-ningxin
Copy link
Contributor Author

I find another problem in the TorchModuleGraph (#2581). It may be too late to fix this problem before this release(code freeze at 6.22), but fortunately, there are not many models with this problem. I'll try to fix it in the next release. Thanks~

Ningxin added 3 commits June 19, 2020 12:21
Signed-off-by: Ningxin <[email protected]>
Signed-off-by: Ningxin <[email protected]>
@@ -11,13 +11,13 @@

from nni.compression.torch import L1FilterPruner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please add test cases to verify the model speedup correctness for resnet, squeezenet and mobilenet like this test case?
https://github.com/microsoft/nni/blob/master/src/sdk/pynni/tests/test_model_speedup.py#L106

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure~

input_shapes = [t.type().sizes() for t in input_tensors]
cat_info['in_shape'] = input_shapes
return cat_info

def _extract_shape_info(self, node):
Copy link
Contributor

@QuanluZhang QuanluZhang Jun 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is written by me, it is only for view (pretty limited). maybe we can generalize this function to extract different module's shape if needed in future.

Returns
-------
dict
Include auxiliary information for the cat operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be better to explain the content of the dict

# after the build_index function.
input_order = []
list_construct_cpp = list(cpp_node.inputs())[0].node()
input_tensors = list(list_construct_cpp.inputs())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the order of tensors returned by .inputs() is the order of input arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my observation and experimental results, yes it is.
However, because jit itself lacks documentation, I have no documentation to support this point.
I will read the source code of jit and double check it.

self.torch_graph = build_module_graph(model, dummy_input)

def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update docstring accordingly


Parameters
----------
model : torch.nn.Module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use consistent order with input arguments

"""
torch.save(self.masks, path)

class CatMaskPadding(MaskFix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain the logic of cat mask padding in docstring to deliver the high level idea of how the conflict is resolved?

# no layer is pruned
continue
elif count == len(layers):
# all the layers have been pruned
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even all the layers have been pruned, is it possible their masks are still not consistent?

for layer in layers:
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the mask is all ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cat concatenates the input masks as the output mask, so when part of the input layers are not pruned, we still need to pass the masks of these not-pruned layers(all ones) to the cat operation to ensure the shape of the final output mask is right.

graph : torch._C.Graph
masks : dict
a dict object that stores the masks
graph : torch._C.torch.jit.TopLevelTracedModule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inconsistent order with input arguments

@chicm-ms chicm-ms merged commit e6817d2 into microsoft:master Jun 24, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for more architecture and functions pruned model size no change and inference time is even longer
3 participants