Skip to content

Commit

Permalink
support RTN on Gaudi2 and make UTs auto detect device (#1811)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored May 24, 2024
1 parent 8dac9f2 commit 4b9b447
Show file tree
Hide file tree
Showing 15 changed files with 165 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from neural_compressor.common import logger
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
from neural_compressor.torch.utils import get_device
from neural_compressor.torch.utils import get_accelerator


class HalfPrecisionConverter:
Expand All @@ -40,7 +40,7 @@ def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs):
configs_mapping (Dict): config class for mix-precision.
"""
self.configs_mapping = configs_mapping
self.device = get_device()
self.device = get_accelerator().current_device_name()

def convert(self, model: torch.nn.Module):
"""Convert to FP16 or BF16 model.
Expand Down
6 changes: 4 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_device, logger
from neural_compressor.torch.utils import get_accelerator, logger

from .modules import MulLinear
from .utility import (
Expand Down Expand Up @@ -124,13 +124,15 @@ def __init__(
weight_config={},
total_block_args=[],
total_block_kwargs=[],
device="auto",
):

self.example_inputs = example_inputs
self.model = model
if example_inputs is None:
assert dataloader is not None, "datalaoder or example_inputs is required."
self.example_inputs = get_example_input(dataloader)
self.device = device
self._move_model_and_data_to_device()
self.total_block_args = total_block_args
self.total_block_kwargs = total_block_kwargs
Expand All @@ -146,7 +148,7 @@ def __init__(

def _move_model_and_data_to_device(self):
# Put the model and example_inputs into target device
device = get_device()
device = get_accelerator(self.device).current_device_name()
self.model.to(device)
self.example_inputs = self.example_inputs.to(device)

Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch.nn as nn
from tqdm import tqdm

from neural_compressor.torch.utils import fetch_module, get_device, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .modules import WeightOnlyLinear
Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(
self.check_layer_config()

# device
self.device = get_device(kwargs.pop("device", "auto"))
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
self.model.to(self.device)
self.is_ready = False

Expand Down
135 changes: 61 additions & 74 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.autograd import Function
from torch.nn import functional as F

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import accelerator, logger

from .utility import quant_tensor

Expand Down Expand Up @@ -174,9 +174,9 @@ def __init__(

def pack(self, int_weight, scale, zp, bias, g_idx=None):
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
self.scales = self.scales.T.contiguous()
self.qweight = self.qweight.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -197,124 +197,111 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
assert scale.shape == self.scales.shape, f"{scale.shape} != {self.scales.shape} Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
int_weight = int_weight.T.contiguous()
self.qweight = self.qweight.T.contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)

# pack weight
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = int_weight[:, start:end].type(self.compression_dtype)
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
self.qweight.copy_(self.pack_tensor(int_weight))
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.t_().contiguous()
self.qweight = self.qweight.T.contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
zp = zp.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = zp[:, start:end].type(self.compression_dtype)
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
self.qzeros.copy_(self.pack_tensor(zp))
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.t_().contiguous()
self.qzeros = self.qzeros.T.contiguous()
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
self.scales = self.scales.T.contiguous()
self.qweight = self.qweight.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
scales = self.scales.T.contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
if self.g_idx is None:
# used for recovering fp32_weight
self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device)
if hasattr(self, "qzeros"):
weight_dtype = torch.uint8
else:
weight_dtype = torch.int8
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = qweight[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
if weight_dtype == torch.uint8:
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
qweight = qweight.T.contiguous()
weight = self.unpack_tensor(qweight)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.t_().contiguous()
weight = weight.T.contiguous()
weight = weight[: self.out_features, : self.in_features] # avoid oversize
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
new_weight += torch.where(weight == k, v, 0)
weight = new_weight
# unpack zero_point
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.T.contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = qzeros[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
qzeros = qzeros.T.contiguous()
zp = self.unpack_tensor(qzeros)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.t_().contiguous()
zp = zp.T.contiguous()
zp = zp[: scales.shape[0], : scales.shape[1]] # avoid oversize
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
# recover fp32 weight with int_weight, scale, and zero_point
for idx in range(self.in_features):
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]]
fp32_weight[:, idx] = (torch.subtract(weight[:, idx], zp[:, self.g_idx[idx]]).to(torch.int8)) * scales[
:, self.g_idx[idx]
]
else:
# recover fp32 weight with int_weight, scale
for idx in range(self.in_features):
fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]]
return fp32_weight

