From 92e0a9d54134dda66d2579083200f6707528b51f Mon Sep 17 00:00:00 2001 From: vmpuri Date: Thu, 24 Oct 2024 12:43:48 -0700 Subject: [PATCH] Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization --- torchchat/utils/quantize.py | 229 +++++++++++------------------------- 1 file changed, 66 insertions(+), 163 deletions(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 31c639dfd..bda695ae2 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -26,7 +26,7 @@ # from functools import reduce # from math import gcd -from typing import Dict, Optional, Callable, Any, List +from typing import Any, Callable, Dict, List, Optional import torch import torch.nn as nn @@ -37,6 +37,7 @@ from torchao.quantization.quant_api import ( int4_weight_only, Int4WeightOnlyQuantizer, + int8_weight_only, Int8DynActInt4WeightQuantizer, quantize_, ) @@ -45,8 +46,8 @@ find_multiple, get_device_str, get_precision, - set_precision, name_to_dtype, + set_precision, state_dict_device, use_et_backend, ) @@ -60,28 +61,36 @@ import inspect + def get_named_parameters(func: Callable) -> List[str]: # Get the signature of the function signature = inspect.signature(func) - + # Extract the parameters from the signature parameters = signature.parameters - + # Filter and return named parameters named_params = [ - name for name, param in parameters.items() - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + name + for name, param in parameters.items() + if param.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) ] return named_params -def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: + +def validate_args( + named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None +) -> Dict[str, Any]: for key in q_kwargs.keys(): if key not in named_params: - print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") + print( + f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring." + ) del q_kwargs[key] return q_kwargs - - + + ######################################################################### ### torchchat quantization API ### @@ -110,21 +119,30 @@ def quantize_model( if quantizer not in quantizer_class_dict: raise RuntimeError(f"unknown quantizer {quantizer} specified") else: + ao_quant = True # Use tensor subclass API for int4 weight only. if device == "cuda" and quantizer == "linear:int4": quantize_(model, int4_weight_only(q_kwargs["groupsize"])) + elif quantizer == "linear:int8": + print("quantizer is linear int8") + quantize_(model, int8_weight_only()) + else: + ao_quant = False + if ao_quant: if not support_tensor_subclass: unwrap_tensor_subclass(model) continue - + if quantizer in ["linear:a8wxdq", "embedding:wx"]: # These quantizers require float32 input weights. Note that after quantization, # the weights will no longer be float32, but lowbit integers if get_precision() != torch.float32: - print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") + print( + f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." + ) set_precision(torch.float32) - - # We set global precision from quantize options if it is specified at cli.py:485 + + # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat precision = get_precision() @@ -141,14 +159,19 @@ def quantize_model( model = quant_handler.quantize(model) - ######################################################################### ### QuantHandler API definition ### ### (unify with torchao in future) ### class QuantHandler: - def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None): + def __init__( + self, + model: Optional[nn.Module] = None, + device="cpu", + precision=None, + tokenizer=None, + ): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module: class PrecisionHandler(QuantHandler): - def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype): + def __init__( + self, + model: Optional[nn.Module] = None, + device="cpu", + precision=None, + tokenizer=None, + *, + dtype, + ): self.model_ = model self.device = device self.tokenizer = tokenizer @@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module: class ExecutorHandler(QuantHandler): - def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator): + def __init__( + self, + model: Optional[nn.Module] = None, + device="cpu", + precision=None, + tokenizer=None, + *, + accelerator, + ): self.model_ = model if isinstance(accelerator, str): @@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales): ) -class WeightOnlyInt8Linear(nn.Module): - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - weight: torch.Tensor - scales: torch.Tensor - - def __init__( - self, - in_features, - out_features, - bias=None, - device=None, - dtype=None, - *, - weight: Optional[torch.Tensor] = None, - scales: Optional[torch.Tensor] = None, - groupsize: Optional[int] = None, - ): - super().__init__() - if dtype is None: - dtype = torch.get_default_dtype() - - if device is None: - device = "cpu" - - assert not bias, "Bias is not supported by LinearInt8" - self.in_features = in_features - self.out_features = out_features - - assert (weight is None) == bool( - scales is None - ), "must specify both weights and scales, or neither" - if weight is None: - weight = torch.empty( - (out_features, in_features), - dtype=torch.int8, - device=device, - ) - if groupsize is None or (groupsize == 0): - scales = torch.empty(out_features, dtype=dtype, device=device) - else: - n_groups = (in_features + groupsize - 1) // groupsize - scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) - - self.register_buffer("weight", weight.to(device)) - self.register_buffer("scales", scales.to(device)) - - if use_et_backend(): - self.forward = self.et_forward - else: - self.forward = self.aoti_forward - - def aoti_forward(self, input: torch.Tensor) -> torch.Tensor: - return linear_int8_aoti(input, self.weight, self.scales) - - def et_forward(self, input: torch.Tensor) -> torch.Tensor: - return linear_int8_et(input, self.weight, self.scales) - - -class WeightOnlyInt8QuantHandler(QuantHandler): - def __init__( - self, - model: Optional[nn.Module] = None, - device = None, - precision=None, - tokenizer=None, - *, - node_type: str = "*", - bitwidth: Optional[int] = None, - groupsize: Optional[int] = None, - ): - self.model_ = model - self.device = device - self.groupsize = groupsize - self.node_type = node_type - if bitwidth is None: - self.bitwidth = 8 - else: - self.bitwidth = bitwidth - - @torch.no_grad() - def quantize(self, module): - # cur_state_dict = state_dict_device(self.model_.state_dict()) - # dict_device = "cpu" # self.device - - if self.bitwidth == 4: - range_min = -8 - range_max = 7 - elif self.bitwidth == 8: - range_min = -128 - range_max = 127 - else: - raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - - for name, child in module.named_children(): - # print(f"name: {name}") - if isinstance(child, nn.Linear): - if ( - (self.node_type == "*") - or (self.node_type == "output" and name == "output") - or (self.node_type == "!output" and name != "output") - ): - # print(f"{name, child}") - input_weight = child.weight.float() - # print(f"{name, child}") - # print(f"in_features: {child.in_features}") - # print(f"out_features: {child.out_features}") - - # print(f"expanded weight shape {input_weight.shape}") - weight, scales, _ = dynamically_quantize_per_channel( - input_weight, - range_min, - range_max, - torch.int8, - self.groupsize, - scales_dtype=child.weight.dtype, - ) - - setattr( - module, - name, - WeightOnlyInt8Linear( - in_features=child.in_features, - out_features=child.out_features, - device=self.device, - # update variables from quantization - weight=weight, - scales=scales, - groupsize=self.groupsize, - ), - ) - else: - self.quantize(child) - - return module - - def quantized_model(self) -> nn.Module: - return self.quantize(self.model_) - - ######################################################################### ##### embedding table quantization ###### ### (unify with torchao in future) ### @@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module: # class references quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, - "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, "linear:int4": Int4WeightOnlyQuantizer, + "linear:int8": int8_weight_only, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, } @@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module: print("Slow fallback kernels will be used.") except Exception as e: + class ErrorHandler(QuantHandler): - def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): + def __init__( + self, model: Optional[nn.Module] = None, device="cpu", precision=None + ): global torchao_experimental_load_error - raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}") - + raise Exception( + f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}" + ) + torchao_experimental_load_error = e quantizer_class_dict["linear:a8wxdq"] = ErrorHandler quantizer_class_dict["embedding:wx"] = ErrorHandler