Skip to content

Commit

Permalink
Add pt2e dynamic quantization (#1795)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored May 16, 2024
1 parent d4b0f0f commit 30b36b8
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 41 deletions.
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
14 changes: 9 additions & 5 deletions neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

from typing import Any

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.fx.graph_module import GraphModule
Expand All @@ -30,15 +28,21 @@
from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config


class W8A8StaticQuantizer(Quantizer):
class W8A8PT2EQuantizer(Quantizer):
is_dynamic = False

def __init__(self, quant_config=None):
super().__init__(quant_config)

@staticmethod
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
if not quant_config:
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
quantizer.set_global(
xiq.get_default_x86_inductor_quantization_config(is_dynamic=W8A8PT2EQuantizer.is_dynamic)
)
else:
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config)
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config, is_dynamic=W8A8PT2EQuantizer.is_dynamic)
return quantizer

def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
get_default_fp8_config,
get_default_fp8_config_set,
get_woq_tuning_config,
DynamicQuantConfig,
get_default_dynamic_config,
)

from neural_compressor.torch.quantization.autotune import (
Expand Down
28 changes: 24 additions & 4 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from copy import deepcopy
from types import MethodType
from typing import Any, Callable, Dict, Tuple
from typing import Callable, Dict, Tuple

import torch

Expand Down Expand Up @@ -42,7 +42,7 @@
TEQConfig,
)
from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT
from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT


###################### RTN Algo Entry ##################################
Expand Down Expand Up @@ -186,19 +186,39 @@ def static_quant_entry(
return model


###################### PT2E Dynamic Quant Algo Entry ##################################
@register_algo(name=PT2E_DYNAMIC_QUANT)
@torch.no_grad()
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
W8A8PT2EQuantizer.is_dynamic = True
for _, quant_config in configs_mapping.items():
if quant_config.name == PT2E_DYNAMIC_QUANT:
w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config)
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
return model


###################### PT2E Static Quant Algo Entry ##################################
@register_algo(name=PT2E_STATIC_QUANT)
@torch.no_grad()
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
for _, quant_config in configs_mapping.items():
if quant_config.name == STATIC_QUANT:
w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config)
w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config)
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
Expand Down
77 changes: 76 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=import-error

from collections import OrderedDict
from typing import Any, Callable, Dict, List, NamedTuple, Optional
from typing import Callable, Dict, List, NamedTuple, Optional
from typing import OrderedDict as OrderedDictType
from typing import Tuple, Union

Expand Down Expand Up @@ -50,6 +50,7 @@
PRIORITY_HQQ,
PRIORITY_RTN,
PRIORITY_TEQ,
PT2E_DYNAMIC_QUANT,
)

__all__ = [
Expand Down Expand Up @@ -778,6 +779,80 @@ def get_default_AutoRound_config() -> AutoRoundConfig:
return AutoRoundConfig()


######################## Dynamic Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT)
class DynamicQuantConfig(BaseConfig):
"""Config class for dynamic quantization."""

name = PT2E_DYNAMIC_QUANT
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype",
"act_sym",
"act_granularity",
"act_algo",
]
supported_configs: List[OperatorConfig] = []

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_tensor",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "kl",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init Dynamic Quant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
linear_static_config = cls()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs=None):
return None

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
config_mapping = OrderedDict({self.name: self})
return config_mapping

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "DynamicQuantConfig", List["DynamicQuantConfig"]]:
return cls(act_sym=[True, False], act_algo=["kl", "minmax"])


def get_default_dynamic_config() -> DynamicQuantConfig:
"""Generate the default dynamic quant config.
Returns:
the default dynamic quant config.
"""
return DynamicQuantConfig()


######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/torch/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@


PT2E_STATIC_QUANT = "pt2e_static_quant"
PT2E_DYNAMIC_QUANT = "pt2e_dynamic_quant"
18 changes: 9 additions & 9 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ def is_hpex_available():
return _hpex_available


try:
import intel_extension_for_pytorch as ipex

_ipex_available = True
except:
_ipex_available = False


def is_ipex_available():
try:
import intel_extension_for_pytorch as ipex

_ipex_available = True
except:
_ipex_available = False
return _ipex_available


def get_ipex_version():
if _ipex_available:
if is_ipex_available():
import intel_extension_for_pytorch as ipex

try:
ipex_version = ipex.__version__.split("+")[0]
except ValueError as e: # pragma: no cover
Expand Down
29 changes: 20 additions & 9 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -172,30 +172,41 @@ def postprocess_model(model, mode, quantizer):
del model.quantizer


def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec:
def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
select_dtype = dtype_mapping[dtype]
min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)}
qscheme_mapping = {
"per_channel": {True: torch.per_channel_symmetric, False: torch.per_tensor_affine},
"per_tensor": {True: torch.per_tensor_symmetric, False: torch.per_tensor_affine},
}
observer_mapping = {
"placeholder": PlaceholderObserver,
"minmax": MinMaxObserver,
"kl": HistogramObserver,
}
# Force to use placeholder observer for dynamic quantization
if is_dynamic:
algo = "placeholder"
# algo
observer_or_fake_quant_ctr = observer_mapping[algo]
# qscheme
qscheme = qscheme_mapping[granularity][sym]
quantization_spec = QuantizationSpec(
dtype=dtype_mapping[dtype], observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, qscheme=qscheme
dtype=select_dtype,
quant_min=min_max_mapping[select_dtype][0],
quant_max=min_max_mapping[select_dtype][1],
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
qscheme=qscheme,
is_dynamic=is_dynamic,
)
return quantization_spec


def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig:
default_quant_config = xiq.get_default_x86_inductor_quantization_config()
def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
input_act_quant_spec = create_quant_spec_from_config(
inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo
inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic
)
weight_quant_spec = create_quant_spec_from_config(
inc_config.w_dtype, inc_config.w_sym, inc_config.w_granularity, inc_config.w_algo
Expand All @@ -210,14 +221,14 @@ def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig:
return quant_config


def create_xiq_quantizer_from_pt2e_config(config) -> X86InductorQuantizer:
def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer:
quantizer = xiq.X86InductorQuantizer()
# set global
global_config = _map_inc_config_to_torch_quant_config(config)
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
quantizer.set_global(global_config)
# set local
for module_or_func_name, local_config in config.local_config.items():
local_quant_config = _map_inc_config_to_torch_quant_config(local_config)
local_quant_config = _map_inc_config_to_torch_quant_config(local_config, is_dynamic)
if isinstance(module_or_func_name, torch.nn.Module):
quantizer.set_module_type_qconfig(module_or_func_name, local_quant_config)
else:
Expand Down
8 changes: 4 additions & 4 deletions test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import torch

from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.export import export_model_for_pt2e_quant
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version


class TestW8A8StaticQuantizer:
class TestW8A8PT2EQuantizer:

@staticmethod
def get_toy_model():
Expand Down Expand Up @@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_quantizer_on_simple_model(self):
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
w8a8_static_quantizer = W8A8StaticQuantizer()
w8a8_static_quantizer = W8A8PT2EQuantizer()
# prepare
prepare_model = w8a8_static_quantizer.prepare(model, example_inputs=example_inputs)
# calibrate
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_quantizer_on_llm(self):
model = export_model_for_pt2e_quant(model, example_inputs=example_inputs)

quant_config = None
w8a8_static_quantizer = W8A8StaticQuantizer()
w8a8_static_quantizer = W8A8PT2EQuantizer()
# prepare
prepare_model = w8a8_static_quantizer.prepare(model)
# calibrate
Expand Down
Loading

0 comments on commit 30b36b8

Please sign in to comment.