Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable Chat Formats #711

Merged
merged 13 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from . import llama_cpp
from .llama_types import *
from .llama_grammar import LlamaGrammar
from . import llama_chat_format

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -243,6 +244,8 @@ def __init__(
lora_path: Optional[str] = None,
# Backend Params
numa: bool = False,
# Chat Format Params
chat_format: str = "llama-2",
# Misc
verbose: bool = True,
# Extra Params
Expand Down Expand Up @@ -273,6 +276,7 @@ def __init__(
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion.
verbose: Print verbose output to stderr.
kwargs: Unused keyword arguments (for additional backwards compatibility).

Expand Down Expand Up @@ -388,6 +392,8 @@ def __init__(

if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)

self.chat_format = chat_format

self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
Expand Down Expand Up @@ -1565,9 +1571,21 @@ def _convert_text_completion_chunks_to_chat(
],
}

def _convert_completion_to_chat(
self,
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
stream: bool = False,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)

def create_chat_completion(
self,
messages: List[ChatCompletionMessage],
messages: List[ChatCompletionRequestMessage],
functions: Optional[List[ChatCompletionFunction]] = None,
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
temperature: float = 0.2,
Expand Down Expand Up @@ -1602,26 +1620,28 @@ def create_chat_completion(
Returns:
Generated chat completion or a stream of chat completion chunks.
"""
stop = (
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
)
chat_history = "".join(
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages

format = llama_chat_format.get_chat_format(self.chat_format)
result = format(
messages=messages,
)
PROMPT = chat_history + "### Assistant:"
PROMPT_STOP = ["### Assistant:", "### Human:"]
completion_or_chunks = self(
prompt=PROMPT,
stop=PROMPT_STOP + stop,
prompt = result.prompt
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop

completion_or_chunks = self.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
stop=stop,
max_tokens=max_tokens,
repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
Expand All @@ -1630,12 +1650,7 @@ def create_chat_completion(
logits_processor=logits_processor,
grammar=grammar,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore

def __del__(self):
if hasattr(self, "model") and self.model is not None:
Expand Down Expand Up @@ -1675,6 +1690,8 @@ def __getstate__(self):
lora_path=self.lora_path,
# Backend Params
numa=self.numa,
# Chat Format Params
chat_format=self.chat_format,
# Misc
verbose=self.verbose,
)
Expand Down Expand Up @@ -1708,6 +1725,8 @@ def __setstate__(self, state):
lora_path=state["lora_path"],
# Backend Params
numa=state["numa"],
# Chat Format Params
chat_format=state["chat_format"],
# Misc
verbose=state["verbose"],
)
Expand Down
Loading