Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add bias quantization in QAT and refactor the code of weight quantization #2914

Merged
merged 11 commits into from
Oct 10, 2020
6 changes: 3 additions & 3 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,10 @@ def forward(self, *inputs):
self)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
Expand Down Expand Up @@ -617,7 +616,8 @@ def forward(ctx, tensor, quant_type, wrapper, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
wrapper.quantizer.quantize_weight(wrapper, **kwargs)
return tensor
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
Expand Down
64 changes: 43 additions & 21 deletions src/sdk/pynni/nni/compression/torch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
Expand All @@ -29,12 +30,13 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight.data
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
wrapper.module.weight.data = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)


def update_ema(biased_ema, value, decay, step):
Expand All @@ -60,6 +62,7 @@ def update_ema(biased_ema, value, decay, step):
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema


def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.
Expand Down Expand Up @@ -116,6 +119,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None):
"""
Parameters
Expand Down Expand Up @@ -215,20 +219,32 @@ def _dequantize(self, op, quantized_val):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = wrapper.module.weight.data
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.steps:
return weight
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add return here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have fixed it.

rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out)
return out
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
wrapper.module.weight.data = weight

# if bias exists, quantize bias to uint32
if not hasattr(wrapper.module, 'bias') or wrapper.module.bias is None:
return
bias = wrapper.module.bias.data
bias_bits = 32
rmin, rmax = torch.min(bias), torch.max(bias)
module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax)
bias = self._quantize(bias_bits, module, bias)
bias = self._dequantize(module, bias)
wrapper.module.bias.data = bias

def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
Expand All @@ -241,8 +257,10 @@ def quantize_output(self, output, wrapper, **kwargs):
return output

current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
Expand All @@ -264,6 +282,7 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)

Expand All @@ -287,17 +306,18 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight.data
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out
weight = weight.tanh()
weight = weight / (2 * weight.abs().max()) + 0.5
weight = self.quantize(weight, weight_bits)
weight = 2 * weight - 1
wrapper.module.weight.data = weight

def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
scale = pow(2, q_bits) - 1
output = torch.round(input_ri * scale) / scale
return output


Expand All @@ -314,6 +334,7 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
Expand All @@ -339,11 +360,12 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight.data
weight = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
weight[weight == 0] = 1
wrapper.module.weight.data = weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return is removed? returned value is never used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returned value is quantized weight. There is no need to return in this version because weight is quantized in place in the function quantize_weight().


def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
Expand Down
13 changes: 1 addition & 12 deletions src/sdk/pynni/tests/test_compressor_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,9 @@ def test_torch_QAT_quantizer(self):
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()
# test quantize
# range not including 0
eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
assert model.conv2.module.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, model.conv2)
assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

# test ema
eps = 1e-7
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
Expand Down