Skip to content

Commit

Permalink
Register pt2e static quantization (#1761)
Browse files Browse the repository at this point in the history
This PR 1) align the `W8A8StaticQuantizer` with Quantizer, 2) add export API, 3)map the StaticQuantConfig to X86InductorQuantizer's config.
---------

Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored May 9, 2024
1 parent 4e31b4d commit 43c3580
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 70 deletions.
3 changes: 3 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
75 changes: 16 additions & 59 deletions neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API.
# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26
# Some code snippets are taken from the X86InductorQuantizer tutorial.
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html


from typing import Any, Dict, Optional, Tuple, Union
from typing import Any

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
Expand All @@ -28,71 +26,30 @@
from torch.fx.graph_module import GraphModule

from neural_compressor.common.utils import logger
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config


class W8A8StaticQuantizer:
class W8A8StaticQuantizer(Quantizer):

@staticmethod
def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer:
# TODO: add the logic to update the quantizer based on the quant_config
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
if not quant_config:
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
else:
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config)
return quantizer

@staticmethod
def export_model(
model,
example_inputs: Tuple[Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[GraphModule]:
exported_model = None
try:
with torch.no_grad():
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
# updated to use the official `torch.export` API when that is ready.
cur_version = get_torch_version()
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
logger.warning(
(
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
"If you want to use `dynamic_shapes` to export model, "
"please upgrade to 2.3.0 or later."
),
cur_version,
)
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
else: # pragma: no cover
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
model, args=example_inputs, dynamic_shapes=dynamic_shapes
)
except Exception as e:
logger.error(f"Failed to export the model: {e}")
return exported_model

def prepare(
self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any
) -> GraphModule:
def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
"""Prepare the model for calibration.
There are two steps in this process:
1) export the eager model into model with Aten IR.
2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
Create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
"""
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
# Set the model to eval mode
model = model.eval()

# 1) Capture the FX Graph to be quantized
dynamic_shapes = kwargs.get("dynamic_shapes", None)
exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
logger.info("Exported the model to Aten IR successfully.")
if exported_model is None:
return

# 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
quantizer = X86InductorQuantizer()
quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config)
prepared_model = prepare_pt2e(exported_model, quantizer)
quant_config = self.quant_config
assert model._exported, "The model should be exported before preparing it for calibration."
quantizer = self.update_quantizer_based_on_quant_config(quant_config)
prepared_model = prepare_pt2e(model, quantizer)
return prepared_model

def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
Expand Down
15 changes: 15 additions & 0 deletions neural_compressor/torch/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.export._export import export_model_for_pt2e_quant, export
73 changes: 73 additions & 0 deletions neural_compressor/torch/export/_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch._export import capture_pre_autograd_graph
from torch.fx.graph_module import GraphModule

from neural_compressor.common.utils import logger
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported

__all__ = ["export", "export_model_for_pt2e_quant"]


def export_model_for_pt2e_quant(
model: torch.nn.Module,
example_inputs: Tuple[Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[GraphModule]:
"""Export the eager model into model with Aten IR."""
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
# Set the model to eval mode
model = model.eval()
exported_model = None
try:
with torch.no_grad():
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
# updated to use the official `torch.export` API when that is ready.
cur_version = get_torch_version()
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
logger.warning(
(
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
"If you want to use `dynamic_shapes` to export model, "
"please upgrade to 2.3.0 or later."
),
cur_version,
)
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
else:
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
model, args=example_inputs, dynamic_shapes=dynamic_shapes
)
exported_model._exported = True
logger.info("Exported the model to Aten IR successfully.")
except Exception as e:
logger.error(f"Failed to export the model: {e}")

return exported_model


def export(
model: torch.nn.Module,
example_inputs: Tuple[Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[GraphModule]:
if not is_ipex_imported():
return export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes)
else:
# TODO, add `export` for ipex
pass
24 changes: 23 additions & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
StaticQuantConfig,
TEQConfig,
)
from neural_compressor.torch.utils import Mode, logger, register_algo
from neural_compressor.torch.utils import Mode, is_ipex_imported, logger, register_algo
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT


###################### RTN Algo Entry ##################################
Expand Down Expand Up @@ -147,6 +148,8 @@ def static_quant_entry(
*args,
**kwargs,
) -> torch.nn.Module:
if not is_ipex_imported():
return pt2e_static_quant_entry(model, configs_mapping, mode, *args, **kwargs)
logger.info("Quantize model with the static quant algorithm.")
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer

Expand Down Expand Up @@ -191,6 +194,25 @@ def static_quant_entry(
return model


###################### PT2E Static Quant Algo Entry ##################################
@register_algo(name=PT2E_STATIC_QUANT)
@torch.no_grad()
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
for _, quant_config in configs_mapping.items():
if quant_config.name == STATIC_QUANT:
w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config)
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
return model


###################### Smooth Quant Algo Entry ##################################
@register_algo(name=SMOOTH_QUANT)
@torch.no_grad()
Expand Down
22 changes: 18 additions & 4 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=import-error

from collections import OrderedDict
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch

Expand All @@ -40,7 +40,7 @@
STATIC_QUANT,
TEQ,
)
from neural_compressor.torch.utils import is_hpex_available, logger
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, logger
from neural_compressor.torch.utils.constants import (
PRIORITY_AUTOROUND,
PRIORITY_AWQ,
Expand Down Expand Up @@ -820,19 +820,31 @@ def __init__(
@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_static_config = StaticQuantConfig()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively

_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
if is_ipex_imported():
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
if is_ipex_imported():
return super().to_config_mapping(config_list, model_info)
config_mapping = OrderedDict({self.name: self})
return config_mapping

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]:
return StaticQuantConfig(act_sym=[True, False], act_algo=["kl", "minmax"])
Expand All @@ -844,6 +856,8 @@ def get_default_static_config() -> StaticQuantConfig:
Returns:
the default static quant config.
"""
if not is_ipex_imported():
return StaticQuantConfig(w_granularity="per_tensor")
return StaticQuantConfig()


Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@
PRIORITY_AWQ = 70
PRIORITY_TEQ = 60
PRIORITY_AUTOROUND = 50


PT2E_STATIC_QUANT = "pt2e_static_quant"
9 changes: 9 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import torch
from packaging.version import Version

Expand Down Expand Up @@ -65,6 +67,13 @@ def get_torch_version():
return version


def is_ipex_imported() -> bool:
for name, _ in sys.modules.items():
if name == "intel_extension_for_pytorch":
return True
return False


def get_device(device_name="auto"):
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

Expand Down
Loading

0 comments on commit 43c3580

Please sign in to comment.