Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

bug bash for sensitivity_pruner #2815

Merged
merged 7 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/model_compress/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, n_class, profile='normal'):
def forward(self, x):
x = self.conv1(x)
x = self.features(x)
x = x.mean(3).mean(2) # global average pooling
x = x.mean([2, 3]) # global average pooling

x = self.classifier(x)
return x
Expand Down
5 changes: 4 additions & 1 deletion examples/model_compress/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.):

def forward(self, x):
x = self.features(x)
x = x.mean(3).mean(2)
# it's same with .mean(3).mean(2), but
# speedup only suport the mean option
# whose output only have two dimensions
x = x.mean([2, 3])
x = self.classifier(x)
return x

Expand Down
20 changes: 18 additions & 2 deletions src/sdk/pynni/nni/compression/torch/pruning/sensitivity_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MAX_PRUNE_RATIO_PER_ITER = 0.95

_logger = logging.getLogger('Sensitivity_Pruner')

_logger.setLevel(logging.INFO)

class SensitivityPruner(Pruner):
"""
Expand Down Expand Up @@ -202,10 +202,10 @@ def _max_prune_ratio(self, ori_acc, threshold, sensitivities):
prune_ratios = sorted(sensitivities[layer].keys())
last_ratio = 0
for ratio in prune_ratios:
last_ratio = ratio
cur_acc = sensitivities[layer][ratio]
if cur_acc + threshold < ori_acc:
break
last_ratio = ratio
max_ratio[layer] = last_ratio
return max_ratio

Expand Down Expand Up @@ -244,6 +244,7 @@ def normalize(self, ratios, target_pruned):
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# ratios under this threshold
if _Max > MAX_PRUNE_RATIO_PER_ITER:

for layername in ratios:
ratios[layername] = ratios[layername] * \
MAX_PRUNE_RATIO_PER_ITER / _Max
Expand Down Expand Up @@ -317,6 +318,7 @@ def compress(self, eval_args=None, eval_kwargs=None,
finetune_kwargs = {}
if self.ori_acc is None:
self.ori_acc = self.evaluator(*eval_args, **eval_kwargs)
assert isinstance(self.ori_acc, float) or isinstance(self.ori_acc, int)
if not resume_sensitivity:
self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs)
Expand All @@ -330,6 +332,7 @@ def compress(self, eval_args=None, eval_kwargs=None,
iteration_count = 0
if self.checkpoint_dir is not None:
os.makedirs(self.checkpoint_dir, exist_ok=True)
modules_wrapper_final = None
while cur_ratio > target_ratio:
iteration_count += 1
# Each round have three steps:
Expand All @@ -343,9 +346,16 @@ def compress(self, eval_args=None, eval_kwargs=None,
# layers according to the sensitivity result
proportion = self.sparsity_proportion_calc(
ori_acc, self.acc_drop_threshold, self.sensitivities)

new_pruneratio = self.normalize(proportion, self.sparsity_per_iter)
cfg_list = self.create_cfg(new_pruneratio)
if not cfg_list:
_logger.error('The threshold is too small, please set a larger threshold')
return self.model
_logger.debug('Pruner Config: %s', str(cfg_list))
cfg_str = ['%s:%.3f'%(cfg['op_names'][0], cfg['sparsity']) for cfg in cfg_list]
_logger.info('Current Sparsities: %s', ','.join(cfg_str))

pruner = self.Pruner(self.model, cfg_list)
pruner.compress()
pruned_acc = self.evaluator(*eval_args, **eval_kwargs)
Expand All @@ -367,6 +377,7 @@ def compress(self, eval_args=None, eval_kwargs=None,
self.analyzer.already_pruned[name] = sparsity
# update the cur_ratio
cur_ratio = 1 - self.current_sparsity()
modules_wrapper_final = pruner.get_modules_wrapper()
del pruner
_logger.info('Currently remained weights: %f', cur_ratio)

Expand All @@ -383,14 +394,19 @@ def compress(self, eval_args=None, eval_kwargs=None,
with open(cfg_path, 'w') as jf:
json.dump(cfg_list, jf)
self.analyzer.export(sensitivity_path)

if cur_ratio > target_ratio:
# If this is the last prune iteration, skip the time-consuming
# sensitivity analysis

self.analyzer.load_state_dict(self.model.state_dict())
self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs)

_logger.info('After Pruning: %.2f weights remains', cur_ratio)
self.modules_wrapper = modules_wrapper_final

self._wrap_model()
return self.model

def calc_mask(self, wrapper, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,28 +163,30 @@ def analysis(self, val_args=None, val_kwargs=None, specified_layers=None):
if val_kwargs is None:
val_kwargs = {}
# Get the original validation metric(accuracy/loss) before pruning
if self.ori_metric is None:
self.ori_metric = self.val_func(*val_args, **val_kwargs)
# Get the accuracy baseline before starting the analysis.
self.ori_metric = self.val_func(*val_args, **val_kwargs)
namelist = list(self.target_layer.keys())
if specified_layers is not None:
# only analyze several specified conv layers
namelist = list(filter(lambda x: x in specified_layers, namelist))
for name in namelist:
self.sensitivities[name] = {}
for sparsity in self.sparsities:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio
sparsity = (
real_sparsity = (
1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names
cfg = [{'sparsity': sparsity, 'op_names': [
cfg = [{'sparsity': real_sparsity, 'op_names': [
name], 'op_types': ['Conv2d']}]
pruner = self.Pruner(self.model, cfg)
pruner.compress()
val_metric = self.val_func(*val_args, **val_kwargs)
logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f',
name, sparsity, val_metric)
name, real_sparsity, val_metric)

self.sensitivities[name][sparsity] = val_metric
pruner._unwrap_model()
Expand Down