Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Allow loading model with BitsAndBytes 4bit quantization, PEFT LoRA adapters. #203

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions basaran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ def is_true(value):
PORT = int(os.getenv("PORT", "80"))

# Model-related arguments:
MODEL_PEFT = is_true(os.getenv("MODEL_PEFT", ""))
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
MODEL_LOAD_IN_4BIT = is_true(os.getenv("MODEL_LOAD_IN_4BIT", ""))
MODEL_4BIT_QUANT_TYPE = os.getenv("MODEL_4BIT_QUANT_TYPE", "fp4")
MODEL_4BIT_DOUBLE_QUANT = is_true(os.getenv("MODEL_4BIT_DOUBLE_QUANT", ""))
MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", ""))
MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", ""))
MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", ""))
Expand Down
6 changes: 6 additions & 0 deletions basaran/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from . import MODEL_CACHE_DIR
from . import MODEL_LOAD_IN_8BIT
from . import MODEL_LOAD_IN_4BIT
from . import MODEL_4BIT_QUANT_TYPE
from . import MODEL_4BIT_DOUBLE_QUANT
from . import MODEL_PEFT
from . import MODEL_LOCAL_FILES_ONLY
from . import MODEL_TRUST_REMOTE_CODE
from . import MODEL_HALF_PRECISION
Expand All @@ -42,8 +45,11 @@
name_or_path=MODEL,
revision=MODEL_REVISION,
cache_dir=MODEL_CACHE_DIR,
is_peft=MODEL_PEFT,
load_in_8bit=MODEL_LOAD_IN_8BIT,
load_in_4bit=MODEL_LOAD_IN_4BIT,
quant_type=MODEL_4BIT_QUANT_TYPE,
double_quant=MODEL_4BIT_DOUBLE_QUANT,
local_files_only=MODEL_LOCAL_FILES_ONLY,
trust_remote_code=MODEL_TRUST_REMOTE_CODE,
half_precision=MODEL_HALF_PRECISION,
Expand Down
36 changes: 33 additions & 3 deletions basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
MinNewTokensLengthLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
BitsAndBytesConfig
)
from peft import (
PeftConfig,
PeftModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that PeftModel is not being used. Are you sure that PEFT is working correctly? (The GitHub actions environment does not have a GPU for testing)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops 😅 well that explains why I wasn't seeming to get any results from my LoRA fine tunings.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I added loading. But it only works with 4bit with dev version of peft. If loading 4bit with peft 0.3.0 then it will error on inference.

)

from .choice import map_choice
Expand Down Expand Up @@ -309,8 +314,11 @@ def load_model(
name_or_path,
revision=None,
cache_dir=None,
is_peft=False,
load_in_8bit=False,
load_in_4bit=False,
quant_type="fp4",
double_quant=False,
local_files_only=False,
trust_remote_code=False,
half_precision=False,
Expand All @@ -324,24 +332,46 @@ def load_model(
kwargs["revision"] = revision
if cache_dir:
kwargs["cache_dir"] = cache_dir
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)

# Set device mapping and quantization options if CUDA is available.
if torch.cuda.is_available():
# Set quantization options if specified.
quant_config = None
if load_in_8bit and load_in_4bit:
raise ValueError("Only one of load_in_8bit and load_in_4bit can be True")
if load_in_8bit:
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
)
elif load_in_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type,
bnb_4bit_use_double_quant=double_quant,
bnb_4bit_compute_dtype=torch.bfloat16,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding bnb_4bit_compute_dtype to bfloat16 is indeed a reasonable choice. But similarly, can't we use the recommended configuration for bnb_4bit_quant_type and bnb_4bit_use_double_quant in most scenarios? In fact, I prefer to keep only the load_in_4bit option to reduce user confusion. What do you think?

Reference: https://huggingface.co/blog/4bit-transformers-bitsandbytes#advanced-usage

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree it's a lot more configuration options. I just was not sure how much people are playing around with different options, so I put them all in!

I think hardcoding "nf4" might be reasonable for the 4bit quant type as the QLoRA literature recommends this, even though the default for BitsAndBytesConfig is "fp4".

double_quant seems to only be suggested if memory constrained. Maybe the 0.4 bits are not worth it? If you think so, then I can hardcode to False and remove the optional config?

Finally, with bnb_4bit_compute_dtype I was not very sure of the tradeoffs - while QLoRA supposedly uses bfloat16 the default for BitsAndBytesConfig is float32. The reference seems to suggest bfloat16 it's for faster training, I thought it might be the same for inference, but that's not explicitly called out. Maybe it's only a memory saving benefit for inference??? So is this decision a good one? 🤔

Sorry for all the questions, I'm still trying to level up on ML code, and once again appreciate the feedback!

)
kwargs = kwargs.copy()
kwargs["device_map"] = "auto"
kwargs["load_in_8bit"] = load_in_8bit
kwargs["load_in_4bit"] = load_in_4bit
kwargs["quantization_config"] = quant_config

# Cast all parameters to float16 if quantization is enabled.
if half_precision or load_in_8bit or load_in_4bit:
kwargs["torch_dtype"] = torch.float16

if is_peft:
peft_config = PeftConfig.from_pretrained(name_or_path)
peft_model_name_or_path = name_or_path
name_or_path = peft_config.base_model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
peakji marked this conversation as resolved.
Show resolved Hide resolved

# Support both decoder-only and encoder-decoder models.
try:
model = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs)
except ValueError:
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path, **kwargs)
if is_peft:
model = PeftModel.from_pretrained(model, peft_model_name_or_path, **kwargs)

# Check if the model has text generation capabilities.
if not model.can_generate():
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ safetensors~=0.3.1
torch>=1.12.1
transformers[sentencepiece]~=4.30.1
waitress~=2.1.2
peft~=0.3.0
scipy~=1.10.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing dependency is a bug in bitsandbytes: bitsandbytes-foundation/bitsandbytes#426

Instead of specifying the version of the indirect dependency, I suggest waiting for bitsandbytes to fix the issue in version 0.39.0.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. It did feel a bit weird that I had to add this!