Skip to content

Commit

Permalink
Add 'excluded_precisions' to static_quant and smooth_quant (#1814)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored May 24, 2024
1 parent e832cd3 commit eaa3a58
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 44 deletions.
76 changes: 53 additions & 23 deletions neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

import json
import os

import torch

Expand All @@ -32,6 +33,7 @@
from neural_compressor.torch.algorithms import Quantizer

from .utility import (
CpuInfo,
TorchSmoothQuant,
cfg_to_qconfig,
dump_model_op_stats,
Expand Down Expand Up @@ -84,6 +86,8 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

use_bf16 = self.quant_config.get("use_bf16", None)

# check smoothquant alpha and act_algo value
recipe_cfgs = self.quant_config.get("recipe_cfgs", None)
alpha = recipe_cfgs["smooth_quant_args"]["alpha"]
Expand Down Expand Up @@ -133,8 +137,10 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Returns:
A quantized model.
"""
use_bf16 = self.quant_config.get("use_bf16", None)

model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
Expand Down Expand Up @@ -169,6 +175,8 @@ def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args,
model.output_tensor_id_op_name,
)

use_bf16 = tune_cfg.get("use_bf16", None)

# check smoothquant folding value
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
Expand Down Expand Up @@ -223,7 +231,7 @@ def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args,
model.load_qconf_summary(qconf_summary=ipex_config_path)
run_fn(model)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

# Recover model parameter when smoothquant folding = True
if (
Expand All @@ -247,6 +255,8 @@ def qdq_quantize(
smoothquant_scale_info = sq.sq_scale_info
sq_minmax_init = True if tune_cfg.get("act_algo", "kl") == "minmax" else False

use_bf16 = tune_cfg.get("use_bf16", None)

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
Expand Down Expand Up @@ -296,7 +306,7 @@ def qdq_quantize(
update_sq_scale(ipex_config_path, smoothquant_scale_info)
model.load_qconf_summary(qconf_summary=ipex_config_path)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
Expand Down Expand Up @@ -334,33 +344,53 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}")


def _ipex_post_quant_process(model, example_inputs, inplace=False):
def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False):
"""Convert to a jit model.
Args:
model: a prepared model.
example_inputs: used to trace torch model.
use_bf16: whether to use bf16 for mixed precision.
inplace: whether to carry out model transformations in-place.
Returns:
A converted jit model.
"""
model = ipex.quantization.convert(model, inplace=inplace)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except: # pragma: no cover
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
simple_inference(model, example_inputs, iterations=2)
return model
if use_bf16 and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1"): # pragma: no cover
with torch.no_grad():
with torch.cpu.amp.autocast():
model = ipex.quantization.convert(model, inplace=inplace)
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
model = torch.jit.trace(
model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False
)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
else:
model = ipex.quantization.convert(model, inplace=inplace)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except: # pragma: no cover
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
simple_inference(model, example_inputs, iterations=2)
return model
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from packaging.version import Version

from neural_compressor.torch.algorithms.static_quant import (
CpuInfo,
TransformerBasedModelBlockPatternDetector,
dump_model_op_stats,
generate_activation_observer,
Expand Down
68 changes: 47 additions & 21 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

import json
import os
from copy import deepcopy
from types import MethodType

Expand All @@ -34,6 +35,7 @@
from neural_compressor.torch.utils import logger

from .utility import (
CpuInfo,
cfg_to_qconfig,
dump_model_op_stats,
get_ipex_version,
Expand Down Expand Up @@ -75,6 +77,8 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

use_bf16 = self.quant_config.get("use_bf16", None)

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
Expand Down Expand Up @@ -108,10 +112,12 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Returns:
A quantized model.
"""
use_bf16 = self.quant_config.get("use_bf16", None)

from neural_compressor.torch.algorithms.static_quant import save

model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
Expand All @@ -125,33 +131,53 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
return model


def _ipex_post_quant_process(model, example_inputs, inplace=False):
def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False):
"""Convert to a jit model.
Args:
model: a prepared model.
example_inputs: used to trace torch model.
use_bf16: whether to use bf16 for mixed precision.
inplace: whether to carry out model transformations in-place.
Returns:
A converted jit model.
"""
model = ipex.quantization.convert(model, inplace=inplace)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
simple_inference(model, example_inputs, iterations=2)
return model
if use_bf16 and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1"): # pragma: no cover
with torch.no_grad():
with torch.cpu.amp.autocast():
model = ipex.quantization.convert(model, inplace=inplace)
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
model = torch.jit.trace(
model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False
)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
else: # pragma: no cover
model = ipex.quantization.convert(model, inplace=inplace)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
simple_inference(model, example_inputs, iterations=2)
return model
2 changes: 2 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def static_quant_entry(
"algorithm": cfg.act_algo,
},
}
quant_config_mapping["use_bf16"] = "bf16" not in cfg.excluded_precisions

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
Expand Down Expand Up @@ -278,6 +279,7 @@ def smooth_quant_entry(
"last_conv_or_matmul_quantization": True,
"pre_post_process_quantization": True,
}
quant_config_mapping["use_bf16"] = "bf16" not in cfg.excluded_precisions

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
Expand Down
6 changes: 6 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ class StaticQuantConfig(BaseConfig):
"act_sym",
"act_granularity",
"act_algo",
"excluded_precisions",
]
supported_configs: List[OperatorConfig] = []

Expand All @@ -984,6 +985,7 @@ def __init__(
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "kl",
excluded_precisions: list = [],
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init Static Quant Configs."""
Expand All @@ -996,6 +998,7 @@ def __init__(
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self.excluded_precisions = excluded_precisions
self._post_init()

@classmethod
Expand Down Expand Up @@ -1057,6 +1060,7 @@ class SmoothQuantConfig(BaseConfig):
"act_sym",
"act_granularity",
"act_algo",
"excluded_precisions",
"alpha",
"folding",
"scale_sharing",
Expand All @@ -1074,6 +1078,7 @@ def __init__(
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "kl",
excluded_precisions: list = [],
alpha: float = 0.5,
folding: bool = False,
# below for autotune
Expand All @@ -1097,6 +1102,7 @@ def __init__(
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self.excluded_precisions = excluded_precisions
self.alpha = alpha
self.folding = folding
# below for autotune
Expand Down
17 changes: 17 additions & 0 deletions test/3x/torch/quantization/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,20 @@ def test_smooth_quant_with_quantize_API(self):
example_dict = {"x": example_inputs}
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_dict)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_smooth_quant_mixed_precision(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config() # do mixed_precison by default.
example_inputs = torch.randn([1, 3])

# prepare/convert API
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

# quantize API
quant_config.excluded_precisions = ["bf16"]
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
11 changes: 11 additions & 0 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,14 @@ def test_static_quant_with_quantize_API(self):
example_inputs = self.input
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_static_quant_mixed_precision(self):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_static_config()
quant_config.excluded_precisions = ["bf16"]
example_inputs = self.input
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

0 comments on commit eaa3a58

Please sign in to comment.