diff --git a/basaran/__init__.py b/basaran/__init__.py index 28f20bd3..c5317d01 100644 --- a/basaran/__init__.py +++ b/basaran/__init__.py @@ -18,10 +18,13 @@ def is_true(value): PORT = int(os.getenv("PORT", "80")) # Model-related arguments: +MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", "")) MODEL_REVISION = os.getenv("MODEL_REVISION", "") MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models") MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", "")) MODEL_LOAD_IN_4BIT = is_true(os.getenv("MODEL_LOAD_IN_4BIT", "")) +MODEL_4BIT_QUANT_TYPE = os.getenv("MODEL_4BIT_QUANT_TYPE", "fp4") +MODEL_4BIT_DOUBLE_QUANT = is_true(os.getenv("MODEL_4BIT_DOUBLE_QUANT", "")) MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", "")) MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", "")) MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", "")) diff --git a/basaran/__main__.py b/basaran/__main__.py index bcd123b2..5adac250 100644 --- a/basaran/__main__.py +++ b/basaran/__main__.py @@ -21,6 +21,9 @@ from . import MODEL_CACHE_DIR from . import MODEL_LOAD_IN_8BIT from . import MODEL_LOAD_IN_4BIT +from . import MODEL_4BIT_QUANT_TYPE +from . import MODEL_4BIT_DOUBLE_QUANT +from . import MODEL_PEFT from . import MODEL_LOCAL_FILES_ONLY from . import MODEL_TRUST_REMOTE_CODE from . import MODEL_HALF_PRECISION @@ -42,8 +45,11 @@ name_or_path=MODEL, revision=MODEL_REVISION, cache_dir=MODEL_CACHE_DIR, + is_peft=MODEL_PEFT, load_in_8bit=MODEL_LOAD_IN_8BIT, load_in_4bit=MODEL_LOAD_IN_4BIT, + quant_type=MODEL_4BIT_QUANT_TYPE, + double_quant=MODEL_4BIT_DOUBLE_QUANT, local_files_only=MODEL_LOCAL_FILES_ONLY, trust_remote_code=MODEL_TRUST_REMOTE_CODE, half_precision=MODEL_HALF_PRECISION, diff --git a/basaran/model.py b/basaran/model.py index c1bda2d7..62a52c9b 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -12,6 +12,11 @@ MinNewTokensLengthLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, + BitsAndBytesConfig +) +from peft import ( + PeftConfig, + PeftModel ) from .choice import map_choice @@ -309,8 +314,11 @@ def load_model( name_or_path, revision=None, cache_dir=None, + is_peft=False, load_in_8bit=False, load_in_4bit=False, + quant_type="fp4", + double_quant=False, local_files_only=False, trust_remote_code=False, half_precision=False, @@ -324,24 +332,46 @@ def load_model( kwargs["revision"] = revision if cache_dir: kwargs["cache_dir"] = cache_dir - tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) # Set device mapping and quantization options if CUDA is available. if torch.cuda.is_available(): + # Set quantization options if specified. + quant_config = None + if load_in_8bit and load_in_4bit: + raise ValueError("Only one of load_in_8bit and load_in_4bit can be True") + if load_in_8bit: + quant_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + elif load_in_4bit: + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=quant_type, + bnb_4bit_use_double_quant=double_quant, + bnb_4bit_compute_dtype=torch.bfloat16, + ) kwargs = kwargs.copy() kwargs["device_map"] = "auto" - kwargs["load_in_8bit"] = load_in_8bit - kwargs["load_in_4bit"] = load_in_4bit + kwargs["quantization_config"] = quant_config # Cast all parameters to float16 if quantization is enabled. if half_precision or load_in_8bit or load_in_4bit: kwargs["torch_dtype"] = torch.float16 + if is_peft: + peft_config = PeftConfig.from_pretrained(name_or_path) + peft_model_name_or_path = name_or_path + name_or_path = peft_config.base_model_name_or_path + + tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) + # Support both decoder-only and encoder-decoder models. try: model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs) except ValueError: model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path, **kwargs) + if is_peft: + model = PeftModel.from_pretrained(model, peft_model_name_or_path, **kwargs) # Check if the model has text generation capabilities. if not model.can_generate(): diff --git a/requirements.txt b/requirements.txt index f36f38c0..841a1005 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ safetensors~=0.3.1 torch>=1.12.1 transformers[sentencepiece]~=4.30.1 waitress~=2.1.2 +peft~=0.3.0 +scipy~=1.10.1 \ No newline at end of file