Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: PEFT method registration function #2282

Merged
Prev Previous commit
Next Next commit
make style
BenjaminBossan committed Dec 13, 2024
commit 2db48c2dbcfc25fd3b7eb03b674de26bc270acca
9 changes: 4 additions & 5 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
@@ -38,11 +38,12 @@
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.constants import DUMMY_MODEL_CONFIG

from . import __version__
from .config import PeftConfig
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_TUNER_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING, PEFT_TYPE_TO_TUNER_MAPPING
from .utils import (
SAFETENSORS_WEIGHTS_NAME,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
@@ -1736,9 +1737,7 @@ def forward(
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

def _cpt_forward(
self, input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs
):
def _cpt_forward(self, input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs):
githubnemo marked this conversation as resolved.
Show resolved Hide resolved
# Extract labels from kwargs
labels = kwargs.pop("labels")
device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0]
4 changes: 3 additions & 1 deletion src/peft/tuners/adalora/__init__.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,9 @@
__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"]


register_peft_method(name="adalora", config_cls=AdaLoraConfig, model_cls=AdaLoraModel, prefix="lora_", is_mixed_compatible=True)
register_peft_method(
name="adalora", config_cls=AdaLoraConfig, model_cls=AdaLoraModel, prefix="lora_", is_mixed_compatible=True
)


def __getattr__(name):
4 changes: 2 additions & 2 deletions src/peft/tuners/adaption_prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from peft.utils import register_peft_method

from .config import AdaptionPromptConfig
from .layer import AdaptedAttention
from .model import AdaptionPromptModel

from peft.utils import register_peft_method


__all__ = ["AdaptionPromptConfig", "AdaptedAttention", "AdaptionPromptModel"]

4 changes: 2 additions & 2 deletions src/peft/tuners/boft/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import BOFTConfig
from .layer import BOFTLayer
from .model import BOFTModel

from peft.utils import register_peft_method


__all__ = ["BOFTConfig", "BOFTLayer", "BOFTModel"]

4 changes: 2 additions & 2 deletions src/peft/tuners/bone/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import BoneConfig
from .layer import BoneLayer, BoneLinear
from .model import BoneModel

from peft.utils import register_peft_method


__all__ = ["BoneConfig", "BoneModel", "BoneLinear", "BoneLayer"]

4 changes: 2 additions & 2 deletions src/peft/tuners/cpt/__init__.py
Original file line number Diff line number Diff line change
@@ -13,11 +13,11 @@
# limitations under the License.


from peft.utils import register_peft_method

from .config import CPTConfig
from .model import CPTEmbedding

from peft.utils import register_peft_method


__all__ = ["CPTConfig", "CPTEmbedding"]

4 changes: 2 additions & 2 deletions src/peft/tuners/fourierft/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import FourierFTConfig
from .layer import FourierFTLayer, FourierFTLinear
from .model import FourierFTModel

from peft.utils import register_peft_method


__all__ = ["FourierFTConfig", "FourierFTLayer", "FourierFTLinear", "FourierFTModel"]

4 changes: 2 additions & 2 deletions src/peft/tuners/hra/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import HRAConfig
from .layer import HRAConv2d, HRALayer, HRALinear
from .model import HRAModel

from peft.utils import register_peft_method


__all__ = ["HRAConfig", "HRAModel", "HRAConv2d", "HRALinear", "HRALayer"]

3 changes: 2 additions & 1 deletion src/peft/tuners/ln_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import LNTuningConfig
from .model import LNTuningModel

from peft.utils import register_peft_method

__all__ = ["LNTuningConfig", "LNTuningModel"]

4 changes: 2 additions & 2 deletions src/peft/tuners/loha/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import LoHaConfig
from .layer import Conv2d, Linear, LoHaLayer
from .model import LoHaModel

from peft.utils import register_peft_method


__all__ = ["LoHaConfig", "LoHaModel", "Conv2d", "Linear", "LoHaLayer"]

4 changes: 2 additions & 2 deletions src/peft/tuners/lokr/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import LoKrConfig
from .layer import Conv2d, Linear, LoKrLayer
from .model import LoKrModel

from peft.utils import register_peft_method


__all__ = ["LoKrConfig", "LoKrModel", "Conv2d", "Linear", "LoKrLayer"]

8 changes: 5 additions & 3 deletions src/peft/tuners/multitask_prompt_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit
from .model import MultitaskPromptEmbedding

from peft.utils import register_peft_method


__all__ = ["MultitaskPromptTuningConfig", "MultitaskPromptTuningInit", "MultitaskPromptEmbedding"]

register_peft_method(name="multitask_prompt_tuning", config_cls=MultitaskPromptTuningConfig, model_cls=MultitaskPromptEmbedding)
register_peft_method(
name="multitask_prompt_tuning", config_cls=MultitaskPromptTuningConfig, model_cls=MultitaskPromptEmbedding
)
4 changes: 2 additions & 2 deletions src/peft/tuners/oft/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import OFTConfig
from .layer import Conv2d, Linear, OFTLayer
from .model import OFTModel

from peft.utils import register_peft_method


__all__ = ["OFTConfig", "OFTModel", "Conv2d", "Linear", "OFTLayer"]

4 changes: 2 additions & 2 deletions src/peft/tuners/p_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import PromptEncoderConfig, PromptEncoderReparameterizationType
from .model import PromptEncoder

from peft.utils import register_peft_method


__all__ = ["PromptEncoder", "PromptEncoderConfig", "PromptEncoderReparameterizationType"]

4 changes: 2 additions & 2 deletions src/peft/tuners/poly/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import PolyConfig
from .layer import Linear, PolyLayer
from .model import PolyModel

from peft.utils import register_peft_method


__all__ = ["Linear", "PolyConfig", "PolyLayer", "PolyModel"]

4 changes: 2 additions & 2 deletions src/peft/tuners/prefix_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import PrefixTuningConfig
from .model import PrefixEncoder

from peft.utils import register_peft_method


__all__ = ["PrefixTuningConfig", "PrefixEncoder"]

4 changes: 2 additions & 2 deletions src/peft/tuners/prompt_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import PromptTuningConfig, PromptTuningInit
from .model import PromptEmbedding

from peft.utils import register_peft_method


__all__ = ["PromptTuningConfig", "PromptEmbedding", "PromptTuningInit"]

4 changes: 2 additions & 2 deletions src/peft/tuners/vblora/__init__.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import VBLoRAConfig
from .layer import Linear, VBLoRALayer
from .model import VBLoRAModel

from peft.utils import register_peft_method


__all__ = ["VBLoRAConfig", "VBLoRALayer", "Linear", "VBLoRAModel"]

3 changes: 1 addition & 2 deletions src/peft/tuners/vera/__init__.py
Original file line number Diff line number Diff line change
@@ -13,13 +13,12 @@
# limitations under the License.

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.utils import register_peft_method

from .config import VeraConfig
from .layer import Linear, VeraLayer
from .model import VeraModel

from peft.utils import register_peft_method


__all__ = ["VeraConfig", "VeraLayer", "Linear", "VeraModel"]

4 changes: 2 additions & 2 deletions src/peft/tuners/xlora/__init__.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import XLoraConfig
from .model import XLoraModel

from peft.utils import register_peft_method


__all__ = ["XLoraConfig", "XLoraModel"]

2 changes: 0 additions & 2 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,6 @@
import torch
from transformers import BloomPreTrainedModel

from .peft_types import PeftType


# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
21 changes: 17 additions & 4 deletions src/peft/utils/peft_types.py
Original file line number Diff line number Diff line change
@@ -87,9 +87,16 @@ class TaskType(str, enum.Enum):
FEATURE_EXTRACTION = "FEATURE_EXTRACTION"


def register_peft_method(*, name: str, config_cls, model_cls, prefix: Optional[str] = None, is_mixed_compatible=False) -> None:
def register_peft_method(
*, name: str, config_cls, model_cls, prefix: Optional[str] = None, is_mixed_compatible=False
) -> None:
"""TODO"""
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_MIXED_MODEL_MAPPING, PEFT_TYPE_TO_TUNER_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING
from peft.mapping import (
PEFT_TYPE_TO_CONFIG_MAPPING,
PEFT_TYPE_TO_MIXED_MODEL_MAPPING,
PEFT_TYPE_TO_PREFIX_MAPPING,
PEFT_TYPE_TO_TUNER_MAPPING,
)

if name.endswith("_"):
raise ValueError(f"Please pass the name of the PEFT method without '_' suffix, got {name}.")
@@ -106,15 +113,21 @@ def register_peft_method(*, name: str, config_cls, model_cls, prefix: Optional[s
if prefix is None:
prefix = name + "_"

if (peft_type in PEFT_TYPE_TO_CONFIG_MAPPING) or (peft_type in PEFT_TYPE_TO_TUNER_MAPPING) or (peft_type in PEFT_TYPE_TO_MIXED_MODEL_MAPPING):
if (
(peft_type in PEFT_TYPE_TO_CONFIG_MAPPING)
or (peft_type in PEFT_TYPE_TO_TUNER_MAPPING)
or (peft_type in PEFT_TYPE_TO_MIXED_MODEL_MAPPING)
):
raise KeyError(f"There is already PEFT method called '{name}', please choose a unique name.")

if prefix in PEFT_TYPE_TO_PREFIX_MAPPING:
raise KeyError(f"There is already a prefix called '{prefix}', please choose a unique prefix.")

model_cls_prefix = getattr(model_cls, "prefix", None)
if (model_cls_prefix is not None) and (model_cls_prefix != prefix):
raise ValueError(f"Inconsistent prefixes found: '{prefix}' and '{model_cls_prefix}' (they should be the same).")
raise ValueError(
f"Inconsistent prefixes found: '{prefix}' and '{model_cls_prefix}' (they should be the same)."
)

PEFT_TYPE_TO_PREFIX_MAPPING[peft_type] = prefix
PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] = config_cls