Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add model export for QAT #3458

Merged
merged 8 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 171 additions & 1 deletion nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def get_bits_length(config, quant_type):
else:
return config["quant_bits"].get(quant_type)


class QATGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
Expand Down Expand Up @@ -153,13 +152,41 @@ def __init__(self, model, config_list, optimizer=None):
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('activation_bit', torch.zeros(1))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))

def del_simulated_attr(self, module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> _del_simulated_attr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

"""
delete redundant parameters in quantize module
"""
if hasattr(module, 'old_weight'):
delattr(module, 'old_weight')
if hasattr(module, 'ema_decay'):
delattr(module, 'ema_decay')
if hasattr(module, 'tracked_min_biased'):
delattr(module, 'tracked_min_biased')
if hasattr(module, 'tracked_max_biased'):
delattr(module, 'tracked_max_biased')
if hasattr(module, 'tracked_min'):
delattr(module, 'tracked_min')
if hasattr(module, 'tracked_max'):
delattr(module, 'tracked_max')
if hasattr(module, 'scale'):
delattr(module, 'scale')
if hasattr(module, 'zero_point'):
delattr(module, 'zero_point')
if hasattr(module, 'weight_bit'):
delattr(module, 'weight_bit')
if hasattr(module, 'activation_bit'):
delattr(module, 'activation_bit')
Copy link
Contributor

@QuanluZhang QuanluZhang Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to use for.

to_del = ['old_weight', 'ema_decay', ...]
for each in to_del:
    if hasattr():
        delattr()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Modified.


def validate_config(self, model, config_list):
"""
Parameters
Expand Down Expand Up @@ -256,13 +283,15 @@ def quantize_weight(self, wrapper, **kwargs):
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bit = torch.Tensor([weight_bits])
wrapper.module.weight = weight
return weight

def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output')
module.activation_bit = torch.Tensor([output_bits])
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"

Expand All @@ -282,6 +311,55 @@ def quantize_output(self, output, wrapper, **kwargs):
out = self._dequantize(module, out)
return out

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'):
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased)
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased)
self.del_simulated_attr(module)

torch.save(self.bound_model.state_dict(), model_path)
logger.info('Model state_dict saved to %s', model_path)
if calibration_path is not None:
torch.save(calibration_config, calibration_path)
logger.info('Mask dict saved to %s', calibration_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tested export onnx? and better to write test for this feature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has tested export_model function including pytorch state_dict() and onnx in three algorithms.


return calibration_config

def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
Expand All @@ -301,6 +379,19 @@ class DoReFaQuantizer(Quantizer):

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))

def del_simulated_attr(self, module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> _del_simulated_attr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

"""
delete redundant parameters in quantize module
"""
if hasattr(module, 'old_weight'):
delattr(module, 'old_weight')
if hasattr(module, 'weight_bit'):
delattr(module, 'weight_bit')

def validate_config(self, model, config_list):
"""
Expand Down Expand Up @@ -330,6 +421,7 @@ def quantize_weight(self, wrapper, **kwargs):
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight = weight
wrapper.module.weight_bit = torch.Tensor([weight_bits])
# wrapper.module.weight.data = weight
return weight

Expand All @@ -338,6 +430,50 @@ def quantize(self, input_ri, q_bits):
output = torch.round(input_ri * scale) / scale
return output

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'):
calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this quantizer does not calibrate activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our current implementation of Dorefa, it does not quantize activation, so we don't need to calibrate activation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you double check the paper, does the paper mention how to calibrate activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion, we reach an agreement that the refactor of Dorefa should starts after survey and it will be done in another PR. What' s more, ut related to export_model() has been added into code.

self.del_simulated_attr(module)

torch.save(self.bound_model.state_dict(), model_path)
logger.info('Model state_dict saved to %s', model_path)
if calibration_path is not None:
torch.save(calibration_config, calibration_path)
logger.info('Mask dict saved to %s', calibration_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)

return calibration_config


class ClipGrad(QuantGrad):
@staticmethod
Expand Down Expand Up @@ -391,3 +527,37 @@ def quantize_output(self, output, wrapper, **kwargs):
# remove zeros
out[out == 0] = 1
return out

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()

torch.save(self.bound_model.state_dict(), model_path)
logger.info('Model state_dict saved to %s', model_path)
if calibration_path is not None:
logger.info('No calibration config will be saved because no calibration data in BNN quantizer')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should export bit number even they are all 1 bit. Because the speedup module should know this information to use 1 bit. the speedup module does not know you are using BNNQuantizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Have added.

if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other implementations return a dict, add return {} is better?

And all export model() implementations seem to have mostly the same logic, can we use the implementation in QAT_Quantizer for all Quantizer? Or just specify how to construct calibration_config in different Quantizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful suggestions! Have modified.

21 changes: 20 additions & 1 deletion nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def _setattr(model, name, module):
model = getattr(model, name)
setattr(model, name_list[-1], module)


class Compressor:
"""
Abstract base PyTorch compressor
Expand Down Expand Up @@ -573,6 +572,26 @@ def _wrap_modules(self, layer, config):

return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
raise NotImplementedError('Quantizer must overload export_model()')

def step_with_optimizer(self):
pass

Expand Down