Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add repack_awq_to_optimum_format function #1998

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight-Only utility."""
import numpy as np
import torch

from neural_compressor.torch.utils import accelerator, device_synchronize, logger
Expand Down Expand Up @@ -1228,3 +1229,221 @@ def convert_dtype_str2torch(str_dtype):
return torch.bfloat16
else:
assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype)


# ref reverse reorder from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L491
def awq_reverse_reorder_int_tensor(int_tensor, bits: int):
"""Awq tensor convert tool.

Reverse_reorder_int_tensor
"""
assert bits == 4

int_tensor = int_tensor.T.contiguous()
compress_ratio = 32 // bits
assert int_tensor.shape[-1] % compress_ratio == 0

order_map = [0, 2, 4, 6, 1, 3, 5, 7]
order_tensor = torch.tensor(order_map, dtype=torch.int32, device=int_tensor.device).reshape(1, -1)
order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1)
order_tensor = order_tensor + torch.arange(
0,
int_tensor.shape[1],
compress_ratio,
dtype=torch.int32,
device=int_tensor.device,
).reshape(-1, 1)
order_tensor = order_tensor.reshape(-1)

reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor]
reverse_order_tensor = reverse_order_tensor[order_tensor]
int_tensor = int_tensor[:, reverse_order_tensor]
return int_tensor


# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516
def unpack_awq(
awq_qweight: torch.Tensor,
awq_qzeros: torch.Tensor,
awq_scales: torch.Tensor,
bits: int,
group_size: int,
):
"""Unpack awq format to actual values.

