Skip to content

Commit

Permalink
Support PT2E save and load (#1918)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Jul 15, 2024
1 parent 34f0a9f commit 7a4715c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 3 deletions.
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from .save_load import save, load
42 changes: 42 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import torch

from neural_compressor.common.utils import load_config_mapping, save_config_mapping
from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger


def save(model, example_inputs, output_dir="./saved_results"):
os.makedirs(output_dir, exist_ok=True)
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
quantized_ep = torch.export.export(model, example_inputs)
torch.export.save(quantized_ep, qmodel_file_path)
for key, op_config in model.qconfig.items():
model.qconfig[key] = op_config.to_dict()
with open(qconfig_file_path, "w") as f:
json.dump(model.qconfig, f, indent=4)

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))


def load(output_dir="./saved_results"):
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
loaded_quantized_ep = torch.export.load(qmodel_file_path)
return loaded_quantized_ep.module()
6 changes: 6 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def static_quant_entry(
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
Expand All @@ -221,6 +222,8 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
return model


Expand All @@ -230,6 +233,7 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
Expand All @@ -240,6 +244,8 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode,
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
return model


Expand Down
5 changes: 5 additions & 0 deletions neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
from neural_compressor.torch.algorithms import static_quant

return static_quant.load(model_name_or_path)
elif "static_quant" in per_op_qconfig.keys() or "pt2e_dynamic_quant" in per_op_qconfig.keys(): # PT2E
from neural_compressor.torch.algorithms import pt2e_quant

return pt2e_quant.load(model_name_or_path)
else:
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
# select load function
Expand All @@ -102,6 +106,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
from neural_compressor.torch.algorithms import habana_fp8

return habana_fp8.load(model_name_or_path, original_model)

elif format == LoadFormat.HUGGINGFACE.value:
# now only support load huggingface WOQ causal language model
from neural_compressor.torch.algorithms import weight_only
Expand Down
18 changes: 15 additions & 3 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import unittest
from unittest.mock import patch
import shutil

import pytest
import torch
Expand Down Expand Up @@ -33,6 +31,8 @@ def _is_ipex_imported():


class TestPT2EQuantization:
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

@staticmethod
def get_toy_model():
Expand Down Expand Up @@ -114,6 +114,18 @@ def calib_fn(model):
config.freezing = True
q_model_out = q_model(*example_inputs)
assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!"

# test save and load
q_model.save(
example_inputs=example_inputs,
output_dir="./saved_results",
)
from neural_compressor.torch.quantization import load

loaded_quantized_model = load("./saved_results")
loaded_q_model_out = loaded_quantized_model(*example_inputs)
assert torch.equal(loaded_q_model_out, q_model_out)

opt_model = torch.compile(q_model)
out = opt_model(*example_inputs)
logger.warning("out shape is %s", out.shape)
Expand Down

0 comments on commit 7a4715c

Please sign in to comment.