Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JSON Mode] Constrained Sampling #175

Merged
merged 21 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions serve/benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def create_request(request_id):
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
logprobs=args.sampling_setting["logprobs"],
top_logprobs=args.sampling_setting["top_logprobs"],
json_schema=args.sampling_setting["json_schema"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down
5 changes: 3 additions & 2 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def run_mlc(engine, requests, args) -> float:
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
logprobs=args.sampling_setting["logprobs"],
top_logprobs=args.sampling_setting["top_logprobs"],
json_schema=args.sampling_setting["json_schema"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down
19 changes: 17 additions & 2 deletions serve/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""Utils for benchmark scripts"""
from pydantic import BaseModel


class Output(BaseModel):
answer: str


def add_sampling_flags(parser):
Expand All @@ -17,6 +22,11 @@ def add_sampling_flags(parser):
action="store_true",
help="Apply top-p and top-k.",
)
parser.add_argument(
"--apply-json-mode",
action="store_true",
help="Apply json mode.",
)
parser.add_argument(
"--apply-all-sampling-params",
action="store_true",
Expand All @@ -26,13 +36,13 @@ def add_sampling_flags(parser):
"--logprobs",
action="store_true",
default=False,
help="Switch on logprobs output"
help="Switch on logprobs output",
)
parser.add_argument(
"--top-logprobs",
type=int,
default=5,
help="Number of top logprobs to output, limited by 5. Works only with logprobs true."
help="Number of top logprobs to output, limited by 5. Works only with logprobs true.",
)


Expand All @@ -47,12 +57,14 @@ def postproc_sampling_args(args):
"top_k": -1,
"logprobs": False,
"top_logprobs": 5,
"json_schema": None,
}

if args.apply_all_sampling_params:
args.apply_penalties = True
args.apply_logit_bias = True
args.apply_top_p_top_k = True
args.apply_json_mode = True

if args.apply_penalties:
args.sampling_setting["presence_penalty"] = 0.7
Expand All @@ -69,3 +81,6 @@ def postproc_sampling_args(args):
if args.logprobs:
args.sampling_setting["logprobs"] = True
args.sampling_setting["top_logprobs"] = args.top_logprobs

if args.apply_json_mode:
args.sampling_setting["json_schema"] = Output.model_json_schema()
3 changes: 2 additions & 1 deletion serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _get_sampling_params(
if request.logprobs:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs

if request.response_format and request.response_format.type == "json_object":
sampling_params.json_schema = request.response_format.response_schema
sampling_params.vocab_size = model_artifact_config.vocab_size
return sampling_params

Expand Down
12 changes: 10 additions & 2 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union, Any

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -58,9 +58,16 @@ class ChatMessage(BaseModel):
content: str


class ChatResponseFormat(BaseModel):
type: str
response_schema: Optional[Dict[str, Any]] = Field(None, alias="schema")


class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[ChatMessage]] # according to openai chat completion spec, here should be only a list of ChatMessage
messages: Union[
str, List[ChatMessage]
] # according to openai chat completion spec, here should be only a list of ChatMessage
max_tokens: Optional[int] = None
temperature: float = 1.0
top_p: float = 1.0
Expand All @@ -75,6 +82,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos: Optional[bool] = False
logprobs: bool = False
top_logprobs: int = 0
response_format: Optional[ChatResponseFormat] = None


class ChatCompletionResponseChoice(BaseModel):
Expand Down
94 changes: 94 additions & 0 deletions serve/mlc_serve/engine/constrained_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import math
from collections import defaultdict
from typing import DefaultDict, List

import torch

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from .base import SequenceId

class RegexLogitsProcessor:
def __init__(self, regex_string, tokenizer):
"""Compile the FSM that drives the regex-guided generation.

Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
An instance of `tokenizer`

"""
tokenizer = self.adapt_tokenizer(tokenizer)

fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
self.fsm_state: DefaultDict[SequenceId, int] = defaultdict(int)

def __call__(
self, seq_id: SequenceId, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state = defaultdict(int)
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema, tokenizer):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
An instance of `tokenizer`

"""
if isinstance(schema, dict):
schema = json.dumps(schema)
regex_string = build_regex_from_object(schema)
super().__init__(regex_string, tokenizer)

11 changes: 10 additions & 1 deletion serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs
from .constrained_sampling import JSONLogitsProcessor

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -240,7 +241,9 @@ def prepare_output(


def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
current_states: list[RequestState],
cache_manager: KVCacheManager,
tokenizer: TokenizerP,
) -> Tuple[list[RequestType], bool, int]:
requests: list[RequestType] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
Expand Down Expand Up @@ -289,6 +292,12 @@ def get_requests_to_process(
# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# `JSONLogitsProcessor` needs to be created only once.
if state.sampling_params.json_schema is not None:
state.sampling_params.logits_processor = JSONLogitsProcessor(
state.sampling_params.json_schema, tokenizer._tokenizer
)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Required interfaces for the actual inference capability in InferenceEngine.
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List, Sequence
from typing import Optional, Protocol, Union, List, Sequence, Any

from .base import (
ChatMessage,
Expand Down Expand Up @@ -168,6 +168,7 @@ def generate(


class Tokenizer(Protocol):
_tokenizer: Any
eos_token_id: int
skip_special_tokens: bool
all_special_ids: List[int]
Expand Down
4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from enum import IntEnum
from functools import cached_property
from typing import Dict, Optional
from typing import Dict, Optional, Any

_SAMPLING_EPS = 1e-5
LOGPROB_TOP_K_MAX = 5
Expand Down Expand Up @@ -73,6 +73,8 @@ class SamplingParams:
# Currently, it is unclear what is the best way to fetch this info and
# check in `_verify_args` without this field. Follow-up when we have a better idea.
vocab_size = 32000
json_schema: Optional[Dict[str, Any]] = None
logits_processor: Optional[Any] = None

def __post_init__(self):
if self.logit_bias:
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _adjust_batch(self):

def _get_requests_to_process(self):
requests, is_prompt_batch, token_counts = get_requests_to_process(
self.current_batch.values(), self.cache_manager
self.current_batch.values(), self.cache_manager, self.tokenizer
)

if is_prompt_batch:
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def step(self) -> InferenceStepResult:
return InferenceStepResult(outputs)

requests, _, _ = get_requests_to_process(
list(self.current_batch.values()), self.cache_manager
list(self.current_batch.values()), self.cache_manager, self.tokenizer
)
results = self.text_generator.generate(requests, self.cache_manager.get_cache())
logger.debug("Finished text generation.")
Expand Down
12 changes: 12 additions & 0 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ..engine.model_module import (
PrefillRequest,
DecodeRequest,
EvalMultiQueryRequest,
RequestType,
TextGenerationResult,
Expand Down Expand Up @@ -97,6 +98,17 @@ def sample_from_logits(
# synchronization point for sampling tensors
# wait until all the tensors are loaded on GPU
torch.cuda.current_stream().wait_stream(copy_stream)

# Logit processing for constraint sampling e.g., JSON Mode
for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)):
if request.sampling_params.logits_processor is not None:
cs_input_ids = (
request.token_ids if isinstance(request, DecodeRequest) else []
)
logits[i] = request.sampling_params.logits_processor(
sequence_id, cs_input_ids, logits[i]
)

logits = adjust_logits(logits, sampling_metadata, vocab_size)
outputs: List[TextGenerationResult] = []

Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class HfTokenizerModule:
def __init__(self, model_artifact_path: Path):
hf_tokenizer = AutoTokenizer.from_pretrained(
model_artifact_path.joinpath("model"),
revision=None, tokenizer_revision=None,
trust_remote_code=False,
)
self.tokenizer = Tokenizer(hf_tokenizer)
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def init_tvm_model(
except tvm.error.InternalError:
raise RuntimeError(
f"Memory profiling failed with max_num_batched_tokens = "
"{engine_config.max_num_batched_tokens}."
"{engine_config.max_num_batched_tokens}."
)
else:
num_blocks = 500
Expand Down
3 changes: 3 additions & 0 deletions serve/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ python = ">=3.9"
fastapi = ">=0.103.1"
pydantic = ">=1.8.0"
prometheus-client = ">=0.18.0"
outlines = "0.0.23"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"
httpx_sse = "^0.3.1"
pytest-timeout = "^2.2.0"
cuda-python = "12.3.0"
pandas = "2.2.0"

[tool.setuptools]
packages = ["mlc_serve"]
Expand Down
Loading
Loading