Skip to content

Commit

Permalink
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri authored and vmpuri committed Oct 24, 2024
1 parent 7fe2c86 commit 92e0a9d
Showing 1 changed file with 66 additions and 163 deletions.
229 changes: 66 additions & 163 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from functools import reduce
# from math import gcd
from typing import Dict, Optional, Callable, Any, List
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -37,6 +37,7 @@
from torchao.quantization.quant_api import (
int4_weight_only,
Int4WeightOnlyQuantizer,
int8_weight_only,
Int8DynActInt4WeightQuantizer,
quantize_,
)
Expand All @@ -45,8 +46,8 @@
find_multiple,
get_device_str,
get_precision,
set_precision,
name_to_dtype,
set_precision,
state_dict_device,
use_et_backend,
)
Expand All @@ -60,28 +61,36 @@

import inspect


def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
name
for name, param in parameters.items()
if param.kind
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
]
return named_params

def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:

def validate_args(
named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None
) -> Dict[str, Any]:
for key in q_kwargs.keys():
if key not in named_params:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
print(
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring."
)
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

Expand Down Expand Up @@ -110,21 +119,30 @@ def quantize_model(
if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
ao_quant = True
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
quantize_(model, int8_weight_only())
else:
ao_quant = False
if ao_quant:
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
print(
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
)
set_precision(torch.float32)
# We set global precision from quantize options if it is specified at cli.py:485

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

Expand All @@ -141,14 +159,19 @@ def quantize_model(
model = quant_handler.quantize(model)



#########################################################################
### QuantHandler API definition ###
### (unify with torchao in future) ###


class QuantHandler:
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down Expand Up @@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module:


class PrecisionHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
*,
dtype,
):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down Expand Up @@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module:


class ExecutorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
*,
accelerator,
):
self.model_ = model

if isinstance(accelerator, str):
Expand Down Expand Up @@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales):
)


class WeightOnlyInt8Linear(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales: torch.Tensor

def __init__(
self,
in_features,
out_features,
bias=None,
device=None,
dtype=None,
*,
weight: Optional[torch.Tensor] = None,
scales: Optional[torch.Tensor] = None,
groupsize: Optional[int] = None,
):
super().__init__()
if dtype is None:
dtype = torch.get_default_dtype()

if device is None:
device = "cpu"

assert not bias, "Bias is not supported by LinearInt8"
self.in_features = in_features
self.out_features = out_features

assert (weight is None) == bool(
scales is None
), "must specify both weights and scales, or neither"
if weight is None:
weight = torch.empty(
(out_features, in_features),
dtype=torch.int8,
device=device,
)
if groupsize is None or (groupsize == 0):
scales = torch.empty(out_features, dtype=dtype, device=device)
else:
n_groups = (in_features + groupsize - 1) // groupsize
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)

self.register_buffer("weight", weight.to(device))
self.register_buffer("scales", scales.to(device))

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_aoti(input, self.weight, self.scales)

def et_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_et(input, self.weight, self.scales)


class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
model: Optional[nn.Module] = None,
device = None,
precision=None,
tokenizer=None,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
self.model_ = model
self.device = device
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
self.bitwidth = 8
else:
self.bitwidth = bitwidth

@torch.no_grad()
def quantize(self, module):
# cur_state_dict = state_dict_device(self.model_.state_dict())
# dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(self.node_type == "*")
or (self.node_type == "output" and name == "output")
or (self.node_type == "!output" and name != "output")
):
# print(f"{name, child}")
input_weight = child.weight.float()
# print(f"{name, child}")
# print(f"in_features: {child.in_features}")
# print(f"out_features: {child.out_features}")

# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
range_min,
range_max,
torch.int8,
self.groupsize,
scales_dtype=child.weight.dtype,
)

setattr(
module,
name,
WeightOnlyInt8Linear(
in_features=child.in_features,
out_features=child.out_features,
device=self.device,
# update variables from quantization
weight=weight,
scales=scales,
groupsize=self.groupsize,
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### embedding table quantization ######
### (unify with torchao in future) ###
Expand Down Expand Up @@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module:
# class references
quantizer_class_dict = {
"embedding": EmbeddingOnlyQuantHandler,
"linear:int8": WeightOnlyInt8QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:int8": int8_weight_only,
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}

Expand Down Expand Up @@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module:
print("Slow fallback kernels will be used.")

except Exception as e:

class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
def __init__(
self, model: Optional[nn.Module] = None, device="cpu", precision=None
):
global torchao_experimental_load_error
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")

raise Exception(
f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}"
)

torchao_experimental_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
quantizer_class_dict["embedding:wx"] = ErrorHandler

0 comments on commit 92e0a9d

Please sign in to comment.