diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index 92f21936e52..4de5a0232e1 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -26,7 +26,12 @@ ) # TODO(Yi): move config to config.py -from neural_compressor.torch.quantization.autotune import autotune, TuningConfig, get_all_config_set +from neural_compressor.torch.quantization.autotune import ( + autotune, + TuningConfig, + get_all_config_set, + get_rtn_double_quant_config_set, +) ### Quantization Function Registration ### import neural_compressor.torch.quantization.algorithm_entry diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index 6265908d18e..bd6d8ddcbae 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -17,19 +17,26 @@ import torch -from neural_compressor.common import Logger from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning from neural_compressor.torch.quantization import quantize -from neural_compressor.torch.quantization.config import FRAMEWORK_NAME -from neural_compressor.torch.utils import logger +from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig +from neural_compressor.torch.utils import constants, logger __all__ = [ "autotune", "get_all_config_set", + "get_rtn_double_quant_config_set", ] +def get_rtn_double_quant_config_set() -> List[RTNConfig]: + rtn_double_quant_config_set = [] + for double_quant_type, double_quant_config in constants.DOUBLE_QUANT_CONFIGS.items(): + rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config)) + return rtn_double_quant_config_set + + def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) @@ -52,7 +59,7 @@ def autotune( for trial_index, quant_config in enumerate(config_loader): tuning_logger.trial_start(trial_index=trial_index) tuning_logger.quantization_start() - logger.info(f"quant config: {quant_config}") + logger.info(quant_config.to_dict()) # !!! Make sure to use deepcopy only when inplace is set to `True`. q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True) tuning_logger.quantization_end() @@ -62,6 +69,7 @@ def autotune( tuning_monitor.add_trial_result(trial_index, eval_result, quant_config) tuning_logger.trial_end(trial_index) if tuning_monitor.need_stop(): + logger.info("Stopped tuning.") best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config() # !!! Make sure to use deepcopy only when inplace is set to `True`. quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True) diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 8350772e05e..89db92bea76 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -55,7 +55,8 @@ def quantize( assert isinstance( quant_config, BaseConfig ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." - logger.info(f"Quantize model with config: \n {quant_config.to_json_string()} \n") + logger.info("Quantize model with config:") + logger.info(quant_config.to_dict()) # select quantization algo according to config model_info = quant_config.get_model_info(model=q_model) diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index 086e58593cc..3dca0ebb612 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -1,12 +1,37 @@ import unittest from functools import wraps +from unittest.mock import patch import torch import transformers from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set -from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils import constants, logger + +FAKE_DOUBLE_QUANT_CONFIGS = { + "BNB_NF4": { + "dtype": "nf4", + "bits": 4, + "group_size": 32, + "use_double_quant": True, + "double_quant_bits": 8, + "double_quant_dtype": "int", + "double_quant_use_sym": False, + "double_quant_group_size": 256, + }, + "GGML_TYPE_Q4_K": { + "dtype": "int", + "bits": 4, + "use_sym": False, + "group_size": 32, + "use_double_quant": True, + "double_quant_bits": 6, + "double_quant_dtype": "int", + "double_quant_use_sym": True, + "double_quant_group_size": 8, + }, +} def reset_tuning_target(test_func): @@ -239,6 +264,59 @@ def eval_acc_fn(model): best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn) self.assertIsNone(best_model) + @reset_tuning_target + def test_rtn_double_quant_config_set(self) -> None: + from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set + from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS + + rtn_double_quant_config_set = get_rtn_double_quant_config_set() + self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS)) + + def eval_acc_fn(model) -> float: + return 1.0 + + custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), max_trials=10) + best_model = autotune( + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] + ) + self.assertIsNotNone(best_model) + + @reset_tuning_target + def test_rtn_double_quant_config_set2(self) -> None: + from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set + from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS + + rtn_double_quant_config_set = get_rtn_double_quant_config_set() + self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS)) + + def eval_acc_fn(model) -> float: + return 1.0 + + custom_tune_config = TuningConfig( + config_set=get_rtn_double_quant_config_set(), max_trials=10, tolerable_loss=-1 + ) + best_model = autotune( + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] + ) + self.assertIsNone(best_model) + + @patch("neural_compressor.torch.utils.constants.DOUBLE_QUANT_CONFIGS", FAKE_DOUBLE_QUANT_CONFIGS) + def test_rtn_double_quant_config_set3(self) -> None: + from neural_compressor.torch.quantization import get_rtn_double_quant_config_set + + rtn_double_quant_config_set = get_rtn_double_quant_config_set() + print(len(rtn_double_quant_config_set)) + self.assertEqual(len(constants.DOUBLE_QUANT_CONFIGS), len(FAKE_DOUBLE_QUANT_CONFIGS)) + + def eval_acc_fn(model) -> float: + return 1.0 + + custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), tolerable_loss=-1) + best_model = autotune( + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] + ) + self.assertIsNone(best_model) + if __name__ == "__main__": unittest.main()