def pack_tensor(self, raw_tensor):
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
for j in range(packed_tensor.shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = raw_tensor[:, start:end].type(self.compression_dtype)
tmp &= mask
for e in range(tmp.shape[1]):
tmp[:, e] = tmp[:, e] << (self.bits * e)
packed_tensor[:, j] |= tmp[:, e]
accelerator.synchronize()
return packed_tensor

def unpack_tensor(self, packed_tensor):
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
target_len = packed_tensor.shape[1] * self.n_pack
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
for j in range(packed_tensor.shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
tmp = packed_tensor[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
if target_dtype == torch.uint8:
tmp &= mask # remove sign bit
unpacked_tensor[:, index].copy_(tmp.type(target_dtype))
accelerator.synchronize()
return unpacked_tensor

def forward(self, input):
if not hasattr(self, "weight"):
weight = self.recover()
Expand Down
13 changes: 6 additions & 7 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_device, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module

from .utility import cast_fp8, quant_tensor, search_clip

Expand Down Expand Up @@ -90,7 +90,7 @@ def convert(
model: fake quantized torch module
"""
weight_config = self.quant_config
device = get_device(kwargs.pop("device", "auto"))
device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()

# Put model on device explicitly
# TODO: refine it later, Put module on device one by one instead of the whole model
Expand Down Expand Up @@ -165,9 +165,9 @@ def convert(
else:
transpose = group_dim == 0
if transpose:
weight = m.weight.t_().contiguous()
weight = m.weight.detach().T.contiguous()
else:
weight = m.weight
weight = m.weight.detach()
if use_mse_search:
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
if export_compressed_model:
Expand All @@ -189,8 +189,8 @@ def convert(
in_features = m.in_features
out_features = m.out_features
elif is_transformers_imported() and isinstance(m, transformers.Conv1D):
in_features = m.weight.shape[1]
out_features = m.weight.shape[0]
in_features = m.weight.shape[0]
out_features = m.weight.shape[1]
int_weight = int_weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp
Expand Down Expand Up @@ -227,6 +227,5 @@ def convert(
# for only group_dim is 0 or only `transformers.Conv1D`,
# we need to transpose the quantized tensor and module's weight back
weight = weight.t_().contiguous()
m.weight.t_().contiguous()
m.weight.data.copy_(weight)
return model
4 changes: 2 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch

from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import get_device, is_transformers_imported, logger
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module
Expand Down Expand Up @@ -63,7 +63,7 @@ def _post_init(self):
def _get_device(self):
"""Get the model device
:return:Model device."""
device = get_device()
device = get_accelerator().current_device_name()
return device

def _get_dtype(self):
Expand Down
17 changes: 10 additions & 7 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils import accelerator, device_synchronize, logger

__all__ = [
"FLOAT_MAPPING",
Expand Down Expand Up @@ -205,12 +205,12 @@ def qdq_weight_sym(weight, bits=4, quantile=1.0, return_int=False, full_range=Fa
wmax = torch.max(torch.abs(max_val), torch.abs(min_val))
wmax = wmax * quantile
tmp = wmax == 0
wmax[tmp] = +1
wmax[tmp] = torch.tensor(1, dtype=wmax.dtype, device=wmax.device)
if full_range:
# use -8, 8 to make sure amax is not changed after fake quant
scale = wmax / (-minq)
tmp = scale * flip_flag.int()
scale -= 2 * tmp # set negative scale with flip_flag
# set negative scale with flip_flag
scale = torch.where(flip_flag, -scale, scale)
else:
scale = wmax / maxq
scale.unsqueeze_(dim=-1)
Expand Down Expand Up @@ -248,6 +248,7 @@ def qdq_weight_actor(weight, bits, scheme, quantile=1.0, dtype="int", return_int
return qdq_weight_asym(weight, bits, quantile, return_int, **kwargs)


@device_synchronize
def quant_tensor(
weight,
bits=4,
Expand Down Expand Up @@ -343,10 +344,12 @@ def quant_tensor(
)
if return_int or quant_scale:
weight2, scale2, zp2 = weight2
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
weight = torch.cat([weight1, weight2], dim=1)
scale = torch.cat([scale1, scale2], dim=1)
zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1)
q_state = (weight, scale, zp)
accelerator.synchronize()
orig_weight.copy_(weight)
return orig_weight, scale, zp
else:
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
return orig_weight
Expand Down Expand Up @@ -444,7 +447,7 @@ def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_ful
full_range=enable_full_range,
quantile=ratio,
)
loss = (org_weight - m.weight.data).float().pow(2).mean().item()
loss = (org_weight - m.weight.data).float().pow(2).mean()
m.weight.data.copy_(org_weight)
history.append(loss)
is_best = loss < best_error
Expand Down
Loading

0 comments on commit 4b9b447

Please sign in to comment.