-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add bias quantization in QAT and refactor the code of weight quantization #2914
Changes from 6 commits
1c5954b
ffba6c6
f39b1df
42e4869
b075e50
52aa12e
e7ce8ab
578195f
8b4f492
8bd4d7c
9b09924
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
class NaiveQuantizer(Quantizer): | ||
"""quantize weight to 8 bits | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
self.layer_scale = {} | ||
|
@@ -29,12 +30,13 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
def quantize_weight(self, wrapper, **kwargs): | ||
weight = wrapper.module.weight.data | ||
new_scale = weight.abs().max() / 127 | ||
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) | ||
self.layer_scale[wrapper.name] = scale | ||
orig_type = weight.type() # TODO: user layer | ||
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale) | ||
wrapper.module.weight.data = weight.div(scale).type(torch.int8).type(orig_type).mul(scale) | ||
|
||
|
||
def update_ema(biased_ema, value, decay, step): | ||
|
@@ -60,6 +62,7 @@ def update_ema(biased_ema, value, decay, step): | |
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction | ||
return biased_ema, unbiased_ema | ||
|
||
|
||
def update_quantization_param(bits, rmin, rmax): | ||
""" | ||
calculate the `zero_point` and `scale`. | ||
|
@@ -116,6 +119,7 @@ class QAT_Quantizer(Quantizer): | |
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference | ||
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
""" | ||
Parameters | ||
|
@@ -215,20 +219,34 @@ def _dequantize(self, op, quantized_val): | |
real_val = op.scale * (quantized_val - op.zero_point) | ||
return real_val | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
def quantize_weight(self, wrapper, **kwargs): | ||
config = wrapper.config | ||
module = wrapper.module | ||
weight = wrapper.module.weight.data | ||
weight_bits = get_bits_length(config, 'weight') | ||
quant_start_step = config.get('quant_start_step', 0) | ||
assert weight_bits >= 1, "quant bits length should be at least 1" | ||
|
||
if quant_start_step > self.steps: | ||
return weight | ||
return | ||
|
||
# if bias exists, quantize bias to uint32 | ||
if not hasattr(wrapper.module, 'bias') or wrapper.module.bias is None: | ||
return | ||
bias = wrapper.module.bias.data | ||
bias_bits = 32 | ||
rmin, rmax = torch.min(bias), torch.max(bias) | ||
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax) | ||
bias = self._quantize(bias_bits, module, bias) | ||
bias = self._dequantize(module, bias) | ||
wrapper.module.bias.data = bias | ||
|
||
# quantize weight | ||
rmin, rmax = torch.min(weight), torch.max(weight) | ||
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
out = self._quantize(weight_bits, module, weight) | ||
out = self._dequantize(module, out) | ||
return out | ||
weight = self._quantize(weight_bits, module, weight) | ||
weight = self._dequantize(module, weight) | ||
wrapper.module.weight.data = weight | ||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
config = wrapper.config | ||
|
@@ -241,8 +259,10 @@ def quantize_output(self, output, wrapper, **kwargs): | |
return output | ||
|
||
current_min, current_max = torch.min(output), torch.max(output) | ||
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps) | ||
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps) | ||
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, | ||
module.ema_decay, self.steps) | ||
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, | ||
module.ema_decay, self.steps) | ||
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) | ||
out = self._quantize(output_bits, module, output) | ||
out = self._dequantize(module, out) | ||
|
@@ -264,6 +284,7 @@ class DoReFaQuantizer(Quantizer): | |
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients | ||
(https://arxiv.org/abs/1606.06160) | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
|
||
|
@@ -287,17 +308,18 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
def quantize_weight(self, wrapper, **kwargs): | ||
weight = wrapper.module.weight.data | ||
weight_bits = get_bits_length(wrapper.config, 'weight') | ||
out = weight.tanh() | ||
out = out / (2 * out.abs().max()) + 0.5 | ||
out = self.quantize(out, weight_bits) | ||
out = 2 * out -1 | ||
return out | ||
weight = weight.tanh() | ||
weight = weight / (2 * weight.abs().max()) + 0.5 | ||
weight = self.quantize(weight, weight_bits) | ||
weight = 2 * weight - 1 | ||
wrapper.module.weight.data = weight | ||
|
||
def quantize(self, input_ri, q_bits): | ||
scale = pow(2, q_bits)-1 | ||
output = torch.round(input_ri*scale)/scale | ||
scale = pow(2, q_bits) - 1 | ||
output = torch.round(input_ri * scale) / scale | ||
return output | ||
|
||
|
||
|
@@ -314,6 +336,7 @@ class BNNQuantizer(Quantizer): | |
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 | ||
(https://arxiv.org/abs/1602.02830) | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
self.quant_grad = ClipGrad | ||
|
@@ -339,11 +362,12 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, weight, wrapper, **kwargs): | ||
out = torch.sign(weight) | ||
def quantize_weight(self, wrapper, **kwargs): | ||
weight = wrapper.module.weight.data | ||
weight = torch.sign(weight) | ||
# remove zeros | ||
out[out == 0] = 1 | ||
return out | ||
weight[weight == 0] = 1 | ||
wrapper.module.weight.data = weight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returned value is quantized weight. There is no need to return in this version because weight is quantized in place in the function quantize_weight(). |
||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
out = torch.sign(output) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -234,20 +234,34 @@ def test_torch_QAT_quantizer(self): | |
model.relu = torch.nn.ReLU() | ||
quantizer = torch_compressor.QAT_Quantizer(model, config_list) | ||
quantizer.compress() | ||
|
||
# test quantize | ||
# range not including 0 | ||
eps = 1e-7 | ||
weight = torch.tensor([[1, 2], [3, 5]]).float() | ||
quantize_weight = quantizer.quantize_weight(weight, model.conv2) | ||
model.conv2.module.weight.data = weight | ||
quantizer.quantize_weight(model.conv2) | ||
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) | ||
assert model.conv2.module.zero_point == 0 | ||
# range including 0 | ||
weight = torch.tensor([[-1, 2], [3, 5]]).float() | ||
quantize_weight = quantizer.quantize_weight(weight, model.conv2) | ||
model.conv2.module.weight.data = weight | ||
quantizer.quantize_weight(model.conv2) | ||
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) | ||
linbinskn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert model.conv2.module.zero_point in (42, 43) | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# test value of weight and bias after quantization | ||
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]]) | ||
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious, how are the values of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weight_valid and bias_valid are calculated by quantization function manually. I will modify the test case and annotation after code freeze. |
||
bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341]) | ||
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341]) | ||
model.conv2.module.weight.data = weight | ||
model.conv2.module.bias.data = bias | ||
quantizer.quantize_weight(model.conv2) | ||
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4)) | ||
assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7)) | ||
|
||
# test ema | ||
eps = 1e-7 | ||
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) | ||
out = model.relu(x) | ||
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add return here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I have fixed it.