Skip to content

Commit

Permalink
FIX Allow DoRA init on CPU when using BNB (#1724)
Browse files Browse the repository at this point in the history
Resolves #1674

For some users, it is necessary to initialize the model on CPU, even
when using BitsAndBytes, which requires a GPU eventually. Since DoRA
requires to dequantize the BNB weights at initialization, we need to
temporarily move the model corresponding weights to GPU. After
dequantization, the weights are moved back to CPU.
  • Loading branch information
BenjaminBossan authored May 14, 2024
1 parent 47b3712 commit 748f796
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 18 deletions.
16 changes: 3 additions & 13 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_bnb_weight, gather_params_ctx
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.other import transpose

from .config import LoraConfig
Expand Down Expand Up @@ -195,12 +195,7 @@ def dora_init(self, adapter_name: str) -> None:
scaling = self.scaling[adapter_name]
with gather_params_ctx(self.get_base_layer().parameters()):
base_layer = self.get_base_layer()
if hasattr(base_layer, "W_q"): # For handling HQQ quantized weight
weight = base_layer.dequantize()
else:
weight = base_layer.weight
quant_state = getattr(base_layer, "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
weight = dequantize_module_weight(base_layer)
if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds.
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
Expand Down Expand Up @@ -231,12 +226,7 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):
lora_weight = lora_B.weight @ lora_A.weight
magnitude = self.lora_magnitude_vector[active_adapter]
base_layer = self.get_base_layer()
if hasattr(base_layer, "W_q"): # For handling HQQ quantized weight
weight = base_layer.dequantize()
else:
weight = base_layer.weight
quant_state = getattr(base_layer, "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
weight = dequantize_module_weight(base_layer)
weight = weight.to(x.dtype)
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
Expand Down
45 changes: 40 additions & 5 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,55 @@ def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module
return


def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter:
"""
Helper function to dequantize 4bit or 8bit bnb weights.
Helper function to dequantize a quantized weight.
This function should be extended if more quantization schemes are added to the library.
If the weight is not a bnb quantized weight, it will be returned as is.
If the weight is not quantized, it will be returned as is.
"""
if hasattr(module, "W_q"): # For handling HQQ quantized weight
weight = module.dequantize()
return weight

weight = module.weight
if not isinstance(weight, torch.nn.Parameter):
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")

cls_name = weight.__class__.__name__
if cls_name not in ("Params4bit", "Int8Params"):
return weight

quant_state = getattr(module, "state", None)
device = weight.device
is_cpu = device.type == torch.device("cpu").type
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
if is_cpu:
# dequantize_bnb_weight for 8bit moves the device in-place, thus we need to move it back to CPU if necessary
module.weight = module.weight.to(device)
return weight


def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
"""Helper function to dequantize 4bit or 8bit bnb weights.
Since dequantization is not supported on CPU, the weight will be temporarily moved to CUDA if necessary.
"""
import bitsandbytes as bnb

# BNB requires CUDA weights
device = weight.device
is_cpu = device.type == torch.device("cpu").type
if is_cpu:
weight = weight.to(torch.device("cuda"))

cls_name = weight.__class__.__name__
if cls_name == "Params4bit":
return bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
if is_cpu:
dequantized = dequantized.to(device)
return dequantized

if state.SCB is None:
state.SCB = weight.SCB
Expand All @@ -65,4 +97,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
dequantized = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
if is_cpu:
dequantized = dequantized.to(device)
return dequantized
27 changes: 27 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,33 @@ def test_causal_lm_training_gpt2_dora(self):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@parameterized.expand(["4bit", "8bit"])
def test_initialize_dora_with_bnb_on_cpu(self, kbit):
# 1674
# The issue is that to initialize DoRA, we need to dequantize the weights. That only works on GPU for bnb.
# Therefore, intializing DoRA with bnb on CPU used to fail.
model_id = "facebook/opt-125m"
if kbit == "4bit":
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
elif kbit == "8bit":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
raise ValueError("Only 4bit and 8bit bnb allowed")

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model = model.cpu() # ensure that we're on CPU
# sanity check that all weights are on CPU
weights_not_cpu = [name for name, p in model.named_parameters() if p.device != torch.device("cpu")]
assert not weights_not_cpu

lora_config = LoraConfig(use_dora=True)

# should not raise
peft_model = get_peft_model(model, lora_config)
# check that the weights are still on CPU
weights_not_cpu = [name for name, p in peft_model.named_parameters() if p.device != torch.device("cpu")]
assert not weights_not_cpu


@require_torch_gpu
@require_auto_gptq
Expand Down

0 comments on commit 748f796

Please sign in to comment.