This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add bias quantization in QAT and refactor the code of weight quantization #2914
Merged
Merged
Changes from 4 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
1c5954b
Add bias quantization in QAT and refactor the code of weight quantiza…
linbinskn ffba6c6
delete quantize_weight test
linbinskn f39b1df
change the test file
linbinskn 42e4869
delete the whitespace
linbinskn b075e50
modify test code
linbinskn 52aa12e
add weight and bias value validation
linbinskn e7ce8ab
modify bias judge problem
linbinskn 578195f
Update quantizers.py
linbinskn 8b4f492
fix low accuracy problem
linbinskn 8bd4d7c
fix test error
linbinskn 9b09924
change import order
linbinskn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
class NaiveQuantizer(Quantizer): | ||
"""quantize weight to 8 bits | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
self.layer_scale = {} | ||
|
@@ -29,12 +30,13 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
def quantize_weight(self, wrapper, **kwargs): | ||
weight = wrapper.module.weight.data | ||
new_scale = weight.abs().max() / 127 | ||
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) | ||
self.layer_scale[wrapper.name] = scale | ||
orig_type = weight.type() # TODO: user layer | ||
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale) | ||
wrapper.module.weight.data = weight.div(scale).type(torch.int8).type(orig_type).mul(scale) | ||
|
||
|
||
def update_ema(biased_ema, value, decay, step): | ||
|
@@ -60,6 +62,7 @@ def update_ema(biased_ema, value, decay, step): | |
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction | ||
return biased_ema, unbiased_ema | ||
|
||
|
||
def update_quantization_param(bits, rmin, rmax): | ||
""" | ||
calculate the `zero_point` and `scale`. | ||
|
@@ -116,6 +119,7 @@ class QAT_Quantizer(Quantizer): | |
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference | ||
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
""" | ||
Parameters | ||
|
@@ -215,20 +219,32 @@ def _dequantize(self, op, quantized_val): | |
real_val = op.scale * (quantized_val - op.zero_point) | ||
return real_val | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
def quantize_weight(self, wrapper, **kwargs): | ||
config = wrapper.config | ||
module = wrapper.module | ||
weight = wrapper.module.weight.data | ||
weight_bits = get_bits_length(config, 'weight') | ||
quant_start_step = config.get('quant_start_step', 0) | ||
assert weight_bits >= 1, "quant bits length should be at least 1" | ||
|
||
if quant_start_step > self.steps: | ||
return weight | ||
return | ||
rmin, rmax = torch.min(weight), torch.max(weight) | ||
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
out = self._quantize(weight_bits, module, weight) | ||
out = self._dequantize(module, out) | ||
return out | ||
weight = self._quantize(weight_bits, module, weight) | ||
weight = self._dequantize(module, weight) | ||
wrapper.module.weight.data = weight | ||
|
||
# if bias exists, quantize bias to uint32 | ||
if not hasattr(wrapper.module, 'bias') or wrapper.module.bias is None: | ||
return | ||
bias = wrapper.module.bias.data | ||
bias_bits = 32 | ||
rmin, rmax = torch.min(bias), torch.max(bias) | ||
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax) | ||
bias = self._quantize(bias_bits, module, bias) | ||
bias = self._dequantize(module, bias) | ||
wrapper.module.bias.data = bias | ||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
config = wrapper.config | ||
|
@@ -241,8 +257,10 @@ def quantize_output(self, output, wrapper, **kwargs): | |
return output | ||
|
||
current_min, current_max = torch.min(output), torch.max(output) | ||
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps) | ||
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps) | ||
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, | ||
module.ema_decay, self.steps) | ||
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, | ||
module.ema_decay, self.steps) | ||
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) | ||
out = self._quantize(output_bits, module, output) | ||
out = self._dequantize(module, out) | ||
|
@@ -264,6 +282,7 @@ class DoReFaQuantizer(Quantizer): | |
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients | ||
(https://arxiv.org/abs/1606.06160) | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
|
||
|
@@ -287,17 +306,18 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
def quantize_weight(self, wrapper, **kwargs): | ||
weight = wrapper.module.weight.data | ||
weight_bits = get_bits_length(wrapper.config, 'weight') | ||
out = weight.tanh() | ||
out = out / (2 * out.abs().max()) + 0.5 | ||
out = self.quantize(out, weight_bits) | ||
out = 2 * out -1 | ||
return out | ||
weight = weight.tanh() | ||
weight = weight / (2 * weight.abs().max()) + 0.5 | ||
weight = self.quantize(weight, weight_bits) | ||
weight = 2 * weight - 1 | ||
wrapper.module.weight.data = weight | ||
|
||
def quantize(self, input_ri, q_bits): | ||
scale = pow(2, q_bits)-1 | ||
output = torch.round(input_ri*scale)/scale | ||
scale = pow(2, q_bits) - 1 | ||
output = torch.round(input_ri * scale) / scale | ||
return output | ||
|
||
|
||
|
@@ -314,6 +334,7 @@ class BNNQuantizer(Quantizer): | |
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 | ||
(https://arxiv.org/abs/1602.02830) | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
self.quant_grad = ClipGrad | ||
|
@@ -339,11 +360,12 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
out = torch.sign(weight) | ||
def quantize_weight(self, wrapper, **kwargs): | ||
weight = wrapper.module.weight.data | ||
weight = torch.sign(weight) | ||
# remove zeros | ||
out[out == 0] = 1 | ||
return out | ||
weight[weight == 0] = 1 | ||
wrapper.module.weight.data = weight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returned value is quantized weight. There is no need to return in this version because weight is quantized in place in the function quantize_weight(). |
||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
out = torch.sign(output) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add return here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I have fixed it.