Skip to content

Commit

Permalink
Port auto-detect absorb layers for TEQ (#1895)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 4, 2024
1 parent 856118e commit 1386ac5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
45 changes: 33 additions & 12 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# limitations under the License.
#

import copy
from typing import Any
from typing import Any, List

import torch

Expand All @@ -36,10 +35,10 @@
class TrainableEquivalentTransformation:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""

_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
_PREPARE_ATTRS: List[str] = ["weight_config", "trained_alphas"]
_PREPARE_ATTRS_PREFIX = "_prepare_"

def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
def __init__(self, model, weight_config={}, absorb_to_layer=None, folding=True, example_inputs=None):
"""
:param model: the model for quantization
:param weight_config (dict, optional): contains all info required by RTN. Defaults to {}.
Expand All @@ -54,6 +53,24 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
self.absorb_to_layer = absorb_to_layer
self._post_initialized = False

def _detect_absorb_to_layer(self, model, folding, example_inputs):
# If user not provide the layers to absorb the quantization, detect layers automatically
supported_layers = ["Linear"]
detected_absorb_layers = {}
# Detect the layers that can be absorbed automatically
if folding:
from neural_compressor.torch.algorithms.weight_only.utility import GraphTrace

tg = GraphTrace()
detected_absorb_layers, _ = tg.get_absorb_to_layer(model, example_inputs, supported_layers)
else: # pragma: no cover
for name, module in model.named_modules():
if module.__class__.__name__ in supported_layers:
detected_absorb_layers[name] = [name]
logger.info("Detected **absorb layer**: **absorbed layers**")
logger.info(detected_absorb_layers)
return detected_absorb_layers

def _post_init(self):
self.dtype = self._get_dtype()
self.model.to(self.device)
Expand All @@ -75,6 +92,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
to the paper for more details
:param sqrt_w_init: use sqrt weight to init."""

if not self.absorb_to_layer:
self.absorb_to_layer = self._detect_absorb_to_layer(self.model, self.folding, self.example_inputs)
if not self._post_initialized:
self._post_init()
# freeze model.
Expand Down Expand Up @@ -104,7 +123,7 @@ def add_tuning_scale(self, sqrt_w_init=False):

self.trained_alphas[layer_norm] = alpha
for layer_name in self.absorb_to_layer[layer_norm]:
if self.weight_config.get(layer_name) is None: # pragma: no cover
if not self.weight_config.get(layer_name): # pragma: no cover
logger.info(f"layer {layer_name} not in weight config, skip.")
continue
num_bits = self.weight_config[layer_name]["bits"]
Expand All @@ -117,10 +136,10 @@ def add_tuning_scale(self, sqrt_w_init=False):
)
set_module(self.model, layer_name, wrapper_module)

for n, m in self.model.named_modules():
for layer_name, m in self.model.named_modules():
if isinstance(m, torch.nn.Linear) and "orig_layer" not in n:
if self.weight_config.get(n) is None: # pragma: no cover
logger.info(f"out of absorbed layer {n} not in weight config, skip.")
if not self.weight_config.get(layer_name): # pragma: no cover
logger.info(f"out of absorbed layer {layer_name} not in weight config, skip.")
continue
num_bits = self.weight_config[layer_name]["bits"]
group_size = self.weight_config[layer_name]["group_size"]
Expand All @@ -131,7 +150,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
wrapper_module = TEQLinearFakeQuant(
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
)
set_module(self.model, n, wrapper_module)
set_module(self.model, layer_name, wrapper_module)
# Attach the weight config captured at prepare stage to the model
self.model._weight_config = self.weight_config
self.model._trained_alphas = self.trained_alphas
Expand Down Expand Up @@ -190,7 +209,9 @@ def _absorb_scales(self, layer, scale, layer_name=""):
scale = scale.view(scale.shape[0], 1)
layer.weight *= scale

elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky
elif (
layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm"
): # pragma: no cover
layer.weight *= scale

else: # pragma: no cover
Expand Down Expand Up @@ -222,7 +243,7 @@ def _scale_layer_weight(self, layer, scale): ##input channel
@torch.no_grad()
def transform(self):
"""Apply alpha/scale."""
if not self._post_initialized:
if not self._post_initialized: # pragma: no cover
self._post_init()
for ln_name, layer_names in self.absorb_to_layer.items():
module = get_module(self.model, ln_name)
Expand Down Expand Up @@ -272,7 +293,7 @@ def save(self, save_scale_file="", save_state_dict_file=""):

class TEQuantizer(Quantizer):

def __init__(self, quant_config, folding, absorb_to_layer, example_inputs):
def __init__(self, quant_config, folding, example_inputs, absorb_to_layer=None):
super().__init__(quant_config=quant_config)
self.folding = folding
self.absorb_to_layer = absorb_to_layer
Expand Down
17 changes: 15 additions & 2 deletions test/3x/torch/algorithms/weight_only/test_teq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,21 @@ def setUpClass(self):
)
self.gptj.seqlen = 512

def train_func(self):
pass
def test_teq_detect_absorb_layers(self):
example_inputs = torch.ones([1, 512], dtype=torch.long)
test_input = torch.ones([1, 512], dtype=torch.long)
model = copy.deepcopy(self.gptj)
out0 = model(test_input)

weight_config = {
# 'op_name': (bit, group_size, scheme)
"transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"},
"transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"},
}
quantizer = TEQuantizer(quant_config=weight_config, folding=True, example_inputs=example_inputs)
model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train)
out1 = model(test_input)
self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03))

def test_teq(self):
example_inputs = torch.ones([1, 512], dtype=torch.long)
Expand Down

0 comments on commit 1386ac5

Please sign in to comment.