Skip to content

Commit

Permalink
Support double quant tuning (#1591)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Feb 2, 2024
1 parent e7b3478 commit b8d98eb
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 7 deletions.
7 changes: 6 additions & 1 deletion neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
)

# TODO(Yi): move config to config.py
from neural_compressor.torch.quantization.autotune import autotune, TuningConfig, get_all_config_set
from neural_compressor.torch.quantization.autotune import (
autotune,
TuningConfig,
get_all_config_set,
get_rtn_double_quant_config_set,
)

### Quantization Function Registration ###
import neural_compressor.torch.quantization.algorithm_entry
16 changes: 12 additions & 4 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@

import torch

from neural_compressor.common import Logger
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.torch.quantization import quantize
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME
from neural_compressor.torch.utils import logger
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig
from neural_compressor.torch.utils import constants, logger

__all__ = [
"autotune",
"get_all_config_set",
"get_rtn_double_quant_config_set",
]


def get_rtn_double_quant_config_set() -> List[RTNConfig]:
rtn_double_quant_config_set = []
for double_quant_type, double_quant_config in constants.DOUBLE_QUANT_CONFIGS.items():
rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config))
return rtn_double_quant_config_set


def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)

Expand All @@ -52,7 +59,7 @@ def autotune(
for trial_index, quant_config in enumerate(config_loader):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.quantization_start()
logger.info(f"quant config: {quant_config}")
logger.info(quant_config.to_dict())
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
tuning_logger.quantization_end()
Expand All @@ -62,6 +69,7 @@ def autotune(
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
logger.info("Stopped tuning.")
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
# !!! Make sure to use deepcopy only when inplace is set to `True`.
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def quantize(
assert isinstance(
quant_config, BaseConfig
), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}."
logger.info(f"Quantize model with config: \n {quant_config.to_json_string()} \n")
logger.info("Quantize model with config:")
logger.info(quant_config.to_dict())
# select quantization algo according to config

model_info = quant_config.get_model_info(model=q_model)
Expand Down
80 changes: 79 additions & 1 deletion test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
import unittest
from functools import wraps
from unittest.mock import patch

import torch
import transformers

from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor
from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set
from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import constants, logger

FAKE_DOUBLE_QUANT_CONFIGS = {
"BNB_NF4": {
"dtype": "nf4",
"bits": 4,
"group_size": 32,
"use_double_quant": True,
"double_quant_bits": 8,
"double_quant_dtype": "int",
"double_quant_use_sym": False,
"double_quant_group_size": 256,
},
"GGML_TYPE_Q4_K": {
"dtype": "int",
"bits": 4,
"use_sym": False,
"group_size": 32,
"use_double_quant": True,
"double_quant_bits": 6,
"double_quant_dtype": "int",
"double_quant_use_sym": True,
"double_quant_group_size": 8,
},
}


def reset_tuning_target(test_func):
Expand Down Expand Up @@ -239,6 +264,59 @@ def eval_acc_fn(model):
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
self.assertIsNone(best_model)

@reset_tuning_target
def test_rtn_double_quant_config_set(self) -> None:
from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS

rtn_double_quant_config_set = get_rtn_double_quant_config_set()
self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS))

def eval_acc_fn(model) -> float:
return 1.0

custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), max_trials=10)
best_model = autotune(
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
)
self.assertIsNotNone(best_model)

@reset_tuning_target
def test_rtn_double_quant_config_set2(self) -> None:
from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS

rtn_double_quant_config_set = get_rtn_double_quant_config_set()
self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS))

def eval_acc_fn(model) -> float:
return 1.0

custom_tune_config = TuningConfig(
config_set=get_rtn_double_quant_config_set(), max_trials=10, tolerable_loss=-1
)
best_model = autotune(
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
)
self.assertIsNone(best_model)

@patch("neural_compressor.torch.utils.constants.DOUBLE_QUANT_CONFIGS", FAKE_DOUBLE_QUANT_CONFIGS)
def test_rtn_double_quant_config_set3(self) -> None:
from neural_compressor.torch.quantization import get_rtn_double_quant_config_set

rtn_double_quant_config_set = get_rtn_double_quant_config_set()
print(len(rtn_double_quant_config_set))
self.assertEqual(len(constants.DOUBLE_QUANT_CONFIGS), len(FAKE_DOUBLE_QUANT_CONFIGS))

def eval_acc_fn(model) -> float:
return 1.0

custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), tolerable_loss=-1)
best_model = autotune(
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
)
self.assertIsNone(best_model)


if __name__ == "__main__":
unittest.main()

0 comments on commit b8d98eb

Please sign in to comment.