Skip to content

Commit

Permalink
Fix for "leaf Variable that requires grad" Error in In-Place Operation (
Browse files Browse the repository at this point in the history
huggingface#1372)

Avoid in-place operations for LoRA forward and merging.
  • Loading branch information
DopeorNope-Lee authored and BenjaminBossan committed Mar 14, 2024
1 parent 4db6206 commit b39527d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
orig_weights = base_layer.weight.data.clone()
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
orig_weights += delta_weight
orig_weights = orig_weights + delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
Expand All @@ -345,7 +345,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
else:
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
base_layer.weight.data += delta_weight
base_layer.weight.data = base_layer.weight.data + delta_weight
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
Expand Down Expand Up @@ -537,7 +537,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)
orig_weights = orig_weights + self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
Expand All @@ -546,7 +546,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N

base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand Down Expand Up @@ -625,7 +625,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result += (after_A @ embedding_B) * scaling
result = result + (after_A @ embedding_B) * scaling
result = result.to(torch_result_dtype)

return result
Expand Down Expand Up @@ -726,15 +726,15 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)
orig_weights = orig_weights + self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
base_layer.weight.data = base_layer.weight.data + self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand Down Expand Up @@ -816,7 +816,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result += lora_B(lora_A(dropout(x))) * scaling
result = result + lora_B(lora_A(dropout(x))) * scaling

result = result.to(torch_result_dtype)
return result
Expand Down

0 comments on commit b39527d

Please sign in to comment.