-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add model export for QAT #3458
Add model export for QAT #3458
Changes from 5 commits
b282c6a
72d72bb
d549fcd
62d8049
83e0d74
6d7a133
92b78f3
14ea0a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,7 +110,6 @@ def get_bits_length(config, quant_type): | |
else: | ||
return config["quant_bits"].get(quant_type) | ||
|
||
|
||
class QATGrad(QuantGrad): | ||
@staticmethod | ||
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax): | ||
|
@@ -153,13 +152,41 @@ def __init__(self, model, config_list, optimizer=None): | |
for layer, config in modules_to_compress: | ||
layer.module.register_buffer("zero_point", torch.Tensor([0.0])) | ||
layer.module.register_buffer("scale", torch.Tensor([1.0])) | ||
if "weight" in config.get("quant_types", []): | ||
layer.module.register_buffer('weight_bit', torch.zeros(1)) | ||
if "output" in config.get("quant_types", []): | ||
layer.module.register_buffer('activation_bit', torch.zeros(1)) | ||
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) | ||
layer.module.register_buffer('tracked_min_biased', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_min', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max_biased', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max', torch.zeros(1)) | ||
|
||
def del_simulated_attr(self, module): | ||
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
if hasattr(module, 'old_weight'): | ||
delattr(module, 'old_weight') | ||
if hasattr(module, 'ema_decay'): | ||
delattr(module, 'ema_decay') | ||
if hasattr(module, 'tracked_min_biased'): | ||
delattr(module, 'tracked_min_biased') | ||
if hasattr(module, 'tracked_max_biased'): | ||
delattr(module, 'tracked_max_biased') | ||
if hasattr(module, 'tracked_min'): | ||
delattr(module, 'tracked_min') | ||
if hasattr(module, 'tracked_max'): | ||
delattr(module, 'tracked_max') | ||
if hasattr(module, 'scale'): | ||
delattr(module, 'scale') | ||
if hasattr(module, 'zero_point'): | ||
delattr(module, 'zero_point') | ||
if hasattr(module, 'weight_bit'): | ||
delattr(module, 'weight_bit') | ||
if hasattr(module, 'activation_bit'): | ||
delattr(module, 'activation_bit') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggest to use
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! Modified. |
||
|
||
def validate_config(self, model, config_list): | ||
""" | ||
Parameters | ||
|
@@ -256,13 +283,15 @@ def quantize_weight(self, wrapper, **kwargs): | |
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
weight = self._quantize(weight_bits, module, weight) | ||
weight = self._dequantize(module, weight) | ||
module.weight_bit = torch.Tensor([weight_bits]) | ||
wrapper.module.weight = weight | ||
return weight | ||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
config = wrapper.config | ||
module = wrapper.module | ||
output_bits = get_bits_length(config, 'output') | ||
module.activation_bit = torch.Tensor([output_bits]) | ||
quant_start_step = config.get('quant_start_step', 0) | ||
assert output_bits >= 1, "quant bits length should be at least 1" | ||
|
||
|
@@ -282,6 +311,55 @@ def quantize_output(self, output, wrapper, **kwargs): | |
out = self._dequantize(module, out) | ||
return out | ||
|
||
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export quantized model weights and calibration parameters(optional) | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
path to save quantized model weight | ||
calibration_path : str | ||
(optional) path to save quantize parameters after calibration | ||
onnx_path : str | ||
(optional) path to save onnx model | ||
input_shape : list or tuple | ||
input shape to onnx model | ||
device : torch.device | ||
device of the model, used to place the dummy input tensor for exporting onnx file. | ||
the tensor is placed on cpu if ```device``` is None | ||
""" | ||
assert model_path is not None, 'model_path must be specified' | ||
self._unwrap_model() | ||
calibration_config = {} | ||
|
||
for name, module in self.bound_model.named_modules(): | ||
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'): | ||
calibration_config[name] = {} | ||
if hasattr(module, 'weight_bit'): | ||
calibration_config[name]['weight_bit'] = int(module.weight_bit) | ||
if hasattr(module, 'activation_bit'): | ||
calibration_config[name]['activation_bit'] = int(module.activation_bit) | ||
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased) | ||
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased) | ||
self.del_simulated_attr(module) | ||
|
||
torch.save(self.bound_model.state_dict(), model_path) | ||
logger.info('Model state_dict saved to %s', model_path) | ||
if calibration_path is not None: | ||
torch.save(calibration_config, calibration_path) | ||
logger.info('Mask dict saved to %s', calibration_path) | ||
if onnx_path is not None: | ||
assert input_shape is not None, 'input_shape must be specified to export onnx model' | ||
# input info needed | ||
if device is None: | ||
device = torch.device('cpu') | ||
input_data = torch.Tensor(*input_shape) | ||
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) | ||
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have you tested export onnx? and better to write test for this feature There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Has tested |
||
|
||
return calibration_config | ||
|
||
def fold_bn(self, config, **kwargs): | ||
# TODO simulate folded weight | ||
pass | ||
|
@@ -301,6 +379,19 @@ class DoReFaQuantizer(Quantizer): | |
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
modules_to_compress = self.get_modules_to_compress() | ||
for layer, config in modules_to_compress: | ||
if "weight" in config.get("quant_types", []): | ||
layer.module.register_buffer('weight_bit', torch.zeros(1)) | ||
|
||
def del_simulated_attr(self, module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> _del_simulated_attr There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified. |
||
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
if hasattr(module, 'old_weight'): | ||
delattr(module, 'old_weight') | ||
if hasattr(module, 'weight_bit'): | ||
delattr(module, 'weight_bit') | ||
|
||
def validate_config(self, model, config_list): | ||
""" | ||
|
@@ -330,6 +421,7 @@ def quantize_weight(self, wrapper, **kwargs): | |
weight = self.quantize(weight, weight_bits) | ||
weight = 2 * weight - 1 | ||
wrapper.module.weight = weight | ||
wrapper.module.weight_bit = torch.Tensor([weight_bits]) | ||
# wrapper.module.weight.data = weight | ||
return weight | ||
|
||
|
@@ -338,6 +430,50 @@ def quantize(self, input_ri, q_bits): | |
output = torch.round(input_ri * scale) / scale | ||
return output | ||
|
||
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export quantized model weights and calibration parameters(optional) | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
path to save quantized model weight | ||
calibration_path : str | ||
(optional) path to save quantize parameters after calibration | ||
onnx_path : str | ||
(optional) path to save onnx model | ||
input_shape : list or tuple | ||
input shape to onnx model | ||
device : torch.device | ||
device of the model, used to place the dummy input tensor for exporting onnx file. | ||
the tensor is placed on cpu if ```device``` is None | ||
""" | ||
assert model_path is not None, 'model_path must be specified' | ||
self._unwrap_model() | ||
calibration_config = {} | ||
|
||
for name, module in self.bound_model.named_modules(): | ||
if hasattr(module, 'weight_bit'): | ||
calibration_config[name] = {} | ||
calibration_config[name]['weight_bit'] = int(module.weight_bit) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this quantizer does not calibrate activation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In our current implementation of Dorefa, it does not quantize activation, so we don't need to calibrate activation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you double check the paper, does the paper mention how to calibrate activation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After discussion, we reach an agreement that the refactor of |
||
self.del_simulated_attr(module) | ||
|
||
torch.save(self.bound_model.state_dict(), model_path) | ||
logger.info('Model state_dict saved to %s', model_path) | ||
if calibration_path is not None: | ||
torch.save(calibration_config, calibration_path) | ||
logger.info('Mask dict saved to %s', calibration_path) | ||
if onnx_path is not None: | ||
assert input_shape is not None, 'input_shape must be specified to export onnx model' | ||
# input info needed | ||
if device is None: | ||
device = torch.device('cpu') | ||
input_data = torch.Tensor(*input_shape) | ||
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) | ||
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) | ||
|
||
return calibration_config | ||
|
||
|
||
class ClipGrad(QuantGrad): | ||
@staticmethod | ||
|
@@ -391,3 +527,37 @@ def quantize_output(self, output, wrapper, **kwargs): | |
# remove zeros | ||
out[out == 0] = 1 | ||
return out | ||
|
||
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export quantized model weights and calibration parameters(optional) | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
path to save quantized model weight | ||
calibration_path : str | ||
(optional) path to save quantize parameters after calibration | ||
onnx_path : str | ||
(optional) path to save onnx model | ||
input_shape : list or tuple | ||
input shape to onnx model | ||
device : torch.device | ||
device of the model, used to place the dummy input tensor for exporting onnx file. | ||
the tensor is placed on cpu if ```device``` is None | ||
""" | ||
assert model_path is not None, 'model_path must be specified' | ||
self._unwrap_model() | ||
|
||
torch.save(self.bound_model.state_dict(), model_path) | ||
logger.info('Model state_dict saved to %s', model_path) | ||
if calibration_path is not None: | ||
logger.info('No calibration config will be saved because no calibration data in BNN quantizer') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we should export bit number even they are all 1 bit. Because the speedup module should know this information to use 1 bit. the speedup module does not know you are using BNNQuantizer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense. Have added. |
||
if onnx_path is not None: | ||
assert input_shape is not None, 'input_shape must be specified to export onnx model' | ||
# input info needed | ||
if device is None: | ||
device = torch.device('cpu') | ||
input_data = torch.Tensor(*input_shape) | ||
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) | ||
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For other implementations return a dict, add And all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Useful suggestions! Have modified. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> _del_simulated_attr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified.