diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index 03da3d17d6..bfbdabf5a0 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -155,7 +155,7 @@ def dequantize_block(self, qweight, weight_max, weight_shape): weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device) for i in range(8 // self.num_bits): lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits - lookup_table_idx = lookup_table_idx.to(torch.int) + lookup_table_idx = lookup_table_idx.to(torch.long) weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze() qweight = qweight >> self.num_bits # right shift 2 bits of the original data diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 4fdc3f6e78..4e27b4aaca 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -944,45 +944,51 @@ def test_causal_lm_training_multi_gpu(self): @require_torch_gpu class LoftQTests(unittest.TestCase): r""" - Tests for LoftQ + Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization. """ def setUp(self): self.error_factor = 3 - self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM" - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - def get_input(self, device): - inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt") + def get_input(self, model_id, device): + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer("All I want is", padding=True, return_tensors="pt") if device == "cuda": inputs = inputs.to("cuda") return inputs def get_base_model(self, model_id, device, **kwargs): - model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval() + cls = AutoModelForSeq2SeqLM if "t5" in model_id else AutoModelForCausalLM + model = cls.from_pretrained(model_id, **kwargs).eval() if device == "cuda": model = model.to("cuda") return model - def get_errors(self, bits=4, loftq_iter=1, device="cuda"): + def get_logits(self, model, inputs): + if model.config.is_encoder_decoder: + input_ids = inputs["input_ids"] + return model(input_ids=input_ids, decoder_input_ids=input_ids).logits + return model(**inputs).logits + + def get_errors( + self, bits=4, loftq_iter=1, device="cuda", model_id="hf-internal-testing/tiny-random-BloomForCausalLM" + ): # Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model # to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to # have less error than the normal LoRA quantized model. Since we compare logits, the observed error is # already somewhat dampened because of the softmax. - model = self.get_base_model(self.model_id, device) - if device == "cuda": - model = model.to("cuda") - torch.manual_seed(0) - inputs = self.get_input(device) - logits_base = model(**inputs).logits + model = self.get_base_model(model_id, device) + task_type = TaskType.SEQ_2_SEQ_LM if model.config.is_encoder_decoder else TaskType.CAUSAL_LM + inputs = self.get_input(model_id, device) + logits_base = self.get_logits(model, inputs) # clean up del model gc.collect() torch.cuda.empty_cache() # logits from the normal quantized LoRA model - lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM) + lora_config = LoraConfig(task_type=task_type) kwargs = {} if bits == 4: kwargs["load_in_4bit"] = True @@ -992,11 +998,11 @@ def get_errors(self, bits=4, loftq_iter=1, device="cuda"): raise ValueError("bits must be 4 or 8") quantized_model = get_peft_model( - self.get_base_model(self.model_id, device=None, **kwargs), + self.get_base_model(model_id, device=None, **kwargs), lora_config, ) torch.manual_seed(0) - logits_quantized = quantized_model(**inputs).logits + logits_quantized = self.get_logits(quantized_model, inputs) del quantized_model gc.collect() torch.cuda.empty_cache() @@ -1004,7 +1010,7 @@ def get_errors(self, bits=4, loftq_iter=1, device="cuda"): # logits from quantized LoRA model using LoftQ loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter) lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config) - model = self.get_base_model(self.model_id, device) + model = self.get_base_model(model_id, device) if device == "cuda": model = model.to("cuda") loftq_model = get_peft_model(model, lora_config) @@ -1012,7 +1018,7 @@ def get_errors(self, bits=4, loftq_iter=1, device="cuda"): loftq_model = loftq_model.to("cuda") torch.manual_seed(0) - logits_loftq = loftq_model(**inputs).logits + logits_loftq = self.get_logits(loftq_model, inputs) del loftq_model gc.collect() torch.cuda.empty_cache() @@ -1088,6 +1094,38 @@ def test_bloomz_loftq_8bit_iter_5(self, device): self.assertTrue(mae_loftq < mae_quantized / self.error_factor) self.assertTrue(mse_loftq < mse_quantized / self.error_factor) + @parameterized.expand(["cuda", "cpu"]) + def test_t5_loftq_4bit(self, device): + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( + bits=4, device=device, model_id="t5-small" + ) + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) + + @parameterized.expand(["cuda", "cpu"]) + def test_t5_loftq_8bit(self, device): + mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( + bits=8, device=device, model_id="t5-small" + ) + # first, sanity check that all errors are > 0.0 + self.assertTrue(mae_quantized > 0.0) + self.assertTrue(mse_quantized > 0.0) + self.assertTrue(mae_loftq > 0.0) + self.assertTrue(mse_loftq > 0.0) + + # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin + factor = 3 + self.assertTrue(mae_loftq < mae_quantized / factor) + self.assertTrue(mse_loftq < mse_quantized / factor) + @require_bitsandbytes @require_torch_gpu