Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Fix PEFT multi adapters support #26407

Merged
merged 12 commits into from
Sep 27, 2023
33 changes: 29 additions & 4 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +298,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,7 +320,11 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
"""
Expand All @@ -333,7 +345,11 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapter = module.active_adapter
if isinstance(active_adapter, list):
# In case the adapter name is a list (multiple adapters), we only consider the first one
active_adapter = active_adapter[0]
return active_adapter

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand All @@ -357,6 +373,15 @@ def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
if adapter_name is None:
adapter_name = self.active_adapter()

if isinstance(adapter_name, list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for this method to be called with a list of str? Wouldn't this require the user to explicitly pass that argument? If not, I think this check is not necessary. If it is a valid argument, then the type annotation should also be adjusted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I don't think we should allow users to call this method with a list of str - I refactored a bit the logic of def active_adapter to extend it for multi-adapter inference. let me know what do you think

# In case the adapter name is a list (multiple adapters), we only consider the first one
adapter_name = adapter_name[0]

logger.warning(
"Multiple adapters detected, we will only consider the first adapter, to get all adapters state dict manually loop "
"over the list of adapters and call `get_adapter_state_dict` for each adapter."
)

adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
return adapter_state_dict

Expand Down