Skip to content

Commit

Permalink
ENH Warn when adapter name contains prefix (#2254)
Browse files Browse the repository at this point in the history
Warn when adapter_name contains the tuner_prefix, which can cause
weight reinitialization during model loading.
  • Loading branch information
pzdkn authored Dec 11, 2024
1 parent 3c61b3e commit 5cdade9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)
from .tuners.tuners_utils import BaseTuner
from .utils import _prepare_prompt_learning_config
from .utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING


if TYPE_CHECKING:
Expand Down Expand Up @@ -204,6 +205,13 @@ def get_peft_model(
"Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
)

prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
if prefix and adapter_name in prefix:
warnings.warn(
f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
"This may lead to reinitialization of the adapter weights during loading."
)

if mixed:
# note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
Expand Down
19 changes: 18 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,17 @@ def from_pretrained(
# Let's warn here since (in contrast to load_adapter) we don't return the load result, so it could be quite
# difficult for users to even notice that something might have gone wrong here. As we filter out non PEFT
# keys from the missing keys, this gives no false positives.
warnings.warn(f"Found missing adapter keys while loading the checkpoint: {missing_keys}")

warn_message = f"Found missing adapter keys while loading the checkpoint: {missing_keys}."

prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(config.peft_type)
if prefix and adapter_name in prefix:
warn_message += (
f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
"This could be the potential reason for missing adapter keys."
)

warnings.warn(warn_message)

return model

Expand Down Expand Up @@ -940,6 +950,13 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_us
adapters. Don't use this option when creating a new PEFT adapter for training.
"""
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
if prefix and adapter_name in prefix:
warnings.warn(
f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
"This may lead to reinitialization of the adapter weights during loading."
)

if peft_config.peft_type != self.peft_type:
raise ValueError(
f"Cannot combine adapters with different peft types. "
Expand Down
70 changes: 70 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
from peft.utils.hotswap import hotswap_adapter


Expand Down Expand Up @@ -1747,6 +1748,75 @@ def new_state_dict():
assert any(missing_key in str(w.message) for w in recwarn.list)


class TestNamingConflictWarning:
"""
Tests for warnings related to naming conflicts between adapter names and tuner prefixes. References: Issue 2252
"""

@pytest.fixture(autouse=True)
def setup(self):
self.peft_config = LoraConfig()
self.prefix = PEFT_TYPE_TO_PREFIX_MAPPING[self.peft_config.peft_type]
self.base_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM")

def _save_and_reload_model(self, model, adapter_name, tmp_path):
# Helper method to save and reload the PEFT model
model.save_pretrained(tmp_path, selected_adapters=[adapter_name])
del model
reloaded_base_model = AutoModelForCausalLM.from_pretrained(tmp_path / adapter_name)
return PeftModel.from_pretrained(reloaded_base_model, tmp_path / adapter_name)

def test_no_warning_without_naming_conflict_get_peft_model(self, recwarn):
# No warning should be raised when there is no naming conflict during get_peft_model.
non_conflict_adapter = "adapter"
_ = get_peft_model(self.base_model, self.peft_config, adapter_name=non_conflict_adapter)
expected_msg = f"Adapter name {non_conflict_adapter} should not be contained in the prefix {self.prefix}."
assert not any(expected_msg in str(w.message) for w in recwarn.list)

def test_no_warning_without_naming_conflict_add_adapter(self, recwarn):
# No warning should be raised when adding an adapter without naming conflict.
non_conflict_adapter = "adapter"
other_non_conflict_adapter = "other_adapter"
model = get_peft_model(self.base_model, self.peft_config, adapter_name=non_conflict_adapter)
_ = model.add_adapter(other_non_conflict_adapter, self.peft_config)
expected_msg = (
f"Adapter name {other_non_conflict_adapter} should not be contained in the prefix {self.prefix}."
)
assert not any(expected_msg in str(w.message) for w in recwarn.list)

def test_no_warning_without_naming_conflict_save_and_load(self, recwarn, tmp_path):
# No warning should be raised when saving and loading the model without naming conflict.
non_conflict_adapter = "adapter"
model = get_peft_model(self.base_model, self.peft_config, adapter_name=non_conflict_adapter)
_ = self._save_and_reload_model(model, non_conflict_adapter, tmp_path)
expected_msg = f"Adapter name {non_conflict_adapter} should not be contained in the prefix {self.prefix}."
assert not any(expected_msg in str(w.message) for w in recwarn.list)

def test_warning_naming_conflict_get_peft_model(self, recwarn):
# Warning is raised when the adapter name conflicts with the prefix in get_peft_model.
conflicting_adapter_name = self.prefix[:-1]
_ = get_peft_model(self.base_model, self.peft_config, adapter_name=conflicting_adapter_name)
expected_msg = f"Adapter name {conflicting_adapter_name} should not be contained in the prefix {self.prefix}."
assert any(expected_msg in str(w.message) for w in recwarn.list)

def test_warning_naming_conflict_add_adapter(self, recwarn):
# Warning is raised when adding an adapter with a name that conflicts with the prefix.
conflicting_adapter = self.prefix[1:]
non_conflict_adapter = "adapter"
model = get_peft_model(self.base_model, self.peft_config, adapter_name=non_conflict_adapter)
_ = model.add_adapter(conflicting_adapter, self.peft_config)
expected_msg = f"Adapter name {conflicting_adapter} should not be contained in the prefix {self.prefix}."
assert any(expected_msg in str(w.message) for w in recwarn.list)

def test_warning_naming_conflict_save_and_load(self, recwarn, tmp_path):
# Warning is raised when saving and loading the model with a naming conflict.
conflicting_adapter = self.prefix[:-1]
model = get_peft_model(self.base_model, self.peft_config, adapter_name=conflicting_adapter)
_ = self._save_and_reload_model(model, conflicting_adapter, tmp_path)
expected_msg = f"Adapter name {conflicting_adapter} should not be contained in the prefix {self.prefix}."
assert any(expected_msg in str(w.message) for w in recwarn.list)


class TestEvaInitialization:
"""Tests for the EVA (Explained Variance Adaptation) initialization method.
Expand Down

0 comments on commit 5cdade9

Please sign in to comment.