Skip to content

Commit

Permalink
FIX Use torch.long instead of torch.int in LoftQ for PyTorch versions…
Browse files Browse the repository at this point in the history
… <2.x (#1320)

Solves #1307

For PyTorch < v2.x, using torch.int does not work for indexing, thus
using torch.long.
  • Loading branch information
BenjaminBossan authored Jan 8, 2024
1 parent 8665e2b commit 4186c9b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/peft/utils/loftq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def dequantize_block(self, qweight, weight_max, weight_shape):
weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device)
for i in range(8 // self.num_bits):
lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits
lookup_table_idx = lookup_table_idx.to(torch.int)
lookup_table_idx = lookup_table_idx.to(torch.long)
weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze()
qweight = qweight >> self.num_bits # right shift 2 bits of the original data

Expand Down
74 changes: 56 additions & 18 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,45 +944,51 @@ def test_causal_lm_training_multi_gpu(self):
@require_torch_gpu
class LoftQTests(unittest.TestCase):
r"""
Tests for LoftQ
Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization.
"""

def setUp(self):
self.error_factor = 3
self.model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

def get_input(self, device):
inputs = self.tokenizer("All I want is", padding=True, return_tensors="pt")
def get_input(self, model_id, device):
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer("All I want is", padding=True, return_tensors="pt")
if device == "cuda":
inputs = inputs.to("cuda")
return inputs

def get_base_model(self, model_id, device, **kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).eval()
cls = AutoModelForSeq2SeqLM if "t5" in model_id else AutoModelForCausalLM
model = cls.from_pretrained(model_id, **kwargs).eval()
if device == "cuda":
model = model.to("cuda")
return model

def get_errors(self, bits=4, loftq_iter=1, device="cuda"):
def get_logits(self, model, inputs):
if model.config.is_encoder_decoder:
input_ids = inputs["input_ids"]
return model(input_ids=input_ids, decoder_input_ids=input_ids).logits
return model(**inputs).logits

def get_errors(
self, bits=4, loftq_iter=1, device="cuda", model_id="hf-internal-testing/tiny-random-BloomForCausalLM"
):
# Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model
# to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to
# have less error than the normal LoRA quantized model. Since we compare logits, the observed error is
# already somewhat dampened because of the softmax.
model = self.get_base_model(self.model_id, device)
if device == "cuda":
model = model.to("cuda")

torch.manual_seed(0)
inputs = self.get_input(device)
logits_base = model(**inputs).logits
model = self.get_base_model(model_id, device)
task_type = TaskType.SEQ_2_SEQ_LM if model.config.is_encoder_decoder else TaskType.CAUSAL_LM
inputs = self.get_input(model_id, device)
logits_base = self.get_logits(model, inputs)
# clean up
del model
gc.collect()
torch.cuda.empty_cache()

# logits from the normal quantized LoRA model
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM)
lora_config = LoraConfig(task_type=task_type)
kwargs = {}
if bits == 4:
kwargs["load_in_4bit"] = True
Expand All @@ -992,27 +998,27 @@ def get_errors(self, bits=4, loftq_iter=1, device="cuda"):
raise ValueError("bits must be 4 or 8")

quantized_model = get_peft_model(
self.get_base_model(self.model_id, device=None, **kwargs),
self.get_base_model(model_id, device=None, **kwargs),
lora_config,
)
torch.manual_seed(0)
logits_quantized = quantized_model(**inputs).logits
logits_quantized = self.get_logits(quantized_model, inputs)
del quantized_model
gc.collect()
torch.cuda.empty_cache()

# logits from quantized LoRA model using LoftQ
loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq", loftq_config=loftq_config)
model = self.get_base_model(self.model_id, device)
model = self.get_base_model(model_id, device)
if device == "cuda":
model = model.to("cuda")
loftq_model = get_peft_model(model, lora_config)
if device == "cuda":
loftq_model = loftq_model.to("cuda")

torch.manual_seed(0)
logits_loftq = loftq_model(**inputs).logits
logits_loftq = self.get_logits(loftq_model, inputs)
del loftq_model
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1088,6 +1094,38 @@ def test_bloomz_loftq_8bit_iter_5(self, device):
self.assertTrue(mae_loftq < mae_quantized / self.error_factor)
self.assertTrue(mse_loftq < mse_quantized / self.error_factor)

@parameterized.expand(["cuda", "cpu"])
def test_t5_loftq_4bit(self, device):
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
bits=4, device=device, model_id="t5-small"
)
# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
factor = 3
self.assertTrue(mae_loftq < mae_quantized / factor)
self.assertTrue(mse_loftq < mse_quantized / factor)

@parameterized.expand(["cuda", "cpu"])
def test_t5_loftq_8bit(self, device):
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
bits=8, device=device, model_id="t5-small"
)
# first, sanity check that all errors are > 0.0
self.assertTrue(mae_quantized > 0.0)
self.assertTrue(mse_quantized > 0.0)
self.assertTrue(mae_loftq > 0.0)
self.assertTrue(mse_loftq > 0.0)

# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
factor = 3
self.assertTrue(mae_loftq < mae_quantized / factor)
self.assertTrue(mse_loftq < mse_quantized / factor)


@require_bitsandbytes
@require_torch_gpu
Expand Down

0 comments on commit 4186c9b

Please sign in to comment.