Skip to content

Commit

Permalink
fix cmake version, padding for a8w4dq, lint... (#680)
Browse files Browse the repository at this point in the history
* fix cmake version, p[adding for a8w4dq, lint...

* fix cmake version, p[adding for a8w4dq, lint...

* updates

* fix
  • Loading branch information
mikekgfb authored May 5, 2024
1 parent a4eefc3 commit d3d77fc
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-readme-pr-macos.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run the README instructions - with stories - to ensure they work
name: Run the README instructions - with stories - on MacOS
on:
pull_request:
push:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-readme-pr.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run the README instructions - with stories - to ensure they work
name: Run the README instructions - with stories

on:
pull_request:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ export TORCHCHAT_ROOT=${PWD}
### Export for mobile
The following example uses the Llama3 8B Instruct model.

[shell default]: echo '{"embedding": {"bitwidth": 4, "groupsize" : 32}, "linear:a8w4dq": {"groupsize" : 32}}' >./config/data/mobile.json
[#shell default]: echo '{"embedding": {"bitwidth": 4, "groupsize" : 32}, "linear:a8w4dq": {"groupsize" : 32}}' >./config/data/mobile.json

```
# Export
Expand Down
2 changes: 2 additions & 0 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple

import torch

##########################################################################
### unpack packed weights ###


def unpack_packed_weights(
packed_weights: Dict[str, Any],
packed_linear: Callable,
Expand Down
10 changes: 9 additions & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from pathlib import Path

Expand All @@ -13,6 +14,13 @@
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
from download import download_and_convert, is_model_downloaded

FORMAT = (
"%(levelname)s: %(asctime)-15s: %(filename)s: %(funcName)s: %(module)s: %(message)s"
)
logging.basicConfig(filename="/tmp/torchchat.log", level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


default_device = "fast"
default_model_dir = Path(
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
Expand Down Expand Up @@ -316,7 +324,7 @@ def arg_init(args):
if args.output_pte_path:
if args.device not in ["cpu", "fast"]:
raise RuntimeError("Device not supported by ExecuTorch")
args.device="cpu"
args.device = "cpu"
else:
args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", args.device)
Expand Down
2 changes: 1 addition & 1 deletion docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ While quantization can potentially degrade the model's performance, the methods
| compression | FP Precision | bitwidth| group size | dynamic activation quantization | Eager | AOTI | ExecuTorch |
|--|--|--|--|--|--|--|--|
| linear (asymmetric) | fp32, fp16, bf16 | [8, 4]* | [32, 64, 128, 256]** | ||| 🚧 |
| linear with dynamic activations (symmetric) | fp32^ | | [32, 64, 128, 256]** | a8w4dq | 🚧 |🚧 ||
| linear with GPTQ*** (asymmetric) | | |[32, 64, 128, 256]** | ||||
| linear with HQQ*** (asymmetric) | | |[32, 64, 128, 256]** | ||||
| linear with dynamic activations (symmetric) | fp32^ | | [32, 64, 128, 256] | a8w4dq | 🚧 |🚧 ||

### Embedding Quantization
Due to the larger vocabulary size of llama3, we also recommend quantizing the embeddings to further reduce the model size for on-device usecases.
Expand Down
5 changes: 1 addition & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import argparse
import itertools

import logging
import sys
import time
Expand All @@ -25,9 +24,7 @@
)
from build.model import Transformer
from build.utils import device_sync, set_precision
from cli import add_arguments_for_generate, arg_init, check_args

logger = logging.getLogger(__name__)
from cli import add_arguments_for_generate, arg_init, check_args, logger

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
Expand Down
131 changes: 131 additions & 0 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,134 @@ def _prepare_weight_and_scales_and_zeros(
@classmethod
def _calc_padded_size(cls, *, k, groupsize=1, innner_k_tiles=1):
return find_multiple(k, 1024)


def linear_8da4w(
input,
weight_int8,
scales,
zeros,
out_features,
groupsize,
precision,
):
from torchao.quantization.quant_primitives import per_token_dynamic_quant

input = per_token_dynamic_quant(input)
# TODO: verify and remove following reshape code
# origin_input_size = input.size()
# input = input.reshape(-1, origin_input_size[-1])

# TODO: better API
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
n_bit = 4
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
weight_int8,
scales,
zeros,
quant_min,
quant_max,
torch.int8,
groupsize,
precision,
)

# input = input.to(torch.float16)
# w_dq = w_dq.to(torch.float16)
c = torch.nn.functional.linear(input, w_dq)

# new_shape = origin_input_size[:-1] + (out_features,)
# c = c.reshape(new_shape)

return c


class LinearAct8Int4DQ(torch.nn.Module):
__constants__ = ["in_features", "origin_in_feature", "out_features"]
in_features: int
origin_in_features: int
out_features: int
weight: torch.Tensor
scales: torch.Tensor
zeros: torch.Tensor

"""
This module implements a dynamic quantized linear layer with
int4 weight. Weights are per channel groupwise
quantized. Parameters of importance groupsize: the number of
elements in each quantized group precision: precision of input and
output. e.g. torch.float32 means input activation is float32 and
output is float32. scales_precision: precision of per group
scale. """

def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
*,
groupsize: int = 256,
weight: Optional[torch.Tensor] = None,
scales: Optional[torch.Tensor] = None,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
super().__init__()
# always pad if needed since it becomes a noop at runtime if not needed
# self.origin_in_features = in_features
self.origin_in_features = in_features
in_features = find_multiple(in_features, groupsize)
self.in_features = in_features
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"

self.groupsize = groupsize
# Precision of the activation which also indicates
# output precision of the dynamically quantized linear layer
# that his module represents.
self.precision = precision

assert (weight is None) == bool(
scales is None
), "must specify both weights and scales_and_zeros, or neither"

if weight is None:
weight = torch.empty((out_features, in_features), dtype=torch.int8)
scales = torch.empty(
(out_features, in_features // groupsize),
dtype=scales_precision,
)

# we received an unpadded weight, so pad it
if weight.shape[1] != in_features:
weight = F.pad(weight, pad=(0, self.in_features - self.origin_in_features))

# currently storing unpacked int8 weights
self.register_buffer("weight", weight)
self.register_buffer("scales", scales)
self.register_buffer(
"zeros",
torch.empty(
(out_features, in_features // groupsize),
dtype=scales_precision,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
# This operator does not support anything but FP32, so we do the deed
# Eventually push that into linear_8da4w
return linear_8da4w(
input.float(),
self.weight,
self.scales,
self.zeros,
self.out_features,
self.groupsize,
self.precision,
).to(dtype=input.dtype)
134 changes: 109 additions & 25 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

from qops import (
LinearAct8Int4DQ,
LinearInt4 as WeightOnlyInt4Linear,
LinearInt8 as WeightOnlyInt8Linear,
QuantizedEmbedding,
Expand Down Expand Up @@ -83,29 +84,29 @@ def quantized_model(self) -> nn.Module:

#########################################################################
### QuantHandler wrapper for a8w4dq from torchao ###


class Int8DynActInt4WeightQuantizer(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
import torchao.quantization.quant_api as quant_api

self.model_ = model
self.device = device
self.tokenizer = tokenizer
self.quantizer = quant_api.Int8DynActInt4WeightQuantizer(
**kwargs, precision=get_precision(), scales_precision=get_precision()
)

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass

def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
return self.quantizer.quantize(self.model_)


#
#
# class Int8DynActInt4WeightQuantizer(QuantHandler):
# def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
# import torchao.quantization.quant_api as quant_api
#
# self.model_ = model
# self.device = device
# self.tokenizer = tokenizer
# self.quantizer = quant_api.Int8DynActInt4WeightQuantizer(
# **kwargs, precision=get_precision(), scales_precision=get_precision()
# )
#
# def create_quantized_state_dict(self) -> Dict: # "StateDict"
# pass
#
# def convert_for_runtime(self) -> nn.Module:
# pass
#
# def quantized_model(self) -> nn.Module:
# return self.quantizer.quantize(self.model_)
#
#
#########################################################################
### wrapper for setting precision as a QuantHandler ###

Expand Down Expand Up @@ -547,8 +548,6 @@ def __init__(
groupsize=128,
inner_k_tiles=8,
padding_allowed=True,
weight: Optional[torch.Tensor] = None,
scales_and_zeros: Optional[torch.Tensor] = None,
):
self.model_ = model
self.device = device
Expand Down Expand Up @@ -620,6 +619,91 @@ def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### weight only int4 per channel groupwise quantized code ######


class Int8DynActInt4WeightQuantizer(QuantHandler):
def __init__(
self,
model: nn.Module,
device=None,
dtype=None,
*,
tokenizer=None,
groupsize=128,
padding_allowed=True,
precision=torch.float32,
scales_precision=torch.float32,
):
if dtype is None:
dtype = torch.float32

self.model_ = model
self.device = device
self.dtype = dtype

self.groupsize = groupsize
self.padding_allowed = padding_allowed
self.precision = precision
self.scales_precision = scales_precision
assert groupsize in [32, 64, 128, 256]

@torch.no_grad()
def quantize(self, module):
from torchao.quantization.quant_primitives import (
group_quantize_tensor_symmetric,
)

for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, torch.nn.Linear):
out_features = child.out_features
in_features = child.in_features
weight = child.weight.data
assert not child.bias
assert out_features % 8 == 0, "require out_features % 8 == 0"
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

# if self.padding_allowed:
# padding_multiple=max(self.groupsize, 1024)
padding_multiple = self.groupsize
padded_in_features = find_multiple(in_features, padding_multiple)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
(
weight_int8,
scales,
zeros,
) = group_quantize_tensor_symmetric(
weight.float(),
4, # n_bit
self.groupsize,
self.scales_precision,
)

setattr(
module,
name,
LinearAct8Int4DQ(
child.in_features,
child.out_features,
bias=False,
device=self.device,
dtype=self.dtype,
groupsize=self.groupsize,
weight=weight_int8.to(device=self.device),
scales=scales.to(device=self.device),
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### GPTQ #####

Expand Down
Loading

0 comments on commit d3d77fc

Please sign in to comment.