-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PEFT
] Fix PEFT multi adapters support
#26407
Changes from all commits
05bbcd7
b2e32a5
d772435
87e304a
3dc62a3
9250ccf
c7ad2b4
9e93a26
f8e435f
a46e13b
4ebe629
9f538ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,7 +12,7 @@ | |||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||
# limitations under the License. | ||||||||||||||
import inspect | ||||||||||||||
from typing import TYPE_CHECKING, Any, Dict, Optional | ||||||||||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||||||||||||||
|
||||||||||||||
from ..utils import ( | ||||||||||||||
check_peft_version, | ||||||||||||||
|
@@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non | |||||||||||||
|
||||||||||||||
self.set_adapter(adapter_name) | ||||||||||||||
|
||||||||||||||
def set_adapter(self, adapter_name: str) -> None: | ||||||||||||||
def set_adapter(self, adapter_name: Union[List[str], str]) -> None: | ||||||||||||||
""" | ||||||||||||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||||||||||||||
official documentation: https://huggingface.co/docs/peft | ||||||||||||||
|
||||||||||||||
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
adapter_name (`str`): | ||||||||||||||
The name of the adapter to set. | ||||||||||||||
adapter_name (`Union[List[str], str]`): | ||||||||||||||
The name of the adapter to set. Can be also a list of strings to set multiple adapters. | ||||||||||||||
""" | ||||||||||||||
check_peft_version(min_version=MIN_PEFT_VERSION) | ||||||||||||||
if not self._hf_peft_config_loaded: | ||||||||||||||
raise ValueError("No adapter loaded. Please load an adapter first.") | ||||||||||||||
elif isinstance(adapter_name, list): | ||||||||||||||
missing = set(adapter_name) - set(self.peft_config) | ||||||||||||||
if len(missing) > 0: | ||||||||||||||
raise ValueError( | ||||||||||||||
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." | ||||||||||||||
f" current loaded adapters are: {list(self.peft_config.keys())}" | ||||||||||||||
) | ||||||||||||||
elif adapter_name not in self.peft_config: | ||||||||||||||
raise ValueError( | ||||||||||||||
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" | ||||||||||||||
|
@@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None: | |||||||||||||
|
||||||||||||||
for _, module in self.named_modules(): | ||||||||||||||
if isinstance(module, BaseTunerLayer): | ||||||||||||||
module.active_adapter = adapter_name | ||||||||||||||
# For backward compatbility with previous PEFT versions | ||||||||||||||
if hasattr(module, "set_adapter"): | ||||||||||||||
module.set_adapter(adapter_name) | ||||||||||||||
else: | ||||||||||||||
module.active_adapter = adapter_name | ||||||||||||||
_adapters_has_been_set = True | ||||||||||||||
|
||||||||||||||
if not _adapters_has_been_set: | ||||||||||||||
|
@@ -294,7 +305,11 @@ def disable_adapters(self) -> None: | |||||||||||||
|
||||||||||||||
for _, module in self.named_modules(): | ||||||||||||||
if isinstance(module, BaseTunerLayer): | ||||||||||||||
module.disable_adapters = True | ||||||||||||||
# The recent version of PEFT need to call `enable_adapters` instead | ||||||||||||||
if hasattr(module, "enable_adapters"): | ||||||||||||||
module.enable_adapters(enabled=False) | ||||||||||||||
else: | ||||||||||||||
module.disable_adapters = True | ||||||||||||||
|
||||||||||||||
def enable_adapters(self) -> None: | ||||||||||||||
""" | ||||||||||||||
|
@@ -312,14 +327,22 @@ def enable_adapters(self) -> None: | |||||||||||||
|
||||||||||||||
for _, module in self.named_modules(): | ||||||||||||||
if isinstance(module, BaseTunerLayer): | ||||||||||||||
module.disable_adapters = False | ||||||||||||||
# The recent version of PEFT need to call `enable_adapters` instead | ||||||||||||||
if hasattr(module, "enable_adapters"): | ||||||||||||||
module.enable_adapters(enabled=True) | ||||||||||||||
else: | ||||||||||||||
module.disable_adapters = False | ||||||||||||||
|
||||||||||||||
def active_adapter(self) -> str: | ||||||||||||||
def active_adapters(self) -> List[str]: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that way it won't be BC --> for users that have previous PEFT version they will get an str whereas for newest PEFT versions they will get a list of str There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method didn't exist before so it's ok/good that it now always returns a list of strings |
||||||||||||||
""" | ||||||||||||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||||||||||||||
official documentation: https://huggingface.co/docs/peft | ||||||||||||||
|
||||||||||||||
Gets the current active adapter of the model. | ||||||||||||||
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters | ||||||||||||||
for inference) returns the list of all active adapters so that users can deal with them accordingly. | ||||||||||||||
|
||||||||||||||
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return | ||||||||||||||
a single string. | ||||||||||||||
""" | ||||||||||||||
check_peft_version(min_version=MIN_PEFT_VERSION) | ||||||||||||||
|
||||||||||||||
|
@@ -333,7 +356,21 @@ def active_adapter(self) -> str: | |||||||||||||
|
||||||||||||||
for _, module in self.named_modules(): | ||||||||||||||
if isinstance(module, BaseTunerLayer): | ||||||||||||||
return module.active_adapter | ||||||||||||||
active_adapters = module.active_adapter | ||||||||||||||
break | ||||||||||||||
|
||||||||||||||
# For previous PEFT versions | ||||||||||||||
if isinstance(active_adapters, str): | ||||||||||||||
active_adapters = [active_adapters] | ||||||||||||||
|
||||||||||||||
return active_adapters | ||||||||||||||
Comment on lines
+362
to
+366
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
def active_adapter(self) -> str: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think it's a good idea to have both
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed let's just deprecate
Suggested change
|
||||||||||||||
logger.warning( | ||||||||||||||
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
return self.active_adapters()[0] | ||||||||||||||
|
||||||||||||||
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: | ||||||||||||||
""" | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2006,7 +2006,16 @@ def save_pretrained( | |
peft_state_dict[f"base_model.model.{key}"] = value | ||
state_dict = peft_state_dict | ||
|
||
current_peft_config = self.peft_config[self.active_adapter()] | ||
active_adapter = self.active_adapters() | ||
|
||
if len(active_adapter) > 1: | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
"Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " | ||
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" | ||
) | ||
active_adapter = active_adapter[0] | ||
|
||
current_peft_config = self.peft_config[active_adapter] | ||
current_peft_config.save_pretrained(save_directory) | ||
|
||
# Save the model | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do the same in diffusers :-) https://github.com/huggingface/diffusers/pull/5151/files#r1338774571