Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add model export for QAT #3458

Merged
merged 8 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def get_bits_length(config, quant_type):
else:
return config["quant_bits"].get(quant_type)


class QATGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
Expand Down Expand Up @@ -153,13 +152,26 @@ def __init__(self, model, config_list, optimizer=None):
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('activation_bit', torch.zeros(1))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))

def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_biased', 'tracked_max_biased', 'tracked_min', \
'tracked_max', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)

def validate_config(self, model, config_list):
"""
Parameters
Expand Down Expand Up @@ -256,13 +268,15 @@ def quantize_weight(self, wrapper, **kwargs):
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bit = torch.Tensor([weight_bits])
wrapper.module.weight = weight
return weight

def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output')
module.activation_bit = torch.Tensor([output_bits])
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"

Expand All @@ -282,6 +296,47 @@ def quantize_output(self, output, wrapper, **kwargs):
out = self._dequantize(module, out)
return out

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None

Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'):
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased)
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased)
self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

return calibration_config

def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
Expand All @@ -301,6 +356,19 @@ class DoReFaQuantizer(Quantizer):

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))

def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'weight_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)

def validate_config(self, model, config_list):
"""
Expand Down Expand Up @@ -330,6 +398,7 @@ def quantize_weight(self, wrapper, **kwargs):
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([weight_bits])
# wrapper.module.weight.data = weight
return weight

Expand All @@ -338,6 +407,42 @@ def quantize(self, input_ri, q_bits):
output = torch.round(input_ri * scale) / scale
return output

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None

Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'):
calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this quantizer does not calibrate activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our current implementation of Dorefa, it does not quantize activation, so we don't need to calibrate activation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you double check the paper, does the paper mention how to calibrate activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion, we reach an agreement that the refactor of Dorefa should starts after survey and it will be done in another PR. What' s more, ut related to export_model() has been added into code.

self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

return calibration_config


class ClipGrad(QuantGrad):
@staticmethod
Expand All @@ -356,6 +461,19 @@ class BNNQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))

def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'weight_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)

def validate_config(self, model, config_list):
"""
Expand Down Expand Up @@ -384,10 +502,47 @@ def quantize_weight(self, wrapper, **kwargs):
# remove zeros
weight[weight == 0] = 1
wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([1.0])
return weight

def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
return out

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None

Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'):
calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit)
self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

return calibration_config
62 changes: 61 additions & 1 deletion nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def _setattr(model, name, module):
model = getattr(model, name)
setattr(model, name_list[-1], module)


class Compressor:
"""
Abstract base PyTorch compressor
Expand Down Expand Up @@ -573,6 +572,67 @@ def _wrap_modules(self, layer, config):

return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)

def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, \
input_shape=None, device=None):
"""
This method helps save pytorch model, calibration config, onnx model in quantizer.

Parameters
----------
model : pytorch model
pytorch model to be saved
model_path : str
path to save pytorch
calibration_config: dict
(optional) config of calibration parameters
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
torch.save(model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
if calibration_path is not None:
torch.save(calibration_config, calibration_path)
_logger.info('Mask dict saved to %s', calibration_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None

Returns
-------
Dict
"""
raise NotImplementedError('Quantizer must overload export_model()')

def step_with_optimizer(self):
pass

Expand Down