Skip to content

Commit

Permalink
support auto_host2device on RTN and GPTQ(#1894)
Browse files Browse the repository at this point in the history
Signed-off-by: He, Xin3 <[email protected]>
  • Loading branch information
xin3he authored Jul 3, 2024
1 parent b9e73f5 commit f75ff40
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 13 deletions.
10 changes: 9 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
import torch.nn as nn
from tqdm import tqdm

from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils import (
get_accelerator,
get_model_device,
is_transformers_imported,
logger,
set_module,
)
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .modules import WeightOnlyLinear
Expand Down Expand Up @@ -995,6 +1001,7 @@ def prepare(
if use_layer_wise: # pragma: no cover
assert model_path is not None, "model_path should not be None when use layer wise mode"

self.model_device = get_model_device(model) # return model on the same device
self.gptq_quantizer = RAWGPTQuantizer(
model,
weight_config=self.quant_config,
Expand All @@ -1013,6 +1020,7 @@ def convert(self, model, *args, **kwargs):
self.gptq_quantizer.model = model
self.gptq_quantizer.remove_prepare_for_calibration()
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
q_model = q_model.to(self.model_device)
q_model.gptq_config = gptq_config
logger.info("GPTQ quantizing done.")
return q_model
12 changes: 6 additions & 6 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def recover(self):

def pack_tensor_with_torch(self, raw_tensor):
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(raw_tensor.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(raw_tensor.device)
for j in range(packed_tensor.shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
Expand All @@ -286,8 +286,8 @@ def pack_tensor_with_torch(self, raw_tensor):
def unpack_tensor_with_torch(self, packed_tensor):
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
target_len = packed_tensor.shape[1] * self.n_pack
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(packed_tensor.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(packed_tensor.device)
for j in range(packed_tensor.shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
Expand Down Expand Up @@ -338,13 +338,13 @@ def unpack_tensor_with_numpy(self, packed_tensor):
return unpacked_tensor

def pack_tensor(self, raw_tensor):
if "cuda" in self.device:
if "cuda" in raw_tensor.device.type:
return self.pack_tensor_with_torch(raw_tensor)
else:
return self.pack_tensor_with_numpy(raw_tensor)

def unpack_tensor(self, packed_tensor):
if "cuda" in self.device:
if "cuda" in packed_tensor.device.type:
return self.unpack_tensor_with_torch(packed_tensor)
else:
return self.unpack_tensor_with_numpy(packed_tensor)
Expand Down
12 changes: 8 additions & 4 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from neural_compressor.torch.utils import (
get_accelerator,
get_attr,
get_model_device,
is_transformers_imported,
logger,
set_attr,
Expand Down Expand Up @@ -99,10 +100,7 @@ def convert(
"""
weight_config = self.quant_config
device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()

# Put model on device explicitly
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)
model_device = get_model_device(model) # return model on the same device

# for transformers model. If lm_head is tied from embedding, we deepcopy it.
if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False):
Expand Down Expand Up @@ -132,6 +130,8 @@ def convert(
dtype = weight_config[name].get("dtype", "int")
if dtype == "fp32":
continue
# Move modules to the accelerator device layer-by-layer
m.to(device)
### FP8 cast part
if dtype in ["fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"]:
logger.debug("Cast module {} to FP8 using qdq mode, no scaling".format(name))
Expand Down Expand Up @@ -223,4 +223,8 @@ def convert(
return new_module
else:
set_module(model, name, new_module)
# Move modules back to the model device layer-by-layer
m.to(model_device)
new_module.to(model_device)
model.to(model_device)
return model
13 changes: 13 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,16 @@ def dump_model_op_stats(mode, tune_cfg):
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()


def get_model_device(model: torch.nn.Module):
"""Get the device.
Args:
model (torch.nn.Module): the input model.
Returns:
device (str): a string.
"""
for n, p in model.named_parameters():
return p.data.device.type # p.data.device == device(type='cpu')
18 changes: 16 additions & 2 deletions test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def setup_class(self):
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

@pytest.mark.skipif(device == "cpu", reason="no available accelerator")
def test_auto_host2device(self):
# if model is on CPU, we move it to device layer-by-layer for acceleration,
# and then move it back to CPU after quantization.
model = copy.deepcopy(self.tiny_gptj).to("cpu")
example_inputs = copy.deepcopy(self.example_inputs).to("cpu")
quant_config = get_default_gptq_config()
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
gptq_label = model(example_inputs)[0]
gptq_atol = (gptq_label - self.label.to("cpu")).amax()
assert gptq_atol < 0.06, "GPTQ should have low atol."

def test_accuracy_improvement(self):
# test_default_rtn_config
model = copy.deepcopy(self.tiny_gptj)
Expand Down Expand Up @@ -215,9 +229,9 @@ def test_conv1d(self):
from transformers import GPT2Model, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2").to(device)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors="pt")
encoded_input = tokenizer(text, return_tensors="pt").to(device)

def run_fn_conv1d(model):
model(**encoded_input)
Expand Down
13 changes: 13 additions & 0 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,16 @@ def mock_is_transformers_imported():
model = convert(model)
out = model(self.example_inputs)[0]
assert torch.allclose(out, self.label, atol=1e-1), "Accuracy gap atol > 0.1 is unexpected."

@pytest.mark.skipif(device == "cpu", reason="no available accelerator")
def test_auto_host2device(self):
# if model is on CPU, we move it to device layer-by-layer for acceleration,
# and then move it back to CPU after quantization.
model = copy.deepcopy(self.tiny_gptj).to("cpu")
example_inputs = copy.deepcopy(self.example_inputs).to("cpu")
quant_config = get_default_rtn_config()
model = prepare(model, quant_config)
model = convert(model)
rtn_label = model(example_inputs)[0]
rtn_atol = (rtn_label - self.label.to("cpu")).amax()
assert rtn_atol < 0.08, "RTN should have low atol."

0 comments on commit f75ff40

Please sign in to comment.