From af929fdb848092a7b225a498e6d82d82cf6babfa Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 18 May 2021 11:08:10 +0800 Subject: [PATCH] Add LSQ quantizer (#3503) --- docs/en_US/Compression/Overview.rst | 2 + docs/en_US/Compression/Quantizer.rst | 56 +++++ .../quantization/LSQ_torch_quantizer.py | 142 ++++++++++++ .../pytorch/quantization/quantizers.py | 214 +++++++++++++++++- nni/compression/pytorch/compressor.py | 40 ++-- 5 files changed, 435 insertions(+), 19 deletions(-) create mode 100644 examples/model_compress/quantization/LSQ_torch_quantizer.py diff --git a/docs/en_US/Compression/Overview.rst b/docs/en_US/Compression/Overview.rst index 5b63927af6..262d9631f1 100644 --- a/docs/en_US/Compression/Overview.rst +++ b/docs/en_US/Compression/Overview.rst @@ -87,6 +87,8 @@ Quantization algorithms compress the original network by reducing the number of - DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper `__ * - `BNN Quantizer <../Compression/Quantizer.rst#bnn-quantizer>`__ - Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper `__ + * - `LSQ Quantizer <../Compression/Quantizer.rst#lsq-quantizer>`__ + - Learned step size quantization. `Reference Paper `__ Model Speedup diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index cc164c5296..2af973e4f2 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -8,6 +8,7 @@ Index of supported quantization algorithms * `QAT Quantizer <#qat-quantizer>`__ * `DoReFa Quantizer <#dorefa-quantizer>`__ * `BNN Quantizer <#bnn-quantizer>`__ +* `LSQ Quantizer <#lsq-quantizer>`__ Naive Quantizer --------------- @@ -86,6 +87,61 @@ note batch normalization folding is currently not supported. +---- + +LSQ Quantizer +------------- + +In `LEARNED STEP SIZE QUANTIZATION `__\ , authors Steven K. Esser and Jeffrey L. McKinstry provide an algorithm to train the scales with gradients. + +.. + + The authors introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer’s quantizer step size, such that it can be learned in conjunction with other network parameters. + + +Usage +^^^^^ +You can add codes below before your training codes. Three things must be done: + + +1. configure which layer to be quantized and which tensor (input/output/weight) of that layer to be quantized. +2. construct the lsq quantizer +3. call the `compress` API + + +PyTorch code + +.. code-block:: python + + from nni.algorithms.compression.pytorch.quantization import LsqQuantizer + model = Mnist() + + configure_list = [{ + 'quant_types': ['weight', 'input'], + 'quant_bits': { + 'weight': 8, + 'input': 8, + }, + 'op_names': ['conv1'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8,}, + 'op_names': ['relu1'] + }] + + quantizer = LsqQuantizer(model, configure_list, optimizer) + quantizer.compress() + +You can view example for more information. :githublink:`examples/model_compress/quantization/LSQ_torch_quantizer.py ` + +User configuration for LSQ Quantizer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +common configuration needed by compression algorithms can be found at `Specification of `config_list <./QuickStart.rst>`__. + +configuration needed by this algorithm : + + ---- DoReFa Quantizer diff --git a/examples/model_compress/quantization/LSQ_torch_quantizer.py b/examples/model_compress/quantization/LSQ_torch_quantizer.py new file mode 100644 index 0000000000..449a4e179c --- /dev/null +++ b/examples/model_compress/quantization/LSQ_torch_quantizer.py @@ -0,0 +1,142 @@ +import torch +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.algorithms.compression.pytorch.quantization import LsqQuantizer +from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT + + +class Mnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + self.relu1 = torch.nn.ReLU6() + self.relu2 = torch.nn.ReLU6() + self.relu3 = torch.nn.ReLU6() + self.max_pool1 = torch.nn.MaxPool2d(2, 2) + self.max_pool2 = torch.nn.MaxPool2d(2, 2) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.max_pool1(x) + x = self.relu2(self.conv2(x)) + x = self.max_pool2(x) + x = x.view(-1, 4 * 4 * 50) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train(model, quantizer, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + + +def test_trt(engine, test_loader): + test_loss = 0 + correct = 0 + time_elasped = 0 + for data, target in test_loader: + output, time = engine.inference(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + time_elasped += time + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) + + +def main(): + torch.manual_seed(0) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=trans), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=False, transform=trans), + batch_size=1000, shuffle=True) + + model = Mnist() + configure_list = [{ + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv1'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8, }, + 'op_names': ['relu1'] + }, { + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['relu2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['max_pool2'] + } + ] + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + quantizer = LsqQuantizer(model, configure_list, optimizer) + quantizer.compress() + + model.to(device) + for epoch in range(40): + print('# Epoch {} #'.format(epoch)) + train(model, quantizer, device, train_loader, optimizer) + test(model, device, test_loader) + + model_path = "mnist_model.pth" + calibration_path = "mnist_calibration.pth" + calibration_config = quantizer.export_model(model_path, calibration_path) + + test(model, device, test_loader) + + print("calibration_config: ", calibration_config) + + batch_size = 32 + input_shape = (batch_size, 1, 28, 28) + + engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size) + engine.compress() + + test_trt(engine, test_loader) + + +if __name__ == '__main__': + main() diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index ca40e30e45..62703d449b 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -6,9 +6,9 @@ import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import CompressorSchema -from nni.compression.pytorch.compressor import Quantizer, QuantGrad, QuantType +from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType -__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] +__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def update_ema(biased_ema, value, decay): float, float """ biased_ema = biased_ema * decay + (1 - decay) * value - return biased_ema + return biased_ema def update_quantization_param(bits, rmin, rmax): @@ -146,7 +146,7 @@ def __init__(self, model, config_list, optimizer=None): types of nn.module you want to apply quantization, eg. 'Conv2d' """ super().__init__(model, config_list, optimizer) - self.quant_grad = QATGrad + self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() self.bound_model.register_buffer("steps", torch.Tensor([1])) for layer, config in modules_to_compress: @@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) - self.quant_grad = ClipGrad + self.quant_grad = ClipGrad.apply modules_to_compress = self.get_modules_to_compress() for layer, config in modules_to_compress: if "weight" in config.get("quant_types", []): @@ -559,4 +559,206 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) - return calibration_config \ No newline at end of file + return calibration_config + + +class LsqQuantizer(Quantizer): + """Quantizer defined in: + Learned Step Size Quantization (ICLR 2020) + https://arxiv.org/pdf/1902.08153.pdf + """ + + def __init__(self, model, config_list, optimizer=None): + """ + Parameters + ---------- + model : torch.nn.Module + the model to be quantized + config_list : list of dict + list of configurations for quantization + supported keys for dict: + - quant_types : list of string + type of quantization you want to apply, currently support 'weight', 'input', 'output' + - quant_bits : int or dict of {str : int} + bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, + when the type is int, all quantization types share same bits length + - quant_start_step : int + disable quantization until model are run by certain number of steps, this allows the network to enter a more stable + state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 + - op_types : list of string + types of nn.module you want to apply quantization, eg. 'Conv2d' + """ + super().__init__(model, config_list, optimizer) + self.quant_grad = QuantForward() + modules_to_compress = self.get_modules_to_compress() + self.bound_model.register_buffer("steps", torch.Tensor([1])) + for layer, config in modules_to_compress: + if "weight" in config.get("quant_types", []): + layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0]))) + # todo: support per-channel quantization for weight since TensorRT use it for conv weight + q_bit = get_bits_length(config, "weight") + layer.module.register_buffer('weight_bit', torch.Tensor([q_bit])) + qmax = 2 ** (q_bit - 1) - 1 + qmin = -2 ** (q_bit - 1) + init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5) + layer.module.weight_scale = torch.nn.Parameter(init_weight_scale) + layer.module.weight_qmax = qmax + layer.module.weight_qmin = qmin + + self.optimizer.add_param_group({"params": layer.module.weight_scale}) + + if "output" in config.get("quant_types", []): + # scale of activation will be initialized using the first batch data + layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0]))) + q_bit = get_bits_length(config, "output") + layer.module.register_buffer('output_bit', torch.Tensor([q_bit])) + qmax = 2 ** (q_bit - 1) - 1 + qmin = -2 ** (q_bit - 1) + layer.module.output_qmax = qmax + layer.module.output_qmin = qmin + + self.optimizer.add_param_group({"params": layer.module.output_scale}) + + if "input" in config.get("quant_types", []): + # scale of input will be initialized using the first batch data + layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0]))) + q_bit = get_bits_length(config, "input") + layer.module.register_buffer('input_bit', torch.Tensor([q_bit])) + qmax = 2 ** (q_bit - 1) - 1 + qmin = -2 ** (q_bit - 1) + layer.module.input_qmax = qmax + layer.module.input_qmin = qmin + + self.optimizer.add_param_group({"params": layer.module.input_scale}) + + @staticmethod + def grad_scale(x, scale): + """ + Used to scale the gradient. Give tensor `x`, we have `y=grad_scale(x, scale)=x` in the forward pass, + which means that this function will not change the value of `x`. In the backward pass, we have: + + :math:`\frac{\alpha_L}{\alpha_x}=\frac{\alpha_L}{\alpha_y}*\frac{\alpha_y}{\alpha_x}=sclae*\frac{\alpha_L}{\alpha_x}` + + This means that the origin gradient of x is scaled by a factor of `scale`. Applying this function + to a nn.Parameter will scale the gradient of it without changing its value. + """ + y = x + y_grad = x * scale + return (y - y_grad).detach() + y_grad + + @staticmethod + def round_pass(x): + """ + A simple way to achieve STE operation. + """ + y = x.round() + y_grad = x + return (y - y_grad).detach() + y_grad + + def quantize(self, x, scale, qmin, qmax): + grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5) + scale = self.grad_scale(scale, grad_scale_factor) + x = x / scale + x = torch.clamp(x, qmin, qmax) + x = self.round_pass(x) + x = x * scale + return x + + def quantize_weight(self, wrapper, **kwargs): + module = wrapper.module + + # todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize + # bias + old_weight = module.old_weight + weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax) + module.weight = weight + return weight + + def quantize_output(self, output, wrapper, **kwargs): + module = wrapper.module + + # initialize the scale + if self.bound_model.steps == 1: + qmax = module.output_qmax + init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5) + module.output_scale.data = init_oup_scale + + output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax) + return output + + def quantize_input(self, *inputs, wrapper, **kwargs): + # This is hacky since it is not recommended to modify a tuple + # NB: support layers with multi inputs + module = wrapper.module + # initialize the scale + if self.bound_model.steps == 1: + qmax = module.input_qmax + init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5) + module.input_scale.data = init_oup_scale + + new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax) + list_inp = list(inputs) + list_inp[0] = new_input + return tuple(list_inp) + + def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): + """ + Export quantized model weights and calibration parameters(optional) + + Parameters + ---------- + model_path : str + path to save quantized model weight + calibration_path : str + (optional) path to save quantize parameters after calibration + onnx_path : str + (optional) path to save onnx model + input_shape : list or tuple + input shape to onnx model + device : torch.device + device of the model, used to place the dummy input tensor for exporting onnx file. + the tensor is placed on cpu if ```device``` is None + + Returns + ------- + Dict + """ + assert model_path is not None, 'model_path must be specified' + self._unwrap_model() + calibration_config = {} + + for name, module in self.bound_model.named_modules(): + if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'): + calibration_config[name] = {} + if hasattr(module, 'weight_bit'): + calibration_config[name]['weight_bit'] = int(module.weight_bit) + abs_max_input = float(module.input_scale * module.input_qmax) + calibration_config[name]['tracked_min_input'] = -abs_max_input + calibration_config[name]['tracked_max_input'] = abs_max_input + if hasattr(module, 'output_bit'): + calibration_config[name]['activation_bit'] = int(module.output_bit) + abs_max_output = float(module.output_scale * module.output_qmax) + calibration_config[name]['tracked_min_activation'] = -abs_max_output + calibration_config[name]['tracked_max_activation'] = abs_max_output + self._del_simulated_attr(module) + + self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, + input_shape, device) + + return calibration_config + + def _del_simulated_attr(self, module): + """ + delete redundant parameters in quantize module + """ + del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \ + 'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit'] + for attr in del_attr_list: + if hasattr(module, attr): + delattr(module, attr) + + def step_with_optimizer(self): + """ + override `compressor` `step` method, quantization only happens after certain number of steps + """ + self.bound_model.steps += 1 diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 7fecdc3b4f..08543caf1a 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -474,13 +474,13 @@ def __init__(self, module, module_name, module_type, config, quantizer): def forward(self, *inputs): if 'input' in self.config['quant_types']: - inputs = self.quantizer.quant_grad.apply( + inputs = self.quantizer.quant_grad( inputs, QuantType.QUANT_INPUT, self) if 'weight' in self.config['quant_types'] and _check_weight(self.module): - self.quantizer.quant_grad.apply( + self.quantizer.quant_grad( self.module.old_weight, QuantType.QUANT_WEIGHT, self, inputs[0]) @@ -489,12 +489,13 @@ def forward(self, *inputs): result = self.module(*inputs) if 'output' in self.config['quant_types']: - result = self.quantizer.quant_grad.apply( + result = self.quantizer.quant_grad( result, QuantType.QUANT_OUTPUT, self) return result + class Quantizer(Compressor): """ Base quantizer for pytorch quantizer @@ -502,7 +503,7 @@ class Quantizer(Compressor): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) - self.quant_grad = QuantGrad + self.quant_grad = QuantGrad.apply if self.optimizer is not None: self.patch_optimizer(self.step_with_optimizer) for wrapper in self.get_modules_wrapper(): @@ -719,15 +720,7 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma @staticmethod def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): - if quant_type == QuantType.QUANT_INPUT: - output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) - elif quant_type == QuantType.QUANT_WEIGHT: - output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) - elif quant_type == QuantType.QUANT_OUTPUT: - output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) - else: - raise ValueError("unrecognized QuantType.") - + output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs) bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) @@ -750,3 +743,24 @@ def _check_weight(module): return isinstance(module.weight.data, torch.Tensor) except AttributeError: return False + +def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): + if quant_type == QuantType.QUANT_INPUT: + output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs) + elif quant_type == QuantType.QUANT_WEIGHT: + output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) + elif quant_type == QuantType.QUANT_OUTPUT: + output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) + else: + raise ValueError("unrecognized QuantType.") + + return output + +class QuantForward(torch.nn.Module): + """ + Base class for executing quantization operations. This is for quantization algorithms + that do not need to customize gradient. + """ + + def forward(self, tensor, quant_type, wrapper, input_tensor=None, **kwargs): + return quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs)