From a0299fa9d656a5748395818091e4cadbc1bd73fb Mon Sep 17 00:00:00 2001 From: Sam Coward Date: Sun, 4 Jun 2023 19:43:02 -0400 Subject: [PATCH] Support loading PEFT (LoRA) models --- basaran/__init__.py | 1 + basaran/__main__.py | 2 ++ basaran/model.py | 12 +++++++++++- requirements.txt | 1 + 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/basaran/__init__.py b/basaran/__init__.py index 6b40c429..c5317d01 100644 --- a/basaran/__init__.py +++ b/basaran/__init__.py @@ -18,6 +18,7 @@ def is_true(value): PORT = int(os.getenv("PORT", "80")) # Model-related arguments: +MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", "")) MODEL_REVISION = os.getenv("MODEL_REVISION", "") MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models") MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", "")) diff --git a/basaran/__main__.py b/basaran/__main__.py index 504f160b..5adac250 100644 --- a/basaran/__main__.py +++ b/basaran/__main__.py @@ -23,6 +23,7 @@ from . import MODEL_LOAD_IN_4BIT from . import MODEL_4BIT_QUANT_TYPE from . import MODEL_4BIT_DOUBLE_QUANT +from . import MODEL_PEFT from . import MODEL_LOCAL_FILES_ONLY from . import MODEL_TRUST_REMOTE_CODE from . import MODEL_HALF_PRECISION @@ -44,6 +45,7 @@ name_or_path=MODEL, revision=MODEL_REVISION, cache_dir=MODEL_CACHE_DIR, + is_peft=MODEL_PEFT, load_in_8bit=MODEL_LOAD_IN_8BIT, load_in_4bit=MODEL_LOAD_IN_4BIT, quant_type=MODEL_4BIT_QUANT_TYPE, diff --git a/basaran/model.py b/basaran/model.py index b9082c7c..362b57d9 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -14,6 +14,10 @@ TopPLogitsWarper, BitsAndBytesConfig ) +from peft import ( + PeftConfig, + PeftModel +) from .choice import map_choice from .tokenizer import StreamTokenizer @@ -302,6 +306,7 @@ def load_model( name_or_path, revision=None, cache_dir=None, + is_peft=False, load_in_8bit=False, load_in_4bit=False, quant_type="fp4", @@ -319,7 +324,6 @@ def load_model( kwargs["revision"] = revision if cache_dir: kwargs["cache_dir"] = cache_dir - tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) # Set device mapping and quantization options if CUDA is available. if torch.cuda.is_available(): @@ -346,6 +350,12 @@ def load_model( if half_precision or load_in_8bit or load_in_4bit: kwargs["torch_dtype"] = torch.float16 + if is_peft: + peft_config = PeftConfig.from_pretrained(name_or_path) + name_or_path = peft_config.base_model_name_or_path + + tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) + # Support both decoder-only and encoder-decoder models. try: model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs) diff --git a/requirements.txt b/requirements.txt index db6c6216..85832f13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ safetensors~=0.3.1 torch>=1.12.1 transformers[sentencepiece]~=4.29.2 waitress~=2.1.2 +peft~=0.3.0 \ No newline at end of file