From eaa3a580c8a9f27268d3c27e551054dd5053f01c Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Fri, 24 May 2024 15:39:13 +0800 Subject: [PATCH] Add 'excluded_precisions' to static_quant and smooth_quant (#1814) Signed-off-by: Cheng, Zixuan --- .../algorithms/smooth_quant/smooth_quant.py | 76 +++++++++++++------ .../torch/algorithms/smooth_quant/utility.py | 1 + .../algorithms/static_quant/static_quant.py | 68 ++++++++++++----- .../torch/quantization/algorithm_entry.py | 2 + .../torch/quantization/config.py | 6 ++ .../torch/quantization/test_smooth_quant.py | 17 +++++ .../torch/quantization/test_static_quant.py | 11 +++ 7 files changed, 137 insertions(+), 44 deletions(-) diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py index f2b2cdf8542..e9d6fde3524 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -16,6 +16,7 @@ # limitations under the License. import json +import os import torch @@ -32,6 +33,7 @@ from neural_compressor.torch.algorithms import Quantizer from .utility import ( + CpuInfo, TorchSmoothQuant, cfg_to_qconfig, dump_model_op_stats, @@ -84,6 +86,8 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) model.eval() + use_bf16 = self.quant_config.get("use_bf16", None) + # check smoothquant alpha and act_algo value recipe_cfgs = self.quant_config.get("recipe_cfgs", None) alpha = recipe_cfgs["smooth_quant_args"]["alpha"] @@ -133,8 +137,10 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): Returns: A quantized model. """ + use_bf16 = self.quant_config.get("use_bf16", None) + model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) @@ -169,6 +175,8 @@ def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args, model.output_tensor_id_op_name, ) + use_bf16 = tune_cfg.get("use_bf16", None) + # check smoothquant folding value recipe_cfgs = tune_cfg.get("recipe_cfgs", None) if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]: @@ -223,7 +231,7 @@ def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args, model.load_qconf_summary(qconf_summary=ipex_config_path) run_fn(model) model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) # Recover model parameter when smoothquant folding = True if ( @@ -247,6 +255,8 @@ def qdq_quantize( smoothquant_scale_info = sq.sq_scale_info sq_minmax_init = True if tune_cfg.get("act_algo", "kl") == "minmax" else False + use_bf16 = tune_cfg.get("use_bf16", None) + # Check save_qconf_summary part is a workaround for IPEX bug. # Sometimes the prepared model from get_op_capablitiy loss this attribute if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover @@ -296,7 +306,7 @@ def qdq_quantize( update_sq_scale(ipex_config_path, smoothquant_scale_info) model.load_qconf_summary(qconf_summary=ipex_config_path) model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) @@ -334,33 +344,53 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False): logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}") -def _ipex_post_quant_process(model, example_inputs, inplace=False): +def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False): """Convert to a jit model. Args: model: a prepared model. example_inputs: used to trace torch model. + use_bf16: whether to use bf16 for mixed precision. inplace: whether to carry out model transformations in-place. Returns: A converted jit model. """ - model = ipex.quantization.convert(model, inplace=inplace) - with torch.no_grad(): - try: - if isinstance(example_inputs, dict): - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) - else: - model = torch.jit.trace(model, example_inputs) - model = torch.jit.freeze(model.eval()) - except: # pragma: no cover - if isinstance(example_inputs, dict): - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) - else: - model = torch.jit.trace(model, example_inputs, strict=False) - model = torch.jit.freeze(model.eval()) - # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile - # At the 2nd run, the llga pass will be triggered and the model is turned into - # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph - simple_inference(model, example_inputs, iterations=2) - return model + if use_bf16 and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1"): # pragma: no cover + with torch.no_grad(): + with torch.cpu.amp.autocast(): + model = ipex.quantization.convert(model, inplace=inplace) + try: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) + else: + model = torch.jit.trace(model, example_inputs) + model = torch.jit.freeze(model.eval()) + except: + if isinstance(example_inputs, dict): + model = torch.jit.trace( + model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False + ) + else: + model = torch.jit.trace(model, example_inputs, strict=False) + model = torch.jit.freeze(model.eval()) + else: + model = ipex.quantization.convert(model, inplace=inplace) + with torch.no_grad(): + try: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) + else: + model = torch.jit.trace(model, example_inputs) + model = torch.jit.freeze(model.eval()) + except: # pragma: no cover + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) + else: + model = torch.jit.trace(model, example_inputs, strict=False) + model = torch.jit.freeze(model.eval()) + # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile + # At the 2nd run, the llga pass will be triggered and the model is turned into + # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph + simple_inference(model, example_inputs, iterations=2) + return model diff --git a/neural_compressor/torch/algorithms/smooth_quant/utility.py b/neural_compressor/torch/algorithms/smooth_quant/utility.py index 51af8ba43cf..39b19883d3b 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/utility.py +++ b/neural_compressor/torch/algorithms/smooth_quant/utility.py @@ -25,6 +25,7 @@ from packaging.version import Version from neural_compressor.torch.algorithms.static_quant import ( + CpuInfo, TransformerBasedModelBlockPatternDetector, dump_model_op_stats, generate_activation_observer, diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 5839a23b6bf..b9c476e9e80 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -16,6 +16,7 @@ # limitations under the License. import json +import os from copy import deepcopy from types import MethodType @@ -34,6 +35,7 @@ from neural_compressor.torch.utils import logger from .utility import ( + CpuInfo, cfg_to_qconfig, dump_model_op_stats, get_ipex_version, @@ -75,6 +77,8 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) model.eval() + use_bf16 = self.quant_config.get("use_bf16", None) + # Check save_qconf_summary part is a workaround for IPEX bug. # Sometimes the prepared model from get_op_capablitiy loss this attribute if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): @@ -108,10 +112,12 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): Returns: A quantized model. """ + use_bf16 = self.quant_config.get("use_bf16", None) + from neural_compressor.torch.algorithms.static_quant import save model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) with open(ipex_config_path, "r") as f: model.tune_cfg = json.load(f) @@ -125,33 +131,53 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): return model -def _ipex_post_quant_process(model, example_inputs, inplace=False): +def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False): """Convert to a jit model. Args: model: a prepared model. example_inputs: used to trace torch model. + use_bf16: whether to use bf16 for mixed precision. inplace: whether to carry out model transformations in-place. Returns: A converted jit model. """ - model = ipex.quantization.convert(model, inplace=inplace) - with torch.no_grad(): - try: - if isinstance(example_inputs, dict): - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) - else: - model = torch.jit.trace(model, example_inputs) - model = torch.jit.freeze(model.eval()) - except: - if isinstance(example_inputs, dict): - model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) - else: - model = torch.jit.trace(model, example_inputs, strict=False) - model = torch.jit.freeze(model.eval()) - # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile - # At the 2nd run, the llga pass will be triggered and the model is turned into - # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph - simple_inference(model, example_inputs, iterations=2) - return model + if use_bf16 and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1"): # pragma: no cover + with torch.no_grad(): + with torch.cpu.amp.autocast(): + model = ipex.quantization.convert(model, inplace=inplace) + try: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) + else: + model = torch.jit.trace(model, example_inputs) + model = torch.jit.freeze(model.eval()) + except: + if isinstance(example_inputs, dict): + model = torch.jit.trace( + model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False + ) + else: + model = torch.jit.trace(model, example_inputs, strict=False) + model = torch.jit.freeze(model.eval()) + else: # pragma: no cover + model = ipex.quantization.convert(model, inplace=inplace) + with torch.no_grad(): + try: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) + else: + model = torch.jit.trace(model, example_inputs) + model = torch.jit.freeze(model.eval()) + except: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) + else: + model = torch.jit.trace(model, example_inputs, strict=False) + model = torch.jit.freeze(model.eval()) + # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile + # At the 2nd run, the llga pass will be triggered and the model is turned into + # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph + simple_inference(model, example_inputs, iterations=2) + return model diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 6a024843d90..03e3bf23115 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -181,6 +181,7 @@ def static_quant_entry( "algorithm": cfg.act_algo, }, } + quant_config_mapping["use_bf16"] = "bf16" not in cfg.excluded_precisions run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) @@ -278,6 +279,7 @@ def smooth_quant_entry( "last_conv_or_matmul_quantization": True, "pre_post_process_quantization": True, } + quant_config_mapping["use_bf16"] = "bf16" not in cfg.excluded_precisions run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 4f47b94b1d5..160b17e8f96 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -971,6 +971,7 @@ class StaticQuantConfig(BaseConfig): "act_sym", "act_granularity", "act_algo", + "excluded_precisions", ] supported_configs: List[OperatorConfig] = [] @@ -984,6 +985,7 @@ def __init__( act_sym: bool = False, act_granularity: str = "per_tensor", act_algo: str = "kl", + excluded_precisions: list = [], white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init Static Quant Configs.""" @@ -996,6 +998,7 @@ def __init__( self.act_sym = act_sym self.act_granularity = act_granularity self.act_algo = act_algo + self.excluded_precisions = excluded_precisions self._post_init() @classmethod @@ -1057,6 +1060,7 @@ class SmoothQuantConfig(BaseConfig): "act_sym", "act_granularity", "act_algo", + "excluded_precisions", "alpha", "folding", "scale_sharing", @@ -1074,6 +1078,7 @@ def __init__( act_sym: bool = False, act_granularity: str = "per_tensor", act_algo: str = "kl", + excluded_precisions: list = [], alpha: float = 0.5, folding: bool = False, # below for autotune @@ -1097,6 +1102,7 @@ def __init__( self.act_sym = act_sym self.act_granularity = act_granularity self.act_algo = act_algo + self.excluded_precisions = excluded_precisions self.alpha = alpha self.folding = folding # below for autotune diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py index 7d8b1730ff1..2967f386cb1 100644 --- a/test/3x/torch/quantization/test_smooth_quant.py +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -179,3 +179,20 @@ def test_smooth_quant_with_quantize_API(self): example_dict = {"x": example_inputs} q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_dict) assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_smooth_quant_mixed_precision(self): + fp32_model = copy.deepcopy(model) + quant_config = get_default_sq_config() # do mixed_precison by default. + example_inputs = torch.randn([1, 3]) + + # prepare/convert API + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + assert q_model is not None, "Quantization failed!" + + # quantize API + quant_config.excluded_precisions = ["bf16"] + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 185b7eb9ee7..072f4774e3e 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -191,3 +191,14 @@ def test_static_quant_with_quantize_API(self): example_inputs = self.input q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_static_quant_mixed_precision(self): + fp32_model = copy.deepcopy(self.fp32_model) + quant_config = get_default_static_config() + quant_config.excluded_precisions = ["bf16"] + example_inputs = self.input + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + assert q_model is not None, "Quantization failed!"