Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Speedup enhancement #2719

Merged
merged 8 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ def infer_modules_masks(self):
"""
for module_name, mask in self.masks.items():
_logger.debug('Start mask inference from %s', module_name)
if module_name not in self.torch_graph.name_to_node:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
continue
self.infer_module_mask(module_name, None, mask=mask)

def replace_compressed_modules(self):
Expand Down
27 changes: 18 additions & 9 deletions src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __repr__(self):
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
Expand All @@ -241,7 +242,8 @@ def __repr__(self):
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask)
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask)
}

"""
Expand All @@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
return module_masks.output_mask
# if alreay visited
assert module_masks.input_mask <= mask
if module_masks.input_mask == mask:
return None
# It should be the same, we pass the masks by the reference(not the value),
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
Expand Down Expand Up @@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None
assert module_masks.input_mask is None
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
return None

Expand Down Expand Up @@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
# due to the cat operation, the same node may be
# accessed more than once
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
index = []
Expand Down Expand Up @@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice
if module_masks.input_mask is not None:
# check if has a mask conflict
assert module_masks.input_mask == mask
# No need to pass the mask again
return None
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
Expand Down
6 changes: 3 additions & 3 deletions src/sdk/pynni/tests/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ def test_speedup_bigmodel(self):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)

def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2']:
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']:
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
speedup_model = Model().to(device)
speedup_model.eval()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
Expand Down