Args:
awq_qweight (`torch.LongTensor`):
Expected shape: (in_features, out_features // (32 // bits))
awq_qzeros (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features // (32 // bits))
awq_scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)

Returns:
fp16_weight (`torch.LongTensor`):
With shape (in_features, out_features).
zeros (`torch.LongTensor`):
With shape (in_features // group_size, out_features).
"""
assert bits == 4

qzeros = awq_qzeros
qweight = awq_qweight
qweight = qweight.T.contiguous()

infeatures = awq_qweight.shape[0]

wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device).unsqueeze(0)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to(
torch.int16 if bits == 8 else torch.int8
)

# zeros = zeros + 1

torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)

zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1), wf.unsqueeze(-1)).to(
torch.int16 if bits == 8 else torch.int8
)
torch.bitwise_and(weight, (2**bits) - 1, out=weight)
weight = weight.reshape(-1, group_size, weight.shape[2])

weight = weight.view(-1, weight.shape[-1])
zeros = zeros.view(-1, zeros.shape[-1])

zeros = zeros.T.contiguous()
zeros = awq_reverse_reorder_int_tensor(zeros, bits)
weight = awq_reverse_reorder_int_tensor(weight, bits)

# Dequantize weights.
scales = awq_scales
zeros = zeros.contiguous()
scale_zeros = zeros * scales

g_idx = torch.tensor([i // group_size for i in range(infeatures)], dtype=torch.int32)
scale_mat = scales[g_idx]
scale_zeros_mat = scale_zeros[g_idx].half()

qdq_weight_T = weight * scale_mat - scale_zeros_mat.half()

fp16_weight = qdq_weight_T.T

return fp16_weight, zeros


# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516
def pack_from_tensors(
unpacked_qweight: torch.Tensor,
unpacked_qzeros: torch.Tensor,
awq_scales: torch.Tensor,
bits: int,
group_size: int,
):
"""Pack the tensor to optimum format.

Args:
unpacked_qweight (`torch.LongTensor`):
Expected shape: (in_features, out_features)
unpacked_qzeros (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
awq_scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)

Returns:
qweight (`torch.LongTensor`):
With shape (in_features // (32 // bits), out_features)
qzeros (`torch.LongTensor`):
With shape (in_features // group_size, out_features // (32 // bits))
"""
assert bits == 4
W = unpacked_qweight.clone().cpu()

# TODO: This should be checked somehow.
# if isinstance(linear, nn.Conv2d):
# W = W.flatten(1)
# if isinstance(linear, transformers.pytorch_utils.Conv1D):
# W = W.t()

awq_scales = awq_scales.t().contiguous()
unpacked_qzeros = unpacked_qzeros.contiguous()
unpacked_qzeros = unpacked_qzeros.cpu()

awq_scales = awq_scales.cpu()
scale_zeros = unpacked_qzeros.t() * awq_scales
scales = awq_scales.clone()

infeatures = unpacked_qweight.shape[1]

intweight = []
for idx in range(infeatures):
g_idx = idx // group_size

intweight.append(torch.round((W[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)

i = 0
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
for j in range(i, i + (32 // bits)):
qweight[row] |= intweight[j] << (bits * (j - i))
i += 32 // bits
row += 1

qweight = qweight.astype(np.int32)
qweight = torch.from_numpy(qweight)

unpacked_qzeros = unpacked_qzeros - 1
torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros)

unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(unpacked_qzeros.shape[0], unpacked_qzeros.shape[1] // 32 * bits),
dtype=np.uint32,
)
i = 0
col = 0
while col < qzeros.shape[1]:
for j in range(i, i + (32 // bits)):
qzeros[:, col] |= unpacked_qzeros[:, j] << (bits * (j - i))
i += 32 // bits
col += 1

qzeros = qzeros.astype(np.int32)
qzeros = torch.from_numpy(qzeros)

return qweight, qzeros


def repack_awq_to_optimum_format(
awq_qweight: torch.Tensor,
awq_qzeros: torch.Tensor,
awq_scales: torch.Tensor,
bits: int,
group_size: int,
):
"""The function to repack_awq_to_optimum_format.

Args:
awq_qweight (`torch.LongTensor`):
Expected shape: (in_features, out_features // (32 // bits))
awq_qzeros (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features // (32 // bits))
awq_scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)

Returns:
qweight (`torch.LongTensor`):
With shape (in_features // (32 // bits), out_features)
qzeros (`torch.LongTensor`):
With shape (in_features // group_size, out_features // (32 // bits))
scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
"""
unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size)
qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size)
return qweight, qzeros, awq_scales
66 changes: 41 additions & 25 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.utils import set_module

from ..quantization.utils import convert_dtype_torch2str, convert_to_quantized_model, replace_linear, save_low_bit
from ..quantization.utils import (
convert_dtype_torch2str,
convert_to_quantized_model,
repack_awq_and_load_state_dict,
replace_linear,
save_low_bit,
)
from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig


Expand Down Expand Up @@ -179,6 +185,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
) and model.config.model_type == "chatglm":
model = model.float()
model = convert_to_quantized_model(model, quantization_config, device=device_map)
if isinstance(quantization_config, AwqConfig):
quantization_config.backend = "inc"
quantization_config.remove_redundant_parameters()
model.config.quantization_config = quantization_config
else:
Expand Down Expand Up @@ -295,6 +303,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = GPTQConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantization_config)

assert quantization_config is not None, "Detect this model is not a low-bit model."

if commit_hash is None:
Expand Down Expand Up @@ -613,41 +622,48 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):

with ContextManagers(init_contexts):
model = model_class(config, *model_args, **kwargs)

if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
if quantization_config.modules_to_not_convert is None:
quantization_config.modules_to_not_convert = ["lm_head", "transformer.output_layer", "embed_out"]
else:
quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"]
model = build_woq_model(model, quantization_config)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = list(state_dict.keys())

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
model = repack_awq_and_load_state_dict(
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
)
else:
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
38 changes: 38 additions & 0 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from neural_compressor.common.utils import LazyImport, logger
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.algorithms.weight_only.utility import repack_awq_to_optimum_format
from neural_compressor.torch.quantization import (
AutoRoundConfig,
AWQConfig,
Expand Down Expand Up @@ -654,3 +655,40 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo
token=kwargs.get("token"),
)
self.quantization_config.save_pretrained(save_directory, **kwargs)


def repack_awq_and_load_state_dict(
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
):
from transformers.modeling_utils import load_state_dict

bits = quantization_config.bits
group_size = quantization_config.group_size

state_dict = {}
if isinstance(resolved_archive_file, str):
resolved_archive_file = [resolved_archive_file]
assert isinstance(resolved_archive_file, list), "Please check if the loading weight is shared."
for shard_file in resolved_archive_file:
assert shard_file.endswith("safetensors"), "Please check the loading weight saved format."
state_dict.update(load_state_dict(shard_file))
assert len(state_dict.keys()) > 0, "Please check the state_dict loading."
for name, module in model.named_modules():
if isinstance(module, INCWeightOnlyLinear):
assert name + ".qweight" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.qweight'}"
assert name + ".qzeros" in loaded_state_dict_keys, f"Please check the state_dict key {name + '.qzeros'}"
assert name + ".scales" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.scales'}"
if name + ".scales" in loaded_state_dict_keys:
awq_qweight = state_dict[name + ".qweight"]
awq_qzeros = state_dict[name + ".qzeros"]
awq_scales = state_dict[name + ".scales"]
qweight, qzeros, awq_scales = repack_awq_to_optimum_format(
awq_qweight, awq_qzeros, awq_scales, bits, group_size
)
state_dict[name + ".qweight"] = qweight
state_dict[name + ".qzeros"] = qzeros
state_dict[name + ".scales"] = awq_scales

model.load_state_dict(state_dict, strict=False, assign=True)

return model
2 changes: 2 additions & 0 deletions neural_compressor/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(
zero_point: bool = True,
absorb_layer_dict: dict = {},
quant_lm_head: bool = False,
backend: str = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
Expand All @@ -427,6 +428,7 @@ def __init__(
self.seq_len = seq_len
self.absorb_layer_dict = absorb_layer_dict
self.quant_lm_head = quant_lm_head
self.backend = backend
self.modules_to_not_convert = kwargs.get(
"modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"]
)
Expand Down
Loading