Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Fix PEFT multi adapters support #26407

Merged
merged 12 commits into from
Sep 27, 2023
55 changes: 44 additions & 11 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import (
check_peft_version,
Expand Down Expand Up @@ -245,20 +245,26 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non

self.set_adapter(adapter_name)

def set_adapter(self, adapter_name: str) -> None:
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.

Args:
adapter_name (`str`):
The name of the adapter to set.
adapter_name (`Union[List[str], str]`):
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
for adapter in adapter_name:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In situations like this, I always prefer to identify all incorrect entries, not just the first. Otherwise, the user fixes the first, then gets an error on 2nd, fixes the error on 2nd, gets third, etc. Much nicer to check all and runtime shouldn't really matter here.

missing = set(adapter_name) - set(self.peft_config)
if missing:
    raise ValueError(
        f"Following adapter(s) could not be found: {', '.join(missing)}. Please ...")

With a bit of extra work, this could even be merged with the code below which checks for a single adapter not being found.

if adapter not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
Expand All @@ -270,7 +276,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +304,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,14 +326,24 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
def active_adapter(self, return_multi_adapters: bool = True) -> Union[str, List[str]]:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Gets the current active adapter of the model.
Gets the current active adapter of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of active adapters so that users can deal with them accordingly.

Args:
return_multi_adapters (`bool`, *optional*, defaults to `True`):
Whether to return a list of active adapters or not. If `False`, only the first adapter is returned. If
`True`, returns the list of active adapters.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure about this API? It would seem better to me to have active_adapter changed to active_adapters and to always default to returning the list of active adapters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BC would be handled by having another method active_adapter that would log a warning + call active_adapters under the hood

"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -333,7 +357,16 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapter = module.active_adapter
if isinstance(active_adapter, list) and not return_multi_adapters:
# In case the adapter name is a list (multiple adapters), we only consider the first one
active_adapter = active_adapter[0]

logger.warning(
"Multiple adapters detected, we will only consider the first adapter. If you want to get all active adapters, "
"call `active_adapter(return_multi_adapters=True)` instead."
)
return active_adapter

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand All @@ -355,7 +388,7 @@ def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
from peft import get_peft_model_state_dict

if adapter_name is None:
adapter_name = self.active_adapter()
adapter_name = self.active_adapter(return_multi_adapters=False)

adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
return adapter_state_dict
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,7 +2006,17 @@ def save_pretrained(
peft_state_dict[f"base_model.model.{key}"] = value
state_dict = peft_state_dict

current_peft_config = self.peft_config[self.active_adapter()]
active_adapter = self.active_adapter()

if isinstance(active_adapter, list):
if len(active_adapter) > 1:
logger.warning(
"Multiple active adapters detected, will only consider the first active adapter. In order to save them all, please iteratively call `set_adapter()` on each"
" adapter name and save them one by one manually. "
)
active_adapter = active_adapter[0]

current_peft_config = self.peft_config[active_adapter]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should multi-adapter serialization be supported out of the box?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would advocate to educate users to manually save them one by one as this might require some imporant refactor in save_pretrained

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also expose a new method in PeftAdapterMixin save_adapters to support that by iteratively calling save_pretrained on all adapters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok understood, let's go with what you have right now then

current_peft_config.save_pretrained(save_directory)

# Save the model
Expand Down
18 changes: 16 additions & 2 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,31 @@ def test_peft_add_multi_adapter(self):
_ = model.generate(input_ids=dummy_input)

model.set_adapter("default")
self.assertTrue(model.active_adapter() == "default")
self.assertTrue(model.active_adapter() == ["default"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so the change is backwards incompatible for this specific case? Can we live with that?

self.assertTrue(model.active_adapter(return_multi_adapters=False) == "default")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change; are we ok with that breaking change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set the default return_multi_adapters to True, but there will be still the case of users that used the main branch of PEFT for multiple active adapters that will face a behaviour which is going to be different whether they use 2 different versions of transformers - cc @BenjaminBossan also what do you think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think the most common case would be users with only a single active adapter, as multiple active adapters just landed on PEFT main and have not been advertised yet. Therefore, I would suggest to keep BC for those users and accept that maybe a tiny minority who uses multiple active adapters with the current transformers version will get a breaking change when they upgrade transformers.


model.set_adapter("adapter-2")
self.assertTrue(model.active_adapter() == "adapter-2")
self.assertTrue(model.active_adapter() == ["adapter-2"])
self.assertTrue(model.active_adapter(return_multi_adapters=False) == "adapter-2")

# Logits comparison
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)
)
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))

model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapter() == ["adapter-2", "default"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition to the test. Can we also check model.active_adapter() and model.active_adapter(return_multi_adapters=False) here?


logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

@require_torch_gpu
def test_peft_from_pretrained_kwargs(self):
"""
Expand Down