From 56f2e97223fa68af1bdc8f4a8bc5cb05f9c0178b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 8 Jul 2024 11:20:51 -0700 Subject: [PATCH] Enable `model.to(device)` for int8 weight only quantized model Summary: Fix some implementation issue for `int8_wo_quantized_model.to(device)` Test Plan: python test/quantization/test_quant_api.py -k test_quantized_model_to_device Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 14 ++++++++++++++ torchao/dtypes/affine_quantized_tensor.py | 9 ++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b137cd22dc..3767f8a88e 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -619,6 +619,20 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_model_to_device(self): + m = ToyLinearModel().eval().to(torch.bfloat16) + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") + + quantize_(m, int8_weight_only()) + ref = m(*example_inputs) + + example_inputs_cuda = (example_inputs[0].to("cuda"),) + m.to(device="cuda") + cuda_res = m(*example_inputs_cuda) + self.assertEqual(cuda_res.cpu(), ref) + if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d4f607a8f4..3cde983e9c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -259,8 +259,11 @@ def _get_to_kwargs(self, *args, **kwargs): def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + # not supported yet + kwargs.pop("memory_format") return self.__class__( - self.layout_tensor.to(kwargs["device"]), + self.layout_tensor.to(device), self.block_size, self.shape, self.quant_min, @@ -470,8 +473,8 @@ def to(self, *args, **kwargs): if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"): raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device") return self.__class__( - self.packed_weight.to(kwargs["device"]), - self.scale_and_zero.to(kwargs["device"]), + self.packed_weight.to(device), + self.scale_and_zero.to(device), self.transposed )