Skip to content

Commit

Permalink
Refine base config for 3.x (#1595)
Browse files Browse the repository at this point in the history
* Refine base config for 3.x

* fixed group_dim is 0

---------

Signed-off-by: Mengni Wang <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: yiliu30 <[email protected]>
  • Loading branch information
mengniwang95 and yiliu30 authored Feb 5, 2024
1 parent c6f9cca commit efea089
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
5 changes: 3 additions & 2 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def to_dict(self):

def get_params_dict(self):
result = dict()
for param in self.params_list:
result[param] = getattr(self, param)
for param, value in self.__dict__.items():
if param not in ["_global_config", "_local_config", "_white_list"]:
result[param] = value
return result

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/onnxrt/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __init__(
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/quantize.py#L78
"""
BaseConfig.__init__(self)
StaticQuantConfig.__init__(self, calibration_data_reader=None, **kwargs)
kwargs.update({"calibration_data_reader": None})
StaticQuantConfig.__init__(self, **kwargs)
self.alpha = alpha
self.folding = folding
self.op_types = op_types
Expand Down
10 changes: 8 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def rtn_quantize(
continue
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
if group_dim == 0:
weight = m.weight.t_().contiguous()
else:
weight = m.weight
if use_mse_search:
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
if export_compressed_model:
Expand Down Expand Up @@ -169,6 +172,9 @@ def rtn_quantize(
full_range=use_full_range,
**double_quant_config,
)
weight = weight.t_().contiguous() if group_dim == 0 else weight
if group_dim == 0:
# for group_dim is 0, we need to transpose the quantized tensor and module's weight back
weight = weight.t_().contiguous()
m.weight.t_().contiguous()
m.weight.data.copy_(weight)
return model

0 comments on commit efea089

Please sign in to comment.