diff --git a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py index b6187ba214a..b3c530ce2fd 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py @@ -14,3 +14,4 @@ from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer +from .save_load import save, load diff --git a/neural_compressor/torch/algorithms/pt2e_quant/save_load.py b/neural_compressor/torch/algorithms/pt2e_quant/save_load.py new file mode 100644 index 00000000000..606c31f41c2 --- /dev/null +++ b/neural_compressor/torch/algorithms/pt2e_quant/save_load.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import torch + +from neural_compressor.common.utils import load_config_mapping, save_config_mapping +from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger + + +def save(model, example_inputs, output_dir="./saved_results"): + os.makedirs(output_dir, exist_ok=True) + qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) + quantized_ep = torch.export.export(model, example_inputs) + torch.export.save(quantized_ep, qmodel_file_path) + for key, op_config in model.qconfig.items(): + model.qconfig[key] = op_config.to_dict() + with open(qconfig_file_path, "w") as f: + json.dump(model.qconfig, f, indent=4) + + logger.info("Save quantized model to {}.".format(qmodel_file_path)) + logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) + + +def load(output_dir="./saved_results"): + qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) + loaded_quantized_ep = torch.export.load(qmodel_file_path) + return loaded_quantized_ep.module() diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index b8a1e3b9202..2a3eada9bf5 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -210,6 +210,7 @@ def static_quant_entry( def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module: logger.info("Quantize model with the PT2E static quant algorithm.") from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer + from neural_compressor.torch.algorithms.pt2e_quant.save_load import save run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) @@ -221,6 +222,8 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode model = w8a8_quantizer.execute( model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace ) + model.qconfig = configs_mapping + model.save = MethodType(save, model) return model @@ -230,6 +233,7 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module: logger.info("Quantize model with the PT2E static quant algorithm.") from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer + from neural_compressor.torch.algorithms.pt2e_quant.save_load import save run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) @@ -240,6 +244,8 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, model = w8a8_quantizer.execute( model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace ) + model.qconfig = configs_mapping + model.save = MethodType(save, model) return model diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index 584b853014f..7cc7f8075d0 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -85,6 +85,10 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" from neural_compressor.torch.algorithms import static_quant return static_quant.load(model_name_or_path) + elif "static_quant" in per_op_qconfig.keys() or "pt2e_dynamic_quant" in per_op_qconfig.keys(): # PT2E + from neural_compressor.torch.algorithms import pt2e_quant + + return pt2e_quant.load(model_name_or_path) else: config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) # select load function @@ -102,6 +106,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu" from neural_compressor.torch.algorithms import habana_fp8 return habana_fp8.load(model_name_or_path, original_model) + elif format == LoadFormat.HUGGINGFACE.value: # now only support load huggingface WOQ causal language model from neural_compressor.torch.algorithms import weight_only diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index e2c643f07c6..d55e9004a3a 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -1,6 +1,4 @@ -import os -import unittest -from unittest.mock import patch +import shutil import pytest import torch @@ -33,6 +31,8 @@ def _is_ipex_imported(): class TestPT2EQuantization: + def teardown_class(self): + shutil.rmtree("saved_results", ignore_errors=True) @staticmethod def get_toy_model(): @@ -114,6 +114,18 @@ def calib_fn(model): config.freezing = True q_model_out = q_model(*example_inputs) assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!" + + # test save and load + q_model.save( + example_inputs=example_inputs, + output_dir="./saved_results", + ) + from neural_compressor.torch.quantization import load + + loaded_quantized_model = load("./saved_results") + loaded_q_model_out = loaded_quantized_model(*example_inputs) + assert torch.equal(loaded_q_model_out, q_model_out) + opt_model = torch.compile(q_model) out = opt_model(*example_inputs) logger.warning("out shape is %s", out.shape)