Skip to content

Commit

Permalink
fix IT pruning example issue (microsoft#2772)
Browse files Browse the repository at this point in the history
  • Loading branch information
suiguoxin authored Aug 11, 2020
1 parent 654e824 commit d0a9b10
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
36 changes: 26 additions & 10 deletions src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ class LevelPruner(OneshotPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='level')
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer)

class SlimPruner(OneshotPruner):
"""
Expand All @@ -108,9 +110,11 @@ class SlimPruner(OneshotPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only BatchNorm2d is supported in Slim Pruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='slim')
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer)

def validate_config(self, model, config_list):
schema = CompressorSchema([{
Expand Down Expand Up @@ -147,9 +151,11 @@ class L1FilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='l1')
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer)

class L2FilterPruner(_StructuredFilterPruner):
"""
Expand All @@ -161,9 +167,11 @@ class L2FilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='l2')
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer)

class FPGMPruner(_StructuredFilterPruner):
"""
Expand All @@ -175,9 +183,11 @@ class FPGMPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='fpgm')
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer)

class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
"""
Expand All @@ -189,6 +199,8 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num)
Expand All @@ -203,6 +215,8 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \
Expand All @@ -218,6 +232,8 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \
Expand Down
10 changes: 6 additions & 4 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def test_torch_quantizer_modules_detection(self):

def test_torch_level_pruner(self):
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list).compress()
torch_compressor.LevelPruner(model, configure_list, optimizer).compress()

@tf2
def test_tf_level_pruner(self):
Expand Down Expand Up @@ -128,7 +129,7 @@ def test_torch_fpgm_pruner(self):

model = TorchModel()
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list)
pruner = torch_compressor.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01))

model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2)
Expand Down Expand Up @@ -314,7 +315,7 @@ def test_torch_QAT_quantizer(self):
def test_torch_pruner_validation(self):
# test bad configuraiton
pruner_classes = [torch_compressor.__dict__[x] for x in \
['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', \
['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\
'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

bad_configs = [
Expand All @@ -336,10 +337,11 @@ def test_torch_pruner_validation(self):
]
]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for pruner_class in pruner_classes:
for config_list in bad_configs:
try:
pruner_class(model, config_list)
pruner_class(model, config_list, optimizer)
print(config_list)
assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError:
Expand Down
4 changes: 1 addition & 3 deletions src/sdk/pynni/tests/test_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'])
elif pruner_name == 'autocompress':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x)
elif pruner_name in ['level', 'slim', 'fpgm', 'l1', 'l2']:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list)
elif pruner_name in ['agp', 'taylorfo', 'mean_activation', 'apoz']:
else:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress()

Expand Down

0 comments on commit d0a9b10

Please sign in to comment.