Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add bias quantization in QAT and refactor the code of weight quantization #2914

Merged
merged 11 commits into from
Oct 10, 2020
6 changes: 3 additions & 3 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,10 @@ def forward(self, *inputs):
self)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
Expand Down Expand Up @@ -617,7 +616,8 @@ def forward(ctx, tensor, quant_type, wrapper, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
wrapper.quantizer.quantize_weight(wrapper, **kwargs)
return tensor
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
Expand Down
66 changes: 45 additions & 21 deletions src/sdk/pynni/nni/compression/torch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
Expand All @@ -29,12 +30,13 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight.data
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
wrapper.module.weight.data = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)


def update_ema(biased_ema, value, decay, step):
Expand All @@ -60,6 +62,7 @@ def update_ema(biased_ema, value, decay, step):
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema


def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.
Expand Down Expand Up @@ -116,6 +119,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None):
"""
Parameters
Expand Down Expand Up @@ -215,20 +219,34 @@ def _dequantize(self, op, quantized_val):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = wrapper.module.weight.data
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.steps:
return weight
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add return here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have fixed it.


# if bias exists, quantize bias to uint32
if not hasattr(wrapper.module, 'bias') or wrapper.module.bias is None:
return
bias = wrapper.module.bias.data
bias_bits = 32
rmin, rmax = torch.min(bias), torch.max(bias)
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax)
bias = self._quantize(bias_bits, module, bias)
bias = self._dequantize(module, bias)
wrapper.module.bias.data = bias

# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out)
return out
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
wrapper.module.weight.data = weight

def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
Expand All @@ -241,8 +259,10 @@ def quantize_output(self, output, wrapper, **kwargs):
return output

current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
Expand All @@ -264,6 +284,7 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)

Expand All @@ -287,17 +308,18 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight.data
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out
weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight.data = weight

def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
scale = pow(2, q_bits) - 1
output = torch.round(input_ri * scale) / scale
return output


Expand All @@ -314,6 +336,7 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
Expand All @@ -339,11 +362,12 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight.data
weight = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
weight[weight == 0] = 1
wrapper.module.weight.data = weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return is removed? returned value is never used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returned value is quantized weight. There is no need to return in this version because weight is quantized in place in the function quantize_weight().


def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
Expand Down
18 changes: 16 additions & 2 deletions src/sdk/pynni/tests/test_compressor_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,34 @@ def test_torch_QAT_quantizer(self):
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()

# test quantize
# range not including 0
eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
model.conv2.module.weight.data = weight
quantizer.quantize_weight(model.conv2)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
linbinskn marked this conversation as resolved.
Show resolved Hide resolved
assert model.conv2.module.zero_point in (42, 43)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, how are the values of weight_valid and bias_valid calculated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight_valid and bias_valid are calculated by quantization function manually. I will modify the test case and annotation after code freeze.

bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
model.conv2.module.weight.data = weight
model.conv2.module.bias.data = bias
quantizer.quantize_weight(model.conv2)
assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))

# test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
Expand Down