From 66cfc1e0c703c678450f69843b48412fe85d5a48 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Thu, 1 Aug 2024 06:56:17 -0700 Subject: [PATCH] mixed-precision quantization milestone1: naive_intNwo + eval/benchmark framework (#531) * milestone1: naive_intNwo + eval/benchmark * remove experiment scripts * remove exp files * use default ZeroPointDomain.INT for int2/3/5/6 * renamed test_naive_intNwo.py to test_mixed_precision.py * updated intNwo with _get_linear_subclass_inserter * adjust sqnr threshold according to bit width * fixed test for int4wo and add __init__.py * skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly * edit the sqnr threshold * add unittest * correct import path --- test/quantization/test_mixed_precision.py | 32 +++++++ .../prototype/mixed_precision/__init__.py | 0 .../mixed_precision/scripts/__init__.py | 1 + .../mixed_precision/scripts/mp_quant_eval.py | 95 +++++++++++++++++++ .../mixed_precision/scripts/naive_intNwo.py | 60 ++++++++++++ 5 files changed, 188 insertions(+) create mode 100644 test/quantization/test_mixed_precision.py create mode 100644 torchao/quantization/prototype/mixed_precision/__init__.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/__init__.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py diff --git a/test/quantization/test_mixed_precision.py b/test/quantization/test_mixed_precision.py new file mode 100644 index 0000000000..8afd022d3c --- /dev/null +++ b/test/quantization/test_mixed_precision.py @@ -0,0 +1,32 @@ +import unittest + +import torch +import torch.nn as nn +from torchao.quantization import quantize_, int8_weight_only, int4_weight_only +from torchao.quantization.utils import compute_error +from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only + +_CUDA_IS_AVAILABLE = torch.cuda.is_available() + +class TestWeightOnlyQuantNaive(unittest.TestCase): + + def test_quantization_intNwo(self): + #skip test int4wo for now since it is under development in torchao + for quantization_bit in [2, 3, 5, 6, 8]: + for symmetric in [False, True]: + with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric): + for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]: + x = torch.randn(*x_shape, dtype=torch.bfloat16) + m = nn.Sequential(nn.Linear(32, 80)).bfloat16() + y_ref = m(x) + quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric)) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + # SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization + # e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills + expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02 + self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low") + + +if __name__ == '__main__': + unittest.main() diff --git a/torchao/quantization/prototype/mixed_precision/__init__.py b/torchao/quantization/prototype/mixed_precision/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/mixed_precision/scripts/__init__.py b/torchao/quantization/prototype/mixed_precision/scripts/__init__.py new file mode 100644 index 0000000000..1b0cae6ab3 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/__init__.py @@ -0,0 +1 @@ +from .naive_intNwo import intN_weight_only diff --git a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py new file mode 100644 index 0000000000..d17b76159e --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + +from naive_intNwo import intN_weight_only +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lm_eval.models.huggingface import HFLM +from lm_eval.evaluator import evaluate +from lm_eval.tasks import get_task_dict + +from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight +from torchao._models._eval import TransformerEvalWrapper + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +from torchao.quantization.quant_api import autoquant + + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.fx_graph_cache = True + + +def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size): + + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) + + if quantization == "autoquant": + model = autoquant(model.to(device=device)) + + # naive implementation of uniform precision quantization all layers + elif quantization in ["2","3","4","5","6","8"]: + quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym)) + + # mix precision quantization for Llama3 + elif quantization == "MP_llama3": + + # filter for sensitive layers (the first 3 and last 2 layers for Llama3) + def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) + + # filter for non-sensitive layers (other 27 layers for Llama3) + def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) + + # quantize the sensitive layers + if sensi_bit != 16: + quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen) + + # quantize the less-sensitive layers + if sensi_bit == 4: + quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) + else: + quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) + + if compile: + model = torch.compile(model, mode="max-autotune", fullgraph=True) + + with torch.no_grad(): + + result = evaluate( + HFLM( + pretrained=model, + tokenizer=tokenizer, + batch_size=batch_size, + max_length=max_length), + get_task_dict(tasks), + limit = limit, + ) + + for task, res in result["results"].items(): + print(f"{task}: {res}") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Run HF Model Evaluation') + parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') + parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') + parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') + parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') + parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') + parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') + parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers') + parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers') + parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default') + parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on') + args = parser.parse_args() + run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py new file mode 100644 index 0000000000..6ebe458a46 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -0,0 +1,60 @@ +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +from torchao.quantization import int8_weight_only, int4_weight_only +from torchao.quantization.quant_api import _get_linear_subclass_inserter + +def intN_weight_only(group_size=32, n=8, symmetric=False): + ''' + Apply int N-bit weight only quantization to a linear layer. + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] + `n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] + Usage: + from torchao.quantization import quantize_ + quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize) + ''' + # for asymmetric quantization + def apply_intN_weight_only_quant_asym(weight): + # avoid circular dependency + from torchao.dtypes import to_affine_quantized + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.uint8 + quant_min = 0 + quant_max = 2**n-1 + eps = 1e-6 + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) + + # for symmetric quantization + def apply_intN_weight_only_quant_sym(weight): + # avoid circular dependency + from torchao.dtypes import to_affine_quantized + mapping_type = MappingType.SYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int8 + eps = 1e-6 + zero_point_dtype = torch.int64 + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + + try: + assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" + if n == 8: + return int8_weight_only() + elif n == 4: + return int4_weight_only(group_size=group_size) + else: + if symmetric: + return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym) + else: + return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym) + except Exception as e: + raise +