Skip to content

Commit

Permalink
FEAT: Supports LoRA for LLM and image models (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Mar 8, 2024
1 parent 02f735c commit 45a8625
Show file tree
Hide file tree
Showing 22 changed files with 228 additions and 14 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ install_requires =
aioprometheus[starlette]>=23.12.0
pynvml
async-timeout
peft

[options.packages.find]
exclude =
Expand Down Expand Up @@ -114,6 +115,7 @@ transformers =
tiktoken
auto-gptq
optimum
peft
vllm =
vllm>=0.2.6,<0.3.1
embedding =
Expand Down
9 changes: 9 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,9 @@ async def launch_model(
replica = payload.get("replica", 1)
n_gpu = payload.get("n_gpu", "auto")
request_limits = payload.get("request_limits", None)
peft_model_path = payload.get("peft_model_path", None)
image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)

exclude_keys = {
"model_uid",
Expand All @@ -667,6 +670,9 @@ async def launch_model(
"replica",
"n_gpu",
"request_limits",
"peft_model_path",
"image_lora_load_kwargs",
"image_lora_fuse_kwargs",
}

kwargs = {
Expand All @@ -691,6 +697,9 @@ async def launch_model(
n_gpu=n_gpu,
request_limits=request_limits,
wait_ready=wait_ready,
peft_model_path=peft_model_path,
image_lora_load_kwargs=image_lora_load_kwargs,
image_lora_fuse_kwargs=image_lora_fuse_kwargs,
**kwargs,
)

Expand Down
12 changes: 12 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,9 @@ def launch_model(
replica: int = 1,
n_gpu: Optional[Union[int, str]] = "auto",
request_limits: Optional[int] = None,
peft_model_path: Optional[str] = None,
image_lora_load_kwargs: Optional[Dict] = None,
image_lora_fuse_kwargs: Optional[Dict] = None,
**kwargs,
) -> str:
"""
Expand Down Expand Up @@ -818,6 +821,12 @@ def launch_model(
request_limits: Optional[int]
The number of request limits for this model, default is None.
``request_limits=None`` means no limits for this model.
peft_model_path: Optional[str]
PEFT (Parameter-Efficient Fine-Tuning) model path.
image_lora_load_kwargs: Optional[Dict]
lora load parameters for image model
image_lora_fuse_kwargs: Optional[Dict]
lora fuse parameters for image model
**kwargs:
Any other parameters been specified.
Expand All @@ -840,6 +849,9 @@ def launch_model(
"replica": replica,
"n_gpu": n_gpu,
"request_limits": request_limits,
"peft_model_path": peft_model_path,
"image_lora_load_kwargs": image_lora_load_kwargs,
"image_lora_fuse_kwargs": image_lora_fuse_kwargs,
}

for key, value in kwargs.items():
Expand Down
6 changes: 6 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,9 @@ async def launch_builtin_model(
request_limits: Optional[int] = None,
wait_ready: bool = True,
model_version: Optional[str] = None,
peft_model_path: Optional[str] = None,
image_lora_load_kwargs: Optional[Dict] = None,
image_lora_fuse_kwargs: Optional[Dict] = None,
**kwargs,
) -> str:
if model_uid is None:
Expand Down Expand Up @@ -751,6 +754,9 @@ async def _launch_one_model(_replica_model_uid):
model_type=model_type,
n_gpu=n_gpu,
request_limits=request_limits,
peft_model_path=peft_model_path,
image_lora_load_kwargs=image_lora_load_kwargs,
image_lora_fuse_kwargs=image_lora_fuse_kwargs,
**kwargs,
)
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
Expand Down
16 changes: 16 additions & 0 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ async def launch_builtin_model(
quantization: Optional[str],
model_type: str = "LLM",
n_gpu: Optional[Union[int, str]] = "auto",
peft_model_path: Optional[str] = None,
image_lora_load_kwargs: Optional[Dict] = None,
image_lora_fuse_kwargs: Optional[Dict] = None,
request_limits: Optional[int] = None,
**kwargs,
):
Expand All @@ -516,6 +519,16 @@ async def launch_builtin_model(
if isinstance(n_gpu, str) and n_gpu != "auto":
raise ValueError("Currently `n_gpu` only supports `auto`.")

if peft_model_path is not None:
if model_type in ("embedding", "rerank"):
raise ValueError(
f"PEFT adaptors cannot be applied to embedding or rerank models."
)
if model_type == "LLM" and model_format in ("ggufv2", "ggmlv3"):
raise ValueError(
f"PEFT adaptors can only be applied to pytorch-like models"
)

assert model_uid not in self._model_uid_to_model
self._check_model_is_valid(model_name, model_format)
assert self._supervisor_ref is not None
Expand All @@ -537,6 +550,9 @@ async def launch_builtin_model(
model_format,
model_size_in_billions,
quantization,
peft_model_path,
image_lora_load_kwargs,
image_lora_fuse_kwargs,
is_local_deployment,
**kwargs,
)
Expand Down
39 changes: 38 additions & 1 deletion xinference/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import sys
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import click
from xoscar.utils import get_next_port
Expand Down Expand Up @@ -596,6 +596,26 @@ def list_model_registrations(
type=str,
help='The number of GPUs used by the model, default is "auto".',
)
@click.option(
"--peft-model-path",
default=None,
type=str,
help="PEFT model path.",
)
@click.option(
"--image-lora-load-kwargs",
"-ld",
"image_lora_load_kwargs",
type=(str, str),
multiple=True,
)
@click.option(
"--image-lora-fuse-kwargs",
"-fd",
"image_lora_fuse_kwargs",
type=(str, str),
multiple=True,
)
@click.option(
"--trust-remote-code",
default=True,
Expand All @@ -614,6 +634,9 @@ def model_launch(
quantization: str,
replica: int,
n_gpu: str,
peft_model_path: Optional[str],
image_lora_load_kwargs: Optional[Tuple],
image_lora_fuse_kwargs: Optional[Tuple],
trust_remote_code: bool,
):
kwargs = {}
Expand All @@ -630,6 +653,17 @@ def model_launch(
else:
_n_gpu = int(n_gpu)

image_lora_load_params = (
{k: handle_click_args_type(v) for k, v in dict(image_lora_load_kwargs).items()}
if image_lora_load_kwargs
else None
)
image_lora_fuse_params = (
{k: handle_click_args_type(v) for k, v in dict(image_lora_fuse_kwargs).items()}
if image_lora_fuse_kwargs
else None
)

endpoint = get_endpoint(endpoint)
model_size: Optional[Union[str, int]] = (
size_in_billions
Expand All @@ -648,6 +682,9 @@ def model_launch(
quantization=quantization,
replica=replica,
n_gpu=_n_gpu,
peft_model_path=peft_model_path,
image_lora_load_kwargs=image_lora_load_params,
image_lora_fuse_kwargs=image_lora_fuse_params,
trust_remote_code=trust_remote_code,
**kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion xinference/deploy/docker/cpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ RUN python -m pip install --upgrade -i "$PIP_INDEX" pip && \
controlnet_aux \
orjson \
auto-gptq \
optimum && \
optimum \
peft && \
pip install -i "$PIP_INDEX" -U chatglm-cpp && \
pip install -i "$PIP_INDEX" -U llama-cpp-python && \
cd /opt/inference && \
Expand Down
15 changes: 13 additions & 2 deletions xinference/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from .._compat import BaseModel

Expand Down Expand Up @@ -52,6 +52,9 @@ def create_model_instance(
model_format: Optional[str] = None,
model_size_in_billions: Optional[int] = None,
quantization: Optional[str] = None,
peft_model_path: Optional[str] = None,
image_lora_load_kwargs: Optional[Dict] = None,
image_lora_fuse_kwargs: Optional[Dict] = None,
is_local_deployment: bool = False,
**kwargs,
) -> Tuple[Any, ModelDescription]:
Expand All @@ -70,6 +73,7 @@ def create_model_instance(
model_format,
model_size_in_billions,
quantization,
peft_model_path,
is_local_deployment,
**kwargs,
)
Expand All @@ -82,7 +86,14 @@ def create_model_instance(
elif model_type == "image":
kwargs.pop("trust_remote_code", None)
return create_image_model_instance(
subpool_addr, devices, model_uid, model_name, **kwargs
subpool_addr,
devices,
model_uid,
model_name,
lora_model_path=peft_model_path,
lora_load_kwargs=image_lora_load_kwargs,
lora_fuse_kwargs=image_lora_fuse_kwargs,
**kwargs,
)
elif model_type == "rerank":
kwargs.pop("trust_remote_code", None)
Expand Down
18 changes: 16 additions & 2 deletions xinference/model/image/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,14 @@ def get_cache_status(


def create_image_model_instance(
subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
subpool_addr: str,
devices: List[str],
model_uid: str,
model_name: str,
lora_model_path: Optional[str] = None,
lora_load_kwargs: Optional[Dict] = None,
lora_fuse_kwargs: Optional[Dict] = None,
**kwargs,
) -> Tuple[DiffusionModel, ImageModelDescription]:
model_spec = match_diffusion(model_name)
controlnet = kwargs.get("controlnet")
Expand Down Expand Up @@ -187,7 +194,14 @@ def create_image_model_instance(
else:
kwargs["controlnet"] = controlnet_model_paths
model_path = cache(model_spec)
model = DiffusionModel(model_uid, model_path, **kwargs)
model = DiffusionModel(
model_uid,
model_path,
lora_model_path=lora_model_path,
lora_load_kwargs=lora_load_kwargs,
lora_fuse_kwargs=lora_fuse_kwargs,
**kwargs,
)
model_description = ImageModelDescription(
subpool_addr, devices, model_spec, model_path=model_path
)
Expand Down
27 changes: 25 additions & 2 deletions xinference/model/image/stable_diffusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

from ....constants import XINFERENCE_IMAGE_DIR
from ....device_utils import move_model_to_available_device
Expand All @@ -32,14 +32,36 @@

class DiffusionModel:
def __init__(
self, model_uid: str, model_path: str, device: Optional[str] = None, **kwargs
self,
model_uid: str,
model_path: str,
device: Optional[str] = None,
lora_model_path: Optional[str] = None,
lora_load_kwargs: Optional[Dict] = None,
lora_fuse_kwargs: Optional[Dict] = None,
**kwargs,
):
self._model_uid = model_uid
self._model_path = model_path
self._device = device
self._model = None
self._lora_model_path = lora_model_path
self._lora_load_kwargs = lora_load_kwargs or {}
self._lora_fuse_kwargs = lora_fuse_kwargs or {}
self._kwargs = kwargs

def _apply_lora(self):
if self._lora_model_path is not None:
logger.info(
f"Loading the LoRA with load kwargs: {self._lora_load_kwargs}, fuse kwargs: {self._lora_fuse_kwargs}."
)
assert self._model is not None
self._model.load_lora_weights(
self._lora_model_path, **self._lora_load_kwargs
)
self._model.fuse_lora(**self._lora_fuse_kwargs)
logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")

def load(self):
# import torch
from diffusers import AutoPipelineForText2Image
Expand All @@ -61,6 +83,7 @@ def load(self):
self._model = move_model_to_available_device(self._model)
# Recommended if your computer has < 64 GB of RAM
self._model.enable_attention_slicing()
self._apply_lora()

def _call_model(
self,
Expand Down
17 changes: 17 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
BUILTIN_LLM_PROMPT_STYLE,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
LLM_CLASSES,
PEFT_SUPPORTED_CLASSES,
CustomLLMFamilyV1,
GgmlLLMSpecV1,
LLMFamilyV1,
Expand Down Expand Up @@ -95,6 +96,22 @@ def _install():
PytorchModel,
]
)
PEFT_SUPPORTED_CLASSES.extend(
[
BaichuanPytorchChatModel,
VicunaPytorchChatModel,
FalconPytorchChatModel,
ChatglmPytorchChatModel,
LlamaPytorchModel,
LlamaPytorchChatModel,
PytorchChatModel,
FalconPytorchModel,
Internlm2PytorchChatModel,
QwenVLChatModel,
YiVLChatModel,
PytorchModel,
]
)

json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
Expand Down
Loading

0 comments on commit 45a8625

Please sign in to comment.