Skip to content

Commit

Permalink
loading adapter model ontop
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-aeviator committed Aug 27, 2023
1 parent 694a535 commit aba56c1
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import inspect

import json

from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
from typing import Optional, Tuple, List, Type, Dict
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig

from text_generation_server.models import Model
from text_generation_server.models.types import (
Expand Down Expand Up @@ -476,11 +478,15 @@ def __init__(
)

should_quantize = quantize == "bitsandbytes"
if(should_quantize):
if should_quantize:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)

has_peft_model = True
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"

model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
Expand All @@ -490,6 +496,15 @@ def __init__(
quantization_config = quantization_config if should_quantize else None,
trust_remote_code=trust_remote_code,
)
if has_peft_model:
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
config = json.load(config_file)
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
# model.load_adapter(peft_model_id_or_path)
# model.enable_adapters()

## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
# model = model.cuda()
Expand Down

0 comments on commit aba56c1

Please sign in to comment.