Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 66 additions & 163 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from functools import reduce
# from math import gcd
from typing import Dict, Optional, Callable, Any, List
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -37,6 +37,7 @@
from torchao.quantization.quant_api import (
int4_weight_only,
Int4WeightOnlyQuantizer,
int8_weight_only,
Int8DynActInt4WeightQuantizer,
quantize_,
)
Expand All @@ -45,8 +46,8 @@
find_multiple,
get_device_str,
get_precision,
set_precision,
name_to_dtype,
set_precision,
state_dict_device,
use_et_backend,
)
Expand All @@ -60,28 +61,36 @@

import inspect


def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
name
for name, param in parameters.items()
if param.kind
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
]
return named_params

def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:

def validate_args(
named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None
) -> Dict[str, Any]:
for key in q_kwargs.keys():
if key not in named_params:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
print(
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring."
)
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

Expand Down Expand Up @@ -110,21 +119,30 @@ def quantize_model(
if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
ao_quant = True
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("quantizer is linear int8")

quantize_(model, int8_weight_only())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, quantizer API is deprecated in favor of quantize_, that's why we are gradually refactoring the quantizer APIs to use quantize_, the reason we do it one by one is because there might be missing support/alignment on numerics etc. that we need to do during the migration

else:
ao_quant = False
if ao_quant:
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
print(
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
)
set_precision(torch.float32)
# We set global precision from quantize options if it is specified at cli.py:485

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

Expand All @@ -141,14 +159,19 @@ def quantize_model(
model = quant_handler.quantize(model)



#########################################################################
### QuantHandler API definition ###
### (unify with torchao in future) ###


class QuantHandler:
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down Expand Up @@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module:


class PrecisionHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
*,
dtype,
):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down Expand Up @@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module:


class ExecutorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
*,
accelerator,
):
self.model_ = model

if isinstance(accelerator, str):
Expand Down Expand Up @@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales):
)


class WeightOnlyInt8Linear(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales: torch.Tensor

def __init__(
self,
in_features,
out_features,
bias=None,
device=None,
dtype=None,
*,
weight: Optional[torch.Tensor] = None,
scales: Optional[torch.Tensor] = None,
groupsize: Optional[int] = None,
):
super().__init__()
if dtype is None:
dtype = torch.get_default_dtype()

if device is None:
device = "cpu"

assert not bias, "Bias is not supported by LinearInt8"
self.in_features = in_features
self.out_features = out_features

assert (weight is None) == bool(
scales is None
), "must specify both weights and scales, or neither"
if weight is None:
weight = torch.empty(
(out_features, in_features),
dtype=torch.int8,
device=device,
)
if groupsize is None or (groupsize == 0):
scales = torch.empty(out_features, dtype=dtype, device=device)
else:
n_groups = (in_features + groupsize - 1) // groupsize
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)

self.register_buffer("weight", weight.to(device))
self.register_buffer("scales", scales.to(device))

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_aoti(input, self.weight, self.scales)

def et_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_et(input, self.weight, self.scales)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Int 8 seems like it special cased for ET, reminder to check that as well



class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
model: Optional[nn.Module] = None,
device = None,
precision=None,
tokenizer=None,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
self.model_ = model
self.device = device
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
self.bitwidth = 8
else:
self.bitwidth = bitwidth

@torch.no_grad()
def quantize(self, module):
# cur_state_dict = state_dict_device(self.model_.state_dict())
# dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(self.node_type == "*")
or (self.node_type == "output" and name == "output")
or (self.node_type == "!output" and name != "output")
):
# print(f"{name, child}")
input_weight = child.weight.float()
# print(f"{name, child}")
# print(f"in_features: {child.in_features}")
# print(f"out_features: {child.out_features}")

# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
range_min,
range_max,
torch.int8,
self.groupsize,
scales_dtype=child.weight.dtype,
)

setattr(
module,
name,
WeightOnlyInt8Linear(
in_features=child.in_features,
out_features=child.out_features,
device=self.device,
# update variables from quantization
weight=weight,
scales=scales,
groupsize=self.groupsize,
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### embedding table quantization ######
### (unify with torchao in future) ###
Expand Down Expand Up @@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module:
# class references
quantizer_class_dict = {
"embedding": EmbeddingOnlyQuantHandler,
"linear:int8": WeightOnlyInt8QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:int8": int8_weight_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably use None for now, and remove this later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check for int8_weight_only and finished check before it looks at the table I think

@vmpuri can you check?

"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}

Expand Down Expand Up @@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module:
print("Slow fallback kernels will be used.")

except Exception as e:

class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
def __init__(
self, model: Optional[nn.Module] = None, device="cpu", precision=None
):
global torchao_experimental_load_error
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")

raise Exception(
f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}"
)

torchao_experimental_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
quantizer_class_dict["embedding:wx"] = ErrorHandler
Loading