Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create llamacpp_HF loader #3062

Merged
merged 12 commits into from
Jul 16, 2023
106 changes: 106 additions & 0 deletions modules/llamacpp_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

import llama_cpp
import numpy as np
import torch
from llama_cpp import Llama
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

from modules import shared
from modules.llamacpp_model import LlamaCppModel
from modules.logging_colors import logger


class LlamacppHF(PreTrainedModel):
def __init__(self, model):
super().__init__(PretrainedConfig())
self.model = model
self.generation_config = GenerationConfig()
self.cache = None

def _validate_model_class(self):
pass

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
pass

def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}

@property
def device(self) -> torch.device:
return torch.device(0)

def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None

# Make the forward call
seq_tensor = torch.tensor(seq)
self.cache = seq_tensor
if labels is None:
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
self.model.reset()
self.model.eval(seq)
else:
self.model.eval([seq[-1]])

logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device)
else:
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.eval_logits)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)

# Based on transformers/models/llama/modeling_llama.py
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)

path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
if path.is_file():
model_file = path
else:
model_file = list(path.glob('*ggml*.bin'))[0]

logger.info(f"llama.cpp weights detected: {model_file}\n")
params = {
'model_path': str(model_file),
'n_ctx': shared.args.n_ctx,
'seed': int(shared.args.llama_cpp_seed),
'n_threads': shared.args.threads or None,
'n_batch': shared.args.n_batch,
'use_mmap': not shared.args.no_mmap,
'use_mlock': shared.args.mlock,
'low_vram': shared.args.low_vram,
'n_gpu_layers': shared.args.n_gpu_layers,
'logits_all': True,
}

model = Llama(**params)
return LlamacppHF(model)
11 changes: 11 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@
'mlock',
'llama_cpp_seed',
],
'llamacpp_HF': [
'n_ctx',
'n_gpu_layers',
'n_batch',
'threads',
'no_mmap',
'low_vram',
'mlock',
'llama_cpp_seed',
'llamacpp_HF_info',
],
'Transformers': [
'cpu_memory',
'gpu_memory',
Expand Down
22 changes: 22 additions & 0 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def load_model(model_name, loader=None):
'AutoGPTQ': AutoGPTQ_loader,
'GPTQ-for-LLaMa': GPTQ_loader,
'llama.cpp': llamacpp_loader,
'llamacpp_HF': llamacpp_HF_loader,
'FlexGen': flexgen_loader,
'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader,
Expand Down Expand Up @@ -268,6 +269,27 @@ def llamacpp_loader(model_name):
return model, tokenizer


def llamacpp_HF_loader(model_name):
from modules.llamacpp_hf import LlamacppHF

for fname in ["oobabooga_llama-tokenizer", "llama-tokenizer"]:
path = Path(f'{shared.args.model_dir}/{fname}')
if path.exists():
break
else:
logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.")
return None, None

tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=shared.args.trust_remote_code,
use_fast=False
)

model = LlamacppHF.from_pretrained(model_name)
return model, tokenizer


def GPTQ_loader(model_name):

# Monkey patch
Expand Down
2 changes: 2 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def fix_loader_name(name):
name = name.lower()
if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']:
return 'llama.cpp'
if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']:
return 'llamacpp_HF'
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
return 'Transformers'
elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']:
Expand Down
3 changes: 2 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def create_model_menus():

with gr.Row():
with gr.Column():
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "AutoGPTQ", "llama.cpp", "ExLlama", "GPTQ-for-LLaMa"], value=None)
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "AutoGPTQ", "llama.cpp", "ExLlama", "llama.cpp_HF", "GPTQ-for-LLaMa"], value=None)
with gr.Box():
with gr.Row():
with gr.Column():
Expand Down Expand Up @@ -250,6 +250,7 @@ def create_model_menus():
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF is a wrapper that lets you use llama.cpp like a Transformers model, which means it can use the Transformers samplers. It works, but it\'s experimental and slow. Contributions are welcome.\n\nTo use it, make sure to first download oobabooga/llama-tokenizer under "Download custom model or LoRA".')

with gr.Column():
with gr.Row():
Expand Down