Skip to content

Commit

Permalink
pt 3.x config for static quant and smooth quant (#1568)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Jan 25, 2024
1 parent 52ea445 commit 7766454
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 1 deletion.
1 change: 1 addition & 0 deletions neural_compressor/common/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
COMPOSABLE_CONFIG = "composable_config"
RTN = "rtn"
STATIC_QUANT = "static_quant"
SMOOTH_QUANT = "smooth_quant"
GPTQ = "gptq"
FP8_QUANT = "fp8_quant"

Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)

from neural_compressor.common.base_tuning import TuningConfig
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@
get_default_rtn_config,
GPTQConfig,
get_default_gptq_config,
StaticQuantConfig,
get_default_static_config,
SmoothQuantConfig,
get_default_sq_config,
)
195 changes: 194 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
import torch

from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
from neural_compressor.common.utils import (
DEFAULT_WHITE_LIST,
FP8_QUANT,
GPTQ,
OP_NAME_OR_MODULE_TYPE,
RTN,
SMOOTH_QUANT,
STATIC_QUANT,
)
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

Expand Down Expand Up @@ -282,6 +290,191 @@ def get_default_gptq_config() -> GPTQConfig:
return GPTQConfig()


######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
"""Config class for static quantization."""

name = STATIC_QUANT
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype",
"act_sym",
"act_granularity",
"act_algo",
]
supported_configs: List[OperatorConfig] = []

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "kl",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init Static Quant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_static_config = StaticQuantConfig()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result


# TODO(Yi) run `register_supported_configs` for all registered config.
StaticQuantConfig.register_supported_configs()


def get_default_static_config() -> StaticQuantConfig:
"""Generate the default static quant config.
Returns:
the default static quant config.
"""
return StaticQuantConfig()


######################## Smooth Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
class SmoothQuantConfig(BaseConfig):
"""Config class for smooth quantization."""

name = SMOOTH_QUANT
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype",
"act_sym",
"act_granularity",
"act_algo",
"alpha",
"folding",
"scale_sharing",
"auto_alpha_args",
]
supported_configs: List[OperatorConfig] = []

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "kl",
alpha: float = 0.5,
folding: bool = False,
# below for autotune
scale_sharing: bool = False,
init_alpha: float = 0.5,
alpha_min: float = 0.0,
alpha_max: float = 1.0,
alpha_step: float = 0.1,
shared_criterion: str = "max",
enable_blockwise_loss: bool = False,
auto_alpha_args: dict = None,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init SmoothQuant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
self.alpha = alpha
self.folding = folding
# below for autotune
self.scale_sharing = scale_sharing
self.init_alpha = init_alpha
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.alpha_step = alpha_step
self.shared_criterion = shared_criterion
self.enable_blockwise_loss = enable_blockwise_loss
self.auto_alpha_args = {
"init_alpha": self.init_alpha,
"alpha_min": self.alpha_min,
"alpha_max": self.alpha_max,
"alpha_step": self.alpha_step,
"shared_criterion": self.shared_criterion,
"enable_blockwise_loss": self.enable_blockwise_loss,
}
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_sq_config = SmoothQuantConfig()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
return filter_result


# TODO(Yi) run `register_supported_configs` for all registered config.
SmoothQuantConfig.register_supported_configs()


def get_default_sq_config() -> SmoothQuantConfig:
"""Generate the default smoothquant config.
Returns:
the default smoothquant config.
"""
return SmoothQuantConfig()


######################## FP8 Config ###############################
if is_hpex_avaliable():

Expand Down
16 changes: 16 additions & 0 deletions test/3x/torch/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,22 @@ def test_gptq_config(self):
gptq_config2 = GPTQConfig.from_dict(quant_config_dict["gptq"])
self.assertEqual(gptq_config1.to_dict(), gptq_config2.to_dict())

def test_static_quant_config(self):
from neural_compressor.torch.quantization import StaticQuantConfig

static_config1 = StaticQuantConfig(w_dtype="int8", act_sym=True, act_algo="minmax")
quant_config_dict = {"static": {"w_dtype": "int8", "act_sym": True, "act_algo": "minmax"}}
static_config2 = StaticQuantConfig.from_dict(quant_config_dict["static"])
self.assertEqual(static_config1.to_dict(), static_config2.to_dict())

def test_smooth_quant_config(self):
from neural_compressor.torch.quantization import SmoothQuantConfig

sq_config1 = SmoothQuantConfig(alpha=0.8, folding=True)
quant_config_dict = {"sq": {"alpha": 0.8, "folding": True}}
sq_config2 = SmoothQuantConfig.from_dict(quant_config_dict["sq"])
self.assertEqual(sq_config1.to_dict(), sq_config2.to_dict())


class TestQuantConfigForAutotune(unittest.TestCase):
def test_expand_config(self):
Expand Down

0 comments on commit 7766454

Please sign in to comment.