Skip to content

Commit

Permalink
Fix tokenization of chat interfaces (#1035)
Browse files Browse the repository at this point in the history
Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
vmpuri and Jack-Khuu authored Aug 19, 2024
1 parent 1566512 commit c7f56f2
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 35 deletions.
48 changes: 27 additions & 21 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from build.utils import device_sync

from generate import Generator, GeneratorArgs
Expand Down Expand Up @@ -222,7 +224,6 @@ def __init__(self, *args, **kwargs):
"""

super().__init__(*args, **kwargs)
self.start_pos = 0
self.max_seq_length = (
self.model.config.max_seq_length
+ self.speculative_builder_args.speculate_k
Expand Down Expand Up @@ -257,20 +258,25 @@ def chunked_completion(self, completion_request: CompletionRequest):
CompletionResponseChunk objects in response to completion_request as tokens are generated.
"""
device_sync(device=self.builder_args.device)

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

idx = 0
buffer = []
encoded = self.encode_tokens(
completion_request.messages[-1].get("content"),
bos=True,
device=self.builder_args.device,
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)

encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
print(self.tokenizer.decode(tokens))

start_pos = 0

generator_args = GeneratorArgs(
completion_request.messages[-1].get("content"),
None,
max_new_tokens=(
int(completion_request.max_tokens)
if completion_request.max_tokens
Expand All @@ -279,33 +285,39 @@ def chunked_completion(self, completion_request: CompletionRequest):
encoded_prompt=encoded,
temperature=float(completion_request.temperature),
chat_mode=False,
sequential_prefill=True,
)

def callback(x, *, done_generating=False):
return self._callback(
x,
buffer=buffer,
buffer=None,
done_generating=done_generating,
)

device_sync(device=self.builder_args.device)

# Process each token, metrics tuple yielded by Generator.generate.
for y, _ in self.generate(
self.model,
encoded,
generator_args.max_new_tokens,
model=self.model,
prompt=encoded,
max_new_tokens=generator_args.max_new_tokens,
draft_model=self.draft_model,
speculate_k=generator_args.speculate_k,
chat_mode=generator_args.chat_mode,
callback=callback,
temperature=generator_args.temperature,
top_k=generator_args.top_k,
sequential_prefill=generator_args.sequential_prefill,
start_pos=self.start_pos,
start_pos=start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
):
if y is None:
continue
elif y.item() == self.tokenizer.eos_id:
# Stop generation if the EOS token is generated.
break

# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
content = "".join(
Expand All @@ -330,7 +342,7 @@ def callback(x, *, done_generating=False):
system_fingerprint=self.system_fingerprint,
)
yield chunk_response
self.start_pos += y.size(0)
start_pos += y.size(0)
idx += 1

# Yield an ending chunk indicating the generation has completed.
Expand Down Expand Up @@ -369,10 +381,4 @@ def sync_completion(self, request: CompletionRequest):
)

def _callback(self, x, *, buffer, done_generating):
period_id = self.tokenizer.encode(".")[0]
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
if (
self.is_llama3_model
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
):
buffer = buffer[:-1] # drop the eot_id from the output buffer
pass
67 changes: 55 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import os
import textwrap
import time

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
Expand All @@ -28,24 +30,33 @@
from cli import add_arguments_for_verb, arg_init, check_args
from utils.device_info import get_device_info

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class ChatFormat:
class _ChatFormatter(ABC):
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def encode_header(self, message) -> List[int]:
@abstractmethod
def encode_dialog_prompt(self, dialog) -> List[int]:
raise NotImplementedError()


class Llama3ChatFormatter(_ChatFormatter):
"""Format a chat prompt using special tokens to demarcate roles and messages.
Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
"""

def encode_header(self, role) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens

def encode_message(self, message) -> List[int]:
tokens = self.encode_header(message)
tokens = self.encode_header(message.role)
tokens.extend(
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
)
Expand All @@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
return tokens


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class Llama2ChatFormatter(_ChatFormatter):
def encode_dialog_prompt(self, dialog) -> List[int]:
tokens = self.tokenizer.encode(f"{B_INST} ")
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
for message in dialog:
content = message["content"].strip()
if message["role"] == "system":
encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
first_message = False
elif message["role"] == "user":
encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
f"{B_INST if first_message else ''} {content} {E_INST} "
)
first_message = True
elif message["role"] == "assistant":
encoded = self.tokenizer.encode(f"{content}\n\n") + [
self.tokenizer.eos_id()
]
tokens += encoded
return tokens


@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
prompt: Optional[str] = (
None # When passed into the Generator, this will be used as the system prompt
)
encoded_prompt: Optional[torch.Tensor] = None
chat_mode: bool = False
gui_mode: bool = False
Expand Down Expand Up @@ -188,7 +227,7 @@ def __init__(
))
# fmt: on
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")

self.system_prompt = generator_args.prompt
self.tokenizer = _initialize_tokenizer(self.tokenizer_args)

# Right now the assumption is only llama3 uses tiktokenizer and it
Expand All @@ -200,6 +239,11 @@ def __init__(
logging.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
self.chat_formatter = (
Llama3ChatFormatter(self.tokenizer)
if self.is_llama3_model
else Llama2ChatFormatter(self.tokenizer)
)

self.builder_args.setup_caches = False
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
Expand Down Expand Up @@ -641,8 +685,7 @@ def chat(
)
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")
if self.is_llama3_model:
self.chat_formatter = ChatFormat(self.tokenizer)

else:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
Expand Down Expand Up @@ -685,7 +728,7 @@ def chat(
prompt, bos=True, device=self.builder_args.device
)
else:
if self.system_prompt is not None:
if self.system_prompt:
encoded = self.chat_formatter.encode_dialog_prompt(
[
{"role": "system", "content": self.system_prompt},
Expand Down
8 changes: 6 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

import json

import logging

logger = logging.getLogger(__name__)

from dataclasses import asdict
from typing import Dict, List, Union

Expand All @@ -21,7 +25,7 @@
OPENAI_API_VERSION = "v1"


def create_app(args):
def create_app(args): # noqa: C901
"""
Creates a flask app that can be used to serve the model as a chat API.
"""
Expand Down Expand Up @@ -69,7 +73,7 @@ def chunk_processor(chunked_completion_generator):
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="")
print(next_tok, end="", flush=True)
yield json.dumps(_del_none(asdict(chunk)))

return Response(
Expand Down

0 comments on commit c7f56f2

Please sign in to comment.