Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Speedup enhancement #2719

Merged
merged 8 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __repr__(self):
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
Expand All @@ -241,7 +242,8 @@ def __repr__(self):
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask)
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask)
}

"""
Expand All @@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
return module_masks.output_mask
# if alreay visited
assert module_masks.input_mask <= mask
if module_masks.input_mask == mask:
return None
# It should be the same, we pass the masks by the reference(not the value),
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
Expand Down Expand Up @@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None
assert module_masks.input_mask is None
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
return None

Expand Down Expand Up @@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
# due to the cat operation, the same node may be
# accessed more than once
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
index = []
Expand Down Expand Up @@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice
if module_masks.input_mask is not None:
# check if has a mask conflict
assert module_masks.input_mask == mask
# No need to pass the mask again
return None
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
Expand Down
13 changes: 9 additions & 4 deletions src/sdk/pynni/tests/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,23 @@ def test_speedup_bigmodel(self):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)

def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2']:
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']:
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
if model_name == 'inception_v3':
# jit.trace cannot capture the aux_logits path when the net.training is False
net = Model(pretrained=True, progress=False, aux_logits=False).to(device)
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
speedup_model = Model(aux_logits=False).to(device)
else:
net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
speedup_model = Model().to(device)
speedup_model.eval()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
Expand Down