From d3d77fcf5875160bba3fa8c738d3e1e5e63d04ae Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Sun, 5 May 2024 09:06:53 -0700 Subject: [PATCH] fix cmake version, padding for a8w4dq, lint... (#680) * fix cmake version, p[adding for a8w4dq, lint... * fix cmake version, p[adding for a8w4dq, lint... * updates * fix --- .github/workflows/run-readme-pr-macos.yml | 2 +- .github/workflows/run-readme-pr.yml | 2 +- README.md | 2 +- build/utils.py | 2 + cli.py | 10 +- docs/quantization.md | 2 +- generate.py | 5 +- qops.py | 131 +++++++++++++++++++++ quantize.py | 134 ++++++++++++++++++---- requirements.txt | 2 +- scripts/process-readme.py | 4 +- 11 files changed, 259 insertions(+), 37 deletions(-) diff --git a/.github/workflows/run-readme-pr-macos.yml b/.github/workflows/run-readme-pr-macos.yml index 522fd9963..071c0256e 100644 --- a/.github/workflows/run-readme-pr-macos.yml +++ b/.github/workflows/run-readme-pr-macos.yml @@ -1,4 +1,4 @@ -name: Run the README instructions - with stories - to ensure they work +name: Run the README instructions - with stories - on MacOS on: pull_request: push: diff --git a/.github/workflows/run-readme-pr.yml b/.github/workflows/run-readme-pr.yml index cebfa1650..6199e1e20 100644 --- a/.github/workflows/run-readme-pr.yml +++ b/.github/workflows/run-readme-pr.yml @@ -1,4 +1,4 @@ -name: Run the README instructions - with stories - to ensure they work +name: Run the README instructions - with stories on: pull_request: diff --git a/README.md b/README.md index 69e5d96fb..00142f595 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,7 @@ export TORCHCHAT_ROOT=${PWD} ### Export for mobile The following example uses the Llama3 8B Instruct model. -[shell default]: echo '{"embedding": {"bitwidth": 4, "groupsize" : 32}, "linear:a8w4dq": {"groupsize" : 32}}' >./config/data/mobile.json +[#shell default]: echo '{"embedding": {"bitwidth": 4, "groupsize" : 32}, "linear:a8w4dq": {"groupsize" : 32}}' >./config/data/mobile.json ``` # Export diff --git a/build/utils.py b/build/utils.py index ae14b7939..2c4c44a6a 100644 --- a/build/utils.py +++ b/build/utils.py @@ -10,11 +10,13 @@ import os from pathlib import Path from typing import Any, Callable, Dict, List, Tuple + import torch ########################################################################## ### unpack packed weights ### + def unpack_packed_weights( packed_weights: Dict[str, Any], packed_linear: Callable, diff --git a/cli.py b/cli.py index 1487ba607..db5417c0a 100644 --- a/cli.py +++ b/cli.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import json +import logging import os from pathlib import Path @@ -13,6 +14,13 @@ from build.utils import allowable_dtype_names, allowable_params_table, get_device_str from download import download_and_convert, is_model_downloaded +FORMAT = ( + "%(levelname)s: %(asctime)-15s: %(filename)s: %(funcName)s: %(module)s: %(message)s" +) +logging.basicConfig(filename="/tmp/torchchat.log", level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + default_device = "fast" default_model_dir = Path( os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") @@ -316,7 +324,7 @@ def arg_init(args): if args.output_pte_path: if args.device not in ["cpu", "fast"]: raise RuntimeError("Device not supported by ExecuTorch") - args.device="cpu" + args.device = "cpu" else: args.device = get_device_str( args.quantize.get("executor", {}).get("accelerator", args.device) diff --git a/docs/quantization.md b/docs/quantization.md index ac7c07408..fcb3198c4 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -11,9 +11,9 @@ While quantization can potentially degrade the model's performance, the methods | compression | FP Precision | bitwidth| group size | dynamic activation quantization | Eager | AOTI | ExecuTorch | |--|--|--|--|--|--|--|--| | linear (asymmetric) | fp32, fp16, bf16 | [8, 4]* | [32, 64, 128, 256]** | | ✅ | ✅ | 🚧 | -| linear with dynamic activations (symmetric) | fp32^ | | [32, 64, 128, 256]** | a8w4dq | 🚧 |🚧 | ✅ | | linear with GPTQ*** (asymmetric) | | |[32, 64, 128, 256]** | | ✅ | ✅ | ❌ | | linear with HQQ*** (asymmetric) | | |[32, 64, 128, 256]** | | ✅ | ✅ | ❌ | +| linear with dynamic activations (symmetric) | fp32^ | | [32, 64, 128, 256] | a8w4dq | 🚧 |🚧 | ✅ | ### Embedding Quantization Due to the larger vocabulary size of llama3, we also recommend quantizing the embeddings to further reduce the model size for on-device usecases. diff --git a/generate.py b/generate.py index a226fd755..aee5b00a3 100644 --- a/generate.py +++ b/generate.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import argparse import itertools - import logging import sys import time @@ -25,9 +24,7 @@ ) from build.model import Transformer from build.utils import device_sync, set_precision -from cli import add_arguments_for_generate, arg_init, check_args - -logger = logging.getLogger(__name__) +from cli import add_arguments_for_generate, arg_init, check_args, logger B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>", "<>" diff --git a/qops.py b/qops.py index 9145bda77..e59b5dde8 100644 --- a/qops.py +++ b/qops.py @@ -405,3 +405,134 @@ def _prepare_weight_and_scales_and_zeros( @classmethod def _calc_padded_size(cls, *, k, groupsize=1, innner_k_tiles=1): return find_multiple(k, 1024) + + +def linear_8da4w( + input, + weight_int8, + scales, + zeros, + out_features, + groupsize, + precision, +): + from torchao.quantization.quant_primitives import per_token_dynamic_quant + + input = per_token_dynamic_quant(input) + # TODO: verify and remove following reshape code + # origin_input_size = input.size() + # input = input.reshape(-1, origin_input_size[-1]) + + # TODO: better API + # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed) + n_bit = 4 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group( + weight_int8, + scales, + zeros, + quant_min, + quant_max, + torch.int8, + groupsize, + precision, + ) + + # input = input.to(torch.float16) + # w_dq = w_dq.to(torch.float16) + c = torch.nn.functional.linear(input, w_dq) + + # new_shape = origin_input_size[:-1] + (out_features,) + # c = c.reshape(new_shape) + + return c + + +class LinearAct8Int4DQ(torch.nn.Module): + __constants__ = ["in_features", "origin_in_feature", "out_features"] + in_features: int + origin_in_features: int + out_features: int + weight: torch.Tensor + scales: torch.Tensor + zeros: torch.Tensor + + """ + This module implements a dynamic quantized linear layer with + int4 weight. Weights are per channel groupwise + quantized. Parameters of importance groupsize: the number of + elements in each quantized group precision: precision of input and + output. e.g. torch.float32 means input activation is float32 and + output is float32. scales_precision: precision of per group + scale. """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + *, + groupsize: int = 256, + weight: Optional[torch.Tensor] = None, + scales: Optional[torch.Tensor] = None, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + # always pad if needed since it becomes a noop at runtime if not needed + # self.origin_in_features = in_features + self.origin_in_features = in_features + in_features = find_multiple(in_features, groupsize) + self.in_features = in_features + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + + self.groupsize = groupsize + # Precision of the activation which also indicates + # output precision of the dynamically quantized linear layer + # that his module represents. + self.precision = precision + + assert (weight is None) == bool( + scales is None + ), "must specify both weights and scales_and_zeros, or neither" + + if weight is None: + weight = torch.empty((out_features, in_features), dtype=torch.int8) + scales = torch.empty( + (out_features, in_features // groupsize), + dtype=scales_precision, + ) + + # we received an unpadded weight, so pad it + if weight.shape[1] != in_features: + weight = F.pad(weight, pad=(0, self.in_features - self.origin_in_features)) + + # currently storing unpacked int8 weights + self.register_buffer("weight", weight) + self.register_buffer("scales", scales) + self.register_buffer( + "zeros", + torch.empty( + (out_features, in_features // groupsize), + dtype=scales_precision, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + # This operator does not support anything but FP32, so we do the deed + # Eventually push that into linear_8da4w + return linear_8da4w( + input.float(), + self.weight, + self.scales, + self.zeros, + self.out_features, + self.groupsize, + self.precision, + ).to(dtype=input.dtype) diff --git a/quantize.py b/quantize.py index 226c2b74c..df12e71d6 100644 --- a/quantize.py +++ b/quantize.py @@ -24,6 +24,7 @@ ) from qops import ( + LinearAct8Int4DQ, LinearInt4 as WeightOnlyInt4Linear, LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding, @@ -83,29 +84,29 @@ def quantized_model(self) -> nn.Module: ######################################################################### ### QuantHandler wrapper for a8w4dq from torchao ### - - -class Int8DynActInt4WeightQuantizer(QuantHandler): - def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs): - import torchao.quantization.quant_api as quant_api - - self.model_ = model - self.device = device - self.tokenizer = tokenizer - self.quantizer = quant_api.Int8DynActInt4WeightQuantizer( - **kwargs, precision=get_precision(), scales_precision=get_precision() - ) - - def create_quantized_state_dict(self) -> Dict: # "StateDict" - pass - - def convert_for_runtime(self) -> nn.Module: - pass - - def quantized_model(self) -> nn.Module: - return self.quantizer.quantize(self.model_) - - +# +# +# class Int8DynActInt4WeightQuantizer(QuantHandler): +# def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs): +# import torchao.quantization.quant_api as quant_api +# +# self.model_ = model +# self.device = device +# self.tokenizer = tokenizer +# self.quantizer = quant_api.Int8DynActInt4WeightQuantizer( +# **kwargs, precision=get_precision(), scales_precision=get_precision() +# ) +# +# def create_quantized_state_dict(self) -> Dict: # "StateDict" +# pass +# +# def convert_for_runtime(self) -> nn.Module: +# pass +# +# def quantized_model(self) -> nn.Module: +# return self.quantizer.quantize(self.model_) +# +# ######################################################################### ### wrapper for setting precision as a QuantHandler ### @@ -547,8 +548,6 @@ def __init__( groupsize=128, inner_k_tiles=8, padding_allowed=True, - weight: Optional[torch.Tensor] = None, - scales_and_zeros: Optional[torch.Tensor] = None, ): self.model_ = model self.device = device @@ -620,6 +619,91 @@ def quantized_model(self) -> nn.Module: return self.quantize(self.model_) +######################################################################### +##### weight only int4 per channel groupwise quantized code ###### + + +class Int8DynActInt4WeightQuantizer(QuantHandler): + def __init__( + self, + model: nn.Module, + device=None, + dtype=None, + *, + tokenizer=None, + groupsize=128, + padding_allowed=True, + precision=torch.float32, + scales_precision=torch.float32, + ): + if dtype is None: + dtype = torch.float32 + + self.model_ = model + self.device = device + self.dtype = dtype + + self.groupsize = groupsize + self.padding_allowed = padding_allowed + self.precision = precision + self.scales_precision = scales_precision + assert groupsize in [32, 64, 128, 256] + + @torch.no_grad() + def quantize(self, module): + from torchao.quantization.quant_primitives import ( + group_quantize_tensor_symmetric, + ) + + for name, child in module.named_children(): + # print(f"name: {name}") + if isinstance(child, torch.nn.Linear): + out_features = child.out_features + in_features = child.in_features + weight = child.weight.data + assert not child.bias + assert out_features % 8 == 0, "require out_features % 8 == 0" + # print(f"linear: {fqn}, in={in_features}, out={out_features}") + + # if self.padding_allowed: + # padding_multiple=max(self.groupsize, 1024) + padding_multiple = self.groupsize + padded_in_features = find_multiple(in_features, padding_multiple) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + ( + weight_int8, + scales, + zeros, + ) = group_quantize_tensor_symmetric( + weight.float(), + 4, # n_bit + self.groupsize, + self.scales_precision, + ) + + setattr( + module, + name, + LinearAct8Int4DQ( + child.in_features, + child.out_features, + bias=False, + device=self.device, + dtype=self.dtype, + groupsize=self.groupsize, + weight=weight_int8.to(device=self.device), + scales=scales.to(device=self.device), + ), + ) + else: + self.quantize(child) + + return module + + def quantized_model(self) -> nn.Module: + return self.quantize(self.model_) + + ######################################################################### ##### GPTQ ##### diff --git a/requirements.txt b/requirements.txt index 5001aabbf..7e9202bf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ blobfile # Build tools wheel -cmake +cmake>=3.24 ninja zstd diff --git a/scripts/process-readme.py b/scripts/process-readme.py index ba7113cf6..b370e4621 100644 --- a/scripts/process-readme.py +++ b/scripts/process-readme.py @@ -18,7 +18,7 @@ def print_between_triple_backticks(filename, predicate): print("exit 0") return elif line.startswith(skip): - keyword = line[len(skip):-1].strip() + keyword = line[len(skip) : -1].strip() if keyword == "begin": print("if false; then") elif keyword == "end": @@ -35,6 +35,6 @@ def print_between_triple_backticks(filename, predicate): if len(sys.argv) > 1: predicate = sys.argv[1] else: - predicate="default" + predicate = "default" print_between_triple_backticks("README.md", predicate)