Skip to content

Commit

Permalink
Integrate AutoRound v0.3 to 2x (#1926)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Jul 17, 2024
1 parent bfa27e4 commit fd96851
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/ut/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
fi

if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
pip install auto-round
pip install git+https://github.com/intel/auto-round.git@24b2e74070f2b4e6f26ff069ec75af74cf5b177c
fi

# test deps
Expand Down
22 changes: 17 additions & 5 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4905,13 +4905,13 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
enable_minmax_tuning = self.recipes["autoround_args"].get("enable_minmax_tuning", True)
lr = self.recipes["autoround_args"].get("lr", None)
minmax_lr = self.recipes["autoround_args"].get("minmax_lr", None)
low_gpu_mem_usage = self.recipes["autoround_args"].get("low_gpu_mem_usage", True)
low_gpu_mem_usage = self.recipes["autoround_args"].get("low_gpu_mem_usage", False)
iters = self.recipes["autoround_args"].get("iters", 200)
seqlen = self.recipes["autoround_args"].get("seqlen", 2048)
n_samples = self.recipes["autoround_args"].get("n_samples", 512)
nsamples = self.recipes["autoround_args"].get("nsamples", 128)
sampler = self.recipes["autoround_args"].get("sampler", "rand")
seed = self.recipes["autoround_args"].get("seed", 42)
n_blocks = self.recipes["autoround_args"].get("n_blocks", 1)
nblocks = self.recipes["autoround_args"].get("nblocks", 1)
gradient_accumulate_steps = self.recipes["autoround_args"].get("gradient_accumulate_steps", 1)
not_use_best_mse = self.recipes["autoround_args"].get("not_use_best_mse", False)
dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1)
Expand All @@ -4922,6 +4922,12 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
bits = self.recipes["autoround_args"].get("bits", 4)
group_size = self.recipes["autoround_args"].get("group_size", 128)
sym = self.recipes["autoround_args"].get("scheme", "asym") == "sym"
act_bits = self.recipes["autoround_args"].get("act_bits", 32)
act_group_size = self.recipes["autoround_args"].get("act_group_size", None)
act_sym = self.recipes["autoround_args"].get("act_sym", None)
act_dynamic = self.recipes["autoround_args"].get("act_dynamic", True)
multimodal = self.recipes["autoround_args"].get("multimodal", False)
use_layer_wise = self.recipes["autoround_args"].get("use_layer_wise", False)

if dataloader is not None:
dataset = dataloader
Expand All @@ -4944,15 +4950,21 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
nsamples=nsamples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
nblocks=nblocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
data_type=data_type,
scale_dtype=scale_dtype,
multimodal=multimodal,
act_bits=act_bits,
act_group_size=act_group_size,
act_sym=act_sym,
act_dynamic=act_dynamic,
use_layer_wise=use_layer_wise,
)
return model, autoround_config

Expand Down
8 changes: 3 additions & 5 deletions neural_compressor/adaptor/torch_utils/auto_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=128):
"""Generate a DataLoader for calibration using specified parameters.
Args:
Expand All @@ -25,14 +25,12 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
split (str, optional): The data split to use. Defaults to None.
seed (int, optional): The random seed for reproducibility. Defaults to 42.
bs (int, optional): The batch size. Defaults to 4.
n_samples (int, optional): The total number of samples to include. Defaults to 512.
nsamples (int, optional): The total number of samples to include. Defaults to 128.
Returns:
DataLoader: The DataLoader for the calibrated dataset.
"""
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

dataloader = get_dataloader(
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
)
dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset_name, seed=seed, bs=bs, nsamples=nsamples)
return dataloader
59 changes: 40 additions & 19 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,21 +694,28 @@ def autoround_quantize(
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
low_gpu_mem_usage: bool = False,
iters: int = 200,
seqlen: int = 2048,
n_samples: int = 512,
nsamples: int = 128,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
nblocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
data_type: str = "int", ##only support int for now
scale_dtype: str = "fp16",
multimodal: bool = False,
act_bits: int = 32,
act_group_size: int = None,
act_sym: bool = None,
act_dynamic: bool = True,
use_layer_wise: bool = False,
**kwargs,
):
"""Run autoround weight-only quantization.
Args:
model: The PyTorch model to be quantized.
tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied.
Expand All @@ -717,15 +724,19 @@ def autoround_quantize(
sym (bool): Whether symmetric quantization is to be used (default is False).
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
weight_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False
}
...
}
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
'act_data_type': None,
'act_bits': 32,
'act_sym': None,
'act_dynamic': True,
}
...,
}
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
Expand All @@ -737,20 +748,24 @@ def autoround_quantize(
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
lr (float): The learning rate (default is None, will be set to 1.0/iters).
minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False).
iters (int): Number of iterations (default is 200).
seqlen (int): Data length of the sequence for tuning (default is 2048).
n_samples (int): Number of samples (default is 512).
nsamples (int): Number of samples (default is 128).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
nblocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
multimodal(bool): Enable multimodal model quantization, (default is "False").
act_bits (int): Number of bits for activation quantization. Default is 32.
act_group_size (int): Group size for activation quantization. Default is None.
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
Returns:
The quantized model.
"""
Expand All @@ -762,7 +777,7 @@ def autoround_quantize(
bits=bits,
group_size=group_size,
sym=sym,
weight_config=weight_config,
layer_config=weight_config,
enable_full_range=enable_full_range, ##for symmetric, TODO support later
batch_size=batch_size,
amp=amp,
Expand All @@ -776,15 +791,21 @@ def autoround_quantize(
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
nsamples=nsamples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
nblocks=nblocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
data_type=data_type, ## only support data_type
scale_dtype=scale_dtype,
multimodal=multimodal,
act_bits=act_bits,
act_group_size=act_group_size,
act_sym=act_sym,
act_dynamic=act_dynamic,
low_cpu_mem_usage=use_layer_wise,
**kwargs,
)
qdq_model, weight_config = rounder.quantize()
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def export_compressed_model(

self.model = pack_model(
self.model,
weight_config=autoround_config,
layer_config=autoround_config,
enable_full_range=enable_full_range,
compression_dtype=compression_dtype,
compression_dim=compression_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def test_AutoRound_quant(self):
tokenizer = transformers.AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
)
dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=20)
dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=20)
fp32_model = copy.deepcopy(self.gptj)
conf = PostTrainingQuantConfig(
approach="weight_only",
Expand Down
2 changes: 1 addition & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
accelerate==0.21.0
auto-round
auto-round @ git+https://github.com/intel/auto-round.git@24b2e74070f2b4e6f26ff069ec75af74cf5b177c
dynast==1.6.0rc1
horovod
intel-extension-for-pytorch
Expand Down

0 comments on commit fd96851

Please sign in to comment.