Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / PEFT / LoRA] Integrate PEFT into Unet #5151

Merged
merged 159 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
159 commits
Select commit Hold shift + click to select a range
cf2c0ba
v1
younesbelkada Sep 22, 2023
8759f55
add tests and fix previous failing tests
younesbelkada Sep 25, 2023
c90aedc
fix CI
younesbelkada Sep 25, 2023
0bfb136
Merge remote-tracking branch 'upstream/main' into peft-part-2
younesbelkada Sep 25, 2023
3002ea3
add tests + v1 `PeftLayerScaler`
younesbelkada Sep 25, 2023
d6f500c
Merge branch 'main' into peft-part-2
younesbelkada Sep 25, 2023
64ca2bb
style
younesbelkada Sep 25, 2023
f62e506
add scale retrieving mechanism system
younesbelkada Sep 25, 2023
48842c0
fix CI
younesbelkada Sep 25, 2023
71d4990
Merge remote-tracking branch 'upstream/main' into peft-part-2
younesbelkada Sep 25, 2023
1fb4aa2
up
younesbelkada Sep 25, 2023
4c803f6
up
younesbelkada Sep 27, 2023
11a493a
simple approach --> not same results for some reason
younesbelkada Sep 27, 2023
4ea8959
fix issues
younesbelkada Sep 27, 2023
16b1161
fix copies
younesbelkada Sep 27, 2023
b3a02be
remove unneeded method
younesbelkada Sep 27, 2023
cc135f2
active adapters!
younesbelkada Sep 27, 2023
5c493e5
Merge branch 'main' into peft-part-2
younesbelkada Sep 27, 2023
a09530c
fix merge conflicts
younesbelkada Sep 27, 2023
d3ce092
up
younesbelkada Sep 27, 2023
9e500d2
up
younesbelkada Sep 27, 2023
edaea14
kohya - test-1
younesbelkada Sep 27, 2023
8781506
Apply suggestions from code review
younesbelkada Sep 27, 2023
0a14573
fix scale
younesbelkada Sep 27, 2023
c26c418
fix copies
younesbelkada Sep 27, 2023
6996b82
add comment
younesbelkada Sep 27, 2023
68912e4
multi adapters
younesbelkada Sep 28, 2023
10e0e61
fix tests
younesbelkada Sep 28, 2023
8c42fa1
Merge branch 'main' into peft-part-2
younesbelkada Oct 3, 2023
99fec57
oops
younesbelkada Oct 3, 2023
ac925f8
v1 faster loading - in progress
younesbelkada Oct 3, 2023
ebb16ca
Revert "v1 faster loading - in progress"
younesbelkada Oct 4, 2023
81f886e
kohya same generation
younesbelkada Oct 4, 2023
7376deb
fix some slow tests
younesbelkada Oct 4, 2023
ff82de4
peft integration features for unet lora
pacman100 Oct 4, 2023
94403c1
fix `get_peft_kwargs`
pacman100 Oct 4, 2023
e8fca9f
Update loaders.py
pacman100 Oct 4, 2023
4f21a7b
add some tests
younesbelkada Oct 4, 2023
3568e7f
add unfuse tests
younesbelkada Oct 4, 2023
459285f
fix tests
younesbelkada Oct 4, 2023
24dad33
up
younesbelkada Oct 4, 2023
ec04337
add set adapter from sourab and tests
younesbelkada Oct 4, 2023
b40592a
fix multi adapter tests
younesbelkada Oct 4, 2023
2646f3d
style & quality
younesbelkada Oct 4, 2023
0e771f0
Merge branch 'peft-part-2' into smangrul/unet-enhancements
pacman100 Oct 4, 2023
86bd6f5
Merge pull request #2 from huggingface/smangrul/unet-enhancements
younesbelkada Oct 4, 2023
02e73a4
style
younesbelkada Oct 4, 2023
86c7d69
remove comment
younesbelkada Oct 4, 2023
94abbc0
fix `adapter_name` issues
pacman100 Oct 5, 2023
61e316c
fix unet adapter name for sdxl
pacman100 Oct 5, 2023
32dd0d5
fix enabling/disabling adapters
pacman100 Oct 5, 2023
ba6c180
fix fuse / unfuse unet
younesbelkada Oct 5, 2023
892d1d3
Merge pull request #3 from huggingface/smangrul/fixes-peft-integration
younesbelkada Oct 5, 2023
c0d9d68
nit
younesbelkada Oct 5, 2023
7e1e252
fix
younesbelkada Oct 5, 2023
f4a5229
up
younesbelkada Oct 5, 2023
8dc6b87
fix cpu offloading
younesbelkada Oct 5, 2023
4746de1
Merge branch 'main' into peft-part-2
younesbelkada Oct 5, 2023
0413049
fix another slow test
younesbelkada Oct 5, 2023
1d517e3
Merge branch 'peft-part-2' of https://github.com/younesbelkada/diffus…
younesbelkada Oct 5, 2023
6fe1b2d
fix another offload test
younesbelkada Oct 5, 2023
2825d5b
Merge branch 'peft-part-2' of https://github.com/younesbelkada/diffus…
younesbelkada Oct 5, 2023
206f0de
add more tests
younesbelkada Oct 5, 2023
2265fc2
all slow tests pass
younesbelkada Oct 5, 2023
265a928
style
younesbelkada Oct 5, 2023
e7a3dc6
fix alpha pattern for unet and text encoder
pacman100 Oct 6, 2023
7868b48
Merge pull request #4 from huggingface/smangrul/fix-alpha-pattern
younesbelkada Oct 6, 2023
abb2325
Update src/diffusers/loaders.py
younesbelkada Oct 6, 2023
81db89f
Update src/diffusers/models/attention.py
younesbelkada Oct 6, 2023
fc643eb
up
younesbelkada Oct 6, 2023
957108b
up
younesbelkada Oct 6, 2023
5d9ce0d
clarify comment
younesbelkada Oct 6, 2023
bd44f56
comments
younesbelkada Oct 6, 2023
71c321e
change comment order
younesbelkada Oct 6, 2023
c42d974
change comment order
younesbelkada Oct 6, 2023
a0598e6
stylr & quality
younesbelkada Oct 6, 2023
a7a6cd6
Update tests/lora/test_lora_layers_peft.py
younesbelkada Oct 6, 2023
9992964
fix bugs and add tests
younesbelkada Oct 6, 2023
525743e
Update src/diffusers/models/modeling_utils.py
younesbelkada Oct 6, 2023
7183863
Update src/diffusers/models/modeling_utils.py
younesbelkada Oct 6, 2023
e44c17c
refactor
younesbelkada Oct 6, 2023
f435ce9
suggestion
younesbelkada Oct 6, 2023
7e8cb7a
add break statemebt
younesbelkada Oct 6, 2023
2af9bfd
add compile tests
younesbelkada Oct 6, 2023
8da2350
move slow tests to peft tests as I modified them
younesbelkada Oct 6, 2023
f497280
quality
younesbelkada Oct 6, 2023
c0ce809
Merge branch 'peft-part-2' of https://github.com/younesbelkada/diffus…
younesbelkada Oct 6, 2023
74cfc1c
refactor a bit
younesbelkada Oct 6, 2023
36ec721
style
younesbelkada Oct 6, 2023
2c94a86
Merge remote-tracking branch 'upstream/main' into peft-part-2
younesbelkada Oct 6, 2023
95d2b44
change import
younesbelkada Oct 8, 2023
e82d83c
style
younesbelkada Oct 9, 2023
92aef0b
Merge remote-tracking branch 'upstream/main' into peft-part-2
younesbelkada Oct 9, 2023
f939e04
fix CI
younesbelkada Oct 9, 2023
10f6352
refactor slow tests one last time
younesbelkada Oct 9, 2023
4b1a073
Merge branch 'peft-part-2' of https://github.com/younesbelkada/diffus…
younesbelkada Oct 9, 2023
22452b7
style
younesbelkada Oct 9, 2023
48ae256
oops
younesbelkada Oct 9, 2023
06db84d
oops
younesbelkada Oct 9, 2023
0723b55
oops
younesbelkada Oct 9, 2023
ca039a5
Merge branch 'peft-part-2' of https://github.com/younesbelkada/diffus…
younesbelkada Oct 9, 2023
44ae0a9
final tweak tests
younesbelkada Oct 9, 2023
a01d542
Merge branch 'main' into peft-part-2
sayakpaul Oct 9, 2023
d64dc6f
Apply suggestions from code review
younesbelkada Oct 9, 2023
f6d6e5d
Update src/diffusers/loaders.py
younesbelkada Oct 9, 2023
5394d37
comments
younesbelkada Oct 9, 2023
32043aa
Apply suggestions from code review
younesbelkada Oct 9, 2023
599f556
remove comments
younesbelkada Oct 9, 2023
6faee80
more comments
younesbelkada Oct 9, 2023
1c94452
try
younesbelkada Oct 9, 2023
a14779e
revert
younesbelkada Oct 9, 2023
18f3a25
Merge branch 'main' into peft-part-2
younesbelkada Oct 9, 2023
cad5a4b
add `safe_merge` tests
younesbelkada Oct 9, 2023
3708ed9
add comment
younesbelkada Oct 10, 2023
323612b
style, comments and run tests in fp16
younesbelkada Oct 10, 2023
ec65342
Merge remote-tracking branch 'upstream/main' into peft-part-2
younesbelkada Oct 10, 2023
64e2d87
add warnings
younesbelkada Oct 10, 2023
db0c3dc
fix doc test
younesbelkada Oct 10, 2023
6172b64
replace with `adapter_weights`
younesbelkada Oct 10, 2023
74d80a9
add `get_active_adapters()`
younesbelkada Oct 10, 2023
cb588ae
expose `get_list_adapters` method
younesbelkada Oct 10, 2023
b419b52
better error message
younesbelkada Oct 10, 2023
498dc17
Apply suggestions from code review
younesbelkada Oct 10, 2023
02d17b3
style
younesbelkada Oct 10, 2023
400c2da
trigger slow lora tests
younesbelkada Oct 10, 2023
1402506
Merge branch 'main' into peft-part-2
sayakpaul Oct 10, 2023
e6d8042
fix tests
younesbelkada Oct 10, 2023
3c4dc79
Merge remote-tracking branch 'upstream/test-peft-unet' into peft-part-2
younesbelkada Oct 10, 2023
d5e7647
maybe fix last test
younesbelkada Oct 10, 2023
a5cd549
Merge remote-tracking branch 'upstream/test-peft-unet' into peft-part-2
younesbelkada Oct 10, 2023
a02c162
revert
younesbelkada Oct 11, 2023
ffaf30f
Update src/diffusers/loaders.py
younesbelkada Oct 11, 2023
4145818
Merge branch 'main' into peft-part-2
younesbelkada Oct 11, 2023
836e32e
Update src/diffusers/loaders.py
younesbelkada Oct 11, 2023
2fa61fc
Update src/diffusers/loaders.py
younesbelkada Oct 11, 2023
9102399
Update src/diffusers/loaders.py
younesbelkada Oct 11, 2023
924222e
Apply suggestions from code review
younesbelkada Oct 12, 2023
21a279a
move `MIN_PEFT_VERSION`
younesbelkada Oct 12, 2023
0fe4203
Apply suggestions from code review
younesbelkada Oct 12, 2023
e981af2
let's not use class variable
younesbelkada Oct 12, 2023
61737cf
fix few nits
younesbelkada Oct 12, 2023
b2150d9
change a bit offloading logic
younesbelkada Oct 12, 2023
3521188
check earlier
younesbelkada Oct 12, 2023
d03d1a3
rm unneeded block
younesbelkada Oct 12, 2023
fabb521
break long line
younesbelkada Oct 12, 2023
fc55a1a
return empty list
younesbelkada Oct 12, 2023
7106b22
change logic a bit and address comments
younesbelkada Oct 12, 2023
1834f8e
add typehint
younesbelkada Oct 12, 2023
21a8d6c
remove parenthesis
younesbelkada Oct 12, 2023
44f658d
fix
younesbelkada Oct 12, 2023
7fd50a7
revert to fp16 in tests
younesbelkada Oct 12, 2023
e92c6de
add to gpu
younesbelkada Oct 12, 2023
4e382ee
revert to old test
younesbelkada Oct 13, 2023
6ae767f
style
younesbelkada Oct 13, 2023
a0f976e
Merge branch 'main' into peft-part-2
sayakpaul Oct 13, 2023
b4e1381
Update src/diffusers/loaders.py
younesbelkada Oct 13, 2023
f708dba
change indent
younesbelkada Oct 13, 2023
f17206c
Apply suggestions from code review
patrickvonplaten Oct 13, 2023
950d19c
Apply suggestions from code review
patrickvonplaten Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
443 changes: 377 additions & 66 deletions src/diffusers/loaders.py

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn.functional as F
from torch import nn

from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention
Expand Down Expand Up @@ -300,6 +301,7 @@ def __init__(
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
Expand All @@ -316,14 +318,15 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
Expand Down Expand Up @@ -368,7 +371,9 @@ class GEGLU(nn.Module):

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2)

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
Expand All @@ -377,7 +382,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
args = () if USE_PEFT_BACKEND else (scale,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


Expand Down
44 changes: 29 additions & 15 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn.functional as F
from torch import nn

from ..utils import deprecate, logging
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer
Expand Down Expand Up @@ -137,22 +137,27 @@ def __init__(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)

self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear

self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)

if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None

if self.added_kv_proj_dim is not None:
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -545,6 +550,8 @@ def __call__(
):
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand All @@ -562,15 +569,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, scale=scale)
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
Expand All @@ -581,7 +588,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -1007,15 +1014,20 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, scale=scale)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = (
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand All @@ -1035,7 +1047,9 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = (
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torch import nn

from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .lora import LoRACompatibleLinear

Expand Down Expand Up @@ -166,8 +167,9 @@ def __init__(
cond_proj_dim=None,
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
self.linear_1 = linear_cls(in_channels, time_embed_dim)

if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
Expand All @@ -180,7 +182,7 @@ def __init__(
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)

if post_act_fn is None:
self.post_act = None
Expand Down
150 changes: 150 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
_get_model_file,
check_peft_version,
deprecate,
is_accelerate_available,
is_torch_version,
Expand Down Expand Up @@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_hf_peft_config_loaded = False

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -292,6 +295,153 @@ def disable_xformers_memory_efficient_attention(self):
"""
self.set_use_memory_efficient_attention_xformers(False)

def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
r"""
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
to the adapter to follow the convention of the PEFT library.

If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
[documentation](https://huggingface.co/docs/peft).

Args:
adapter_config (`[~peft.PeftConfig]`):
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
methods.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

from peft import PeftConfig, inject_adapter_in_model

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)

# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get this. Does it hurt to have base_model_name_or_path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it does not, but I think there is no equivalent of it in diffusers per my understanding

adapter_config.base_model_name_or_path = None
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)

def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.

If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Args:
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in case of single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

if isinstance(adapter_name, str):
adapter_name = [adapter_name]

missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)

from peft.tuners.tuners_utils import BaseTunerLayer

_adapters_has_been_set = False

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
# Previous versions of PEFT does not support multi-adapter inference
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
raise ValueError(
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)

def disable_adapters(self) -> None:
r"""
Disable all adapters attached to the model and fallback to inference with the base model only.

If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

from peft.tuners.tuners_utils import BaseTunerLayer

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
# support for older PEFT versions
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
list of adapters to enable.

If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

from peft.tuners.tuners_utils import BaseTunerLayer

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
# support for older PEFT versions
module.disable_adapters = False

def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.

If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

from peft.tuners.tuners_utils import BaseTunerLayer

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
Loading