Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Fix PEFT multi adapters support #26407

Merged
merged 12 commits into from
Sep 27, 2023
70 changes: 60 additions & 10 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import (
check_peft_version,
Expand Down Expand Up @@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non

self.set_adapter(adapter_name)

def set_adapter(self, adapter_name: str) -> None:
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.

Args:
adapter_name (`str`):
The name of the adapter to set.
adapter_name (`Union[List[str], str]`):
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
Expand All @@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +305,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,14 +327,22 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
def active_adapters(self) -> List[str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that way it won't be BC --> for users that have previous PEFT version they will get an str whereas for newest PEFT versions they will get a list of str

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method didn't exist before so it's ok/good that it now always returns a list of strings

"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft

Gets the current active adapter of the model.
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of all active adapters so that users can deal with them accordingly.

For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
a single string.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -333,7 +356,34 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapters = module.active_adapter
break

# For previous PEFT versions
if isinstance(active_adapters, str):
active_adapters = [active_adapters]

return active_adapters
Comment on lines +362 to +366
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For previous PEFT versions
if isinstance(active_adapters, str):
active_adapters = [active_adapters]
return active_adapters
return active_adapters


def active_adapter(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it's a good idea to have both active_adapter and active_adapters:

  • It's easier for the user to just have to remember one method and deal with the output type (we can also deprecate returning a single string in favor of returning a list going forward)
  • Reduces the API surface - only a single source of truth

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed let's just deprecate active_adapter(...) in favor of active_adapters

Suggested change
def active_adapter(self) -> str:
def active_adapter(self) -> str:
# deprecate ...
return self.active_adapters[0]

"""
Gets the current active adapter of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the first active adapter - kept for backward compatibility.

For higher versions of PEFT, users should use `model.active_adapters()` instead to get the list of active
adapters.
"""

active_adapters = self.active_adapters()

if isinstance(active_adapters, list):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this check is necessary as for previous PEFT versions active_adapter is an str

logger.warning(
"`active_adapter` will return the first adapter in case of multi-adapter inference. Make sure to know what you are doing.",
" you should use `model.active_adapters() instead to get the list of active adapters",
)
active_adapters = active_adapters[0]

return active_adapters

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,7 +2006,16 @@ def save_pretrained(
peft_state_dict[f"base_model.model.{key}"] = value
state_dict = peft_state_dict

current_peft_config = self.peft_config[self.active_adapter()]
active_adapter = self.active_adapters()

if len(active_adapter) > 1:
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of warning, I would raise here. It's important for the user to understand the difference between multiple active adapters and single. I would limit the functionality for multi active adapters heavily for now

"Multiple active adapters detected, will only consider the first active adapter. In order to save them all, please iteratively call `set_adapter()` on each"
" adapter name and save them one by one manually. "
)
active_adapter = active_adapter[0]

current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory)

# Save the model
Expand Down
15 changes: 15 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,11 @@ def test_peft_add_multi_adapter(self):
_ = model.generate(input_ids=dummy_input)

model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default")

model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2")

# Logits comparison
Expand All @@ -276,6 +278,19 @@ def test_peft_add_multi_adapter(self):
)
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))

model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")

logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)

@require_torch_gpu
def test_peft_from_pretrained_kwargs(self):
"""
Expand Down