The model compression framework has two main components: pruner
and module wrapper
.
A pruner
is responsible for :
- provide a
cal_mask
method that calculates masks for weight and bias. - replace the module with
module wrapper
based on config. - modify the optimizer so that the
cal_mask
method is called every time thestep
method is called.
A module wrapper
is a module containing :
- the origin module
- some buffers used by
cal_mask
- a new forward method that applies masks before running the original forward method.
the reasons to use module wrapper
:
- some buffers are needed by
cal_mask
to calculate masks and these buffers should be registered inmodule wrapper
so that the original modules are not contaminated. - a new
forward
method is needed to apply masks to weight before calling the realforward
method.
A basic pruner usage:
configure_list = [{
'sparsity': 0.7,
'op_types': ['BatchNorm2d'],
}]
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer)
model = pruner.compress()
A pruner receive model, config and optimizer as arguments. In the __init__
method, the step
method of the optimizer is replaced with a new step
method that calls cal_mask
. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a module wrapper
. Afterward, the new model and new optimizer are returned, which can be trained as before. compress
method will calculate the default masks.
Implementing a new pruning algorithm requires implementing a new pruner
class, which should subclass Pruner
and override the cal_mask
method. The cal_mask
is called byoptimizer.step
method.
The Pruner
base class provided basic functionality listed above, for example, replacing modules and patching optimizer.
A basic pruner look likes this:
class NewPruner(Pruner):
def __init__(self, model, config_list, optimizer)
super().__init__(model, config_list, optimizer)
# do some initialization
def calc_mask(self, wrapper, **kwargs):
# do something to calculate weight_mask
wrapper.weight_mask = weight_mask
Sometimes cal_mask
must save some state data, therefore users can use set_wrappers_attribute
API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to module wrapper
. Users can access these buffers through module wrapper
.
class NewPruner(Pruner):
def __init__(self, model, config_list, optimizer):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, wrapper):
# do something to calculate weight_mask
if wrapper.if_calculated:
pass
else:
wrapper.if_calculated = True
# update masks
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module.
class ActivationRankFilterPruner(Pruner):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("collected_activation", [])
self.statistics_batch_num = statistics_batch_num
def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num:
module_.collected_activation.append(self.activation(output.detach().cpu()))
self.add_activation_collector(collector)
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
The collector function will be called each time the forward method runs.
Users can also remove this collector like this:
collector_id = self.add_activation_collector(collector)
# ...
self.remove_activation_collector(collector_id)
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the forward
method runs on multiple GPU. If buffers and parameters are updated in the forward
method, an in-place
update is needed to ensure the update is effective.
Since cal_mask
is called in the optimizer.step
method, which happens after the forward
method and happens only on one GPU, it supports multi-GPU naturally.