Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Fix PEFT multi adapters support #26407

Merged
merged 12 commits into from
Sep 27, 2023
41 changes: 34 additions & 7 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import (
check_peft_version,
Expand Down Expand Up @@ -270,7 +270,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +298,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,14 +320,24 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
def active_adapter(self, return_multi_adapters: bool = False) -> Union[str, List[str]]:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Gets the current active adapter of the model.
Gets the current active adapter of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of active adapters so that users can deal with them accordingly.

Args:
return_multi_adapters (`bool`, *optional*, defaults to `False`):
Whether to return a list of active adapters or not. If `False`, only the first adapter is returned. If
`True`, returns the list of active adapters.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -333,7 +351,16 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapter = module.active_adapter
if isinstance(active_adapter, list) and not return_multi_adapters:
# In case the adapter name is a list (multiple adapters), we only consider the first one
active_adapter = active_adapter[0]

logger.warning(
"Multiple adapters detected, we will only consider the first adapter. If you want to get all active adapters, "
"call `active_adapter(return_multi_adapters=True)` instead."
)
return active_adapter

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand Down