Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Support both xgrammar and outlines for constrained decoding #1752

Merged
merged 19 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8bc804c
feat(xgrammar): support xgrammar as one of the grammar backends
DarkSharpness Oct 19, 2024
cae33a9
fix: fix wrongly clearing the vocab_mask of outlines
DarkSharpness Oct 19, 2024
1b17c72
minor: fix the format by running pre-commit
DarkSharpness Oct 19, 2024
b23f632
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 20, 2024
d93f76e
fix: set the object to error when import failed
DarkSharpness Oct 21, 2024
ee43065
minor: set the default grammar backend as outlines
DarkSharpness Oct 21, 2024
652ef54
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 21, 2024
83d1502
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 22, 2024
5ce813c
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 23, 2024
b8648dd
refactor(constrained): add a new abstraction for constrained decoding
DarkSharpness Oct 23, 2024
e615ce3
minor(constrained): set import failure object as None to pass type check
DarkSharpness Oct 24, 2024
cd59ed0
fix(constrained): use DummyType to avoid type failure in 'isinstance'
DarkSharpness Oct 24, 2024
d01e7af
fix(constrained): fix wrong parameter order in initing bnf_cache
DarkSharpness Oct 24, 2024
e1de402
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 24, 2024
c07cd0d
minor: format the code using pre-commit
DarkSharpness Oct 24, 2024
8608c2b
fix(constrained): fix wrong jump-forward assertion
DarkSharpness Oct 25, 2024
cbdca83
minor: format the code using pre-commit
DarkSharpness Oct 25, 2024
bb0b28d
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 25, 2024
bed1f3d
Merge branch 'main' into xgrammar-outlines
merrymercy Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/sglang/srt/constrained/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def build_regex_from_object(
return build_regex_from_schema(schema, whitespace_pattern)


try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
except ImportError as e:
GrammarMatcher = e
GrammarMatcherInitContext = e
GrammarMatcherInitContextCache = e


__all__ = [
"RegexGuide",
"FSMInfo",
Expand All @@ -60,4 +72,7 @@ def build_regex_from_object(
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
"GrammarMatcher",
"GrammarMatcherInitContext",
"GrammarMatcherInitContextCache",
]
61 changes: 61 additions & 0 deletions python/sglang/srt/constrained/bnf_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Cache for the compressed finite state machine."""

from typing import Tuple

from transformers import AutoTokenizer

from sglang.srt.constrained import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)

MAX_ROLLBACK_TOKENS = 10


class BNFCache:
grammar_cache: GrammarMatcherInitContextCache

def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
):
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
if skip_tokenizer_init:
return

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
self.grammar_cache = GrammarMatcherInitContextCache(
tokenizer_or_vocab=tokenizer
)

def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_init_context_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError(f"regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")

def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
ctx = self.get_context(key)
return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
)
116 changes: 110 additions & 6 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained import GrammarMatcher, RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
Expand Down Expand Up @@ -247,8 +247,11 @@ def __init__(
self.embedding = None

# Constrained decoding
self.regex_fsm: RegexGuide = None
self.regex_fsm: Optional[RegexGuide] = None
self.regex_fsm_state: int = 0
self.regex_bnf: Optional[GrammarMatcher] = None
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

self.allow_jump_forward = False
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
self.jump_forward_map: JumpForwardMap = None

# For Qwen2-VL
Expand Down Expand Up @@ -353,7 +356,7 @@ def check_finished(self):
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return

def jump_forward_and_retokenize(self, jump_forward_str, next_state):
def jump_forward_and_retokenize_fsm(self, jump_forward_str, next_state):
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
Expand Down Expand Up @@ -410,6 +413,72 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state):

return True

def jump_forward_and_retokenize_bnf(self, jump_forward_str):
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
assert self.regex_bnf is not None, "should be a regex request"
assert self.tokenizer is not None, "should have a tokenizer"

if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)

all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text)
if not all_ids:
logger.warning("Encoded all_text resulted in empty all_ids")
return False

prompt_tokens = len(self.origin_input_ids_unpadded)
if prompt_tokens > len(all_ids):
logger.warning("prompt_tokens is larger than encoded all_ids")
return False

if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False

old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)

# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("�"):
self.surr_offset = self.read_offset - i
break

k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == self.output_ids[i]:
k = i + 1
else:
break

# rollback to the last token that is the same
if k < len(old_output_ids):
self.regex_bnf.rollback(len(old_output_ids) - k)

for i in range(k, len(self.output_ids)):
assert self.regex_bnf.accept_token(self.output_ids[i])

if self.return_logprob:
# For fast-forward part's logprobs
self.output_token_logprobs = self.output_token_logprobs[:k]
self.output_top_logprobs = self.output_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k

return True

def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "

Expand Down Expand Up @@ -486,7 +555,7 @@ def init_new(
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
has_regex=any(req.regex_fsm or req.regex_bnf for req in reqs),
device=req_to_token_pool.device,
)

Expand Down Expand Up @@ -794,6 +863,7 @@ def check_for_jump_forward(self, pad_input_ids_func):

for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None:
assert req.regex_fsm is not None and req.regex_bnf is None
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
req.regex_fsm_state
)
Expand Down Expand Up @@ -831,7 +901,7 @@ def check_for_jump_forward(self, pad_input_ids_func):
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
if not req.jump_forward_and_retokenize_fsm(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
Expand All @@ -852,6 +922,38 @@ def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs.append(req)
keep_indices.remove(i)

if req.allow_jump_forward:
assert req.regex_bnf is not None and req.regex_fsm is None
jump_forward_str = req.regex_bnf.find_jump_forward_string()
if len(jump_forward_str) > 1:
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
decode_res, new_text = req.get_next_inc_detokenization()
if not decode_res:
req.output_ids = cur_output_ids
continue

jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize_bnf(jump_forward_str):
# Failed to jump forward, revert
req.output_ids = cur_output_ids
continue

# The decode status has diverged from detokenizer_manager
req.vid += 1

# insert the old request into tree_cache
self.tree_cache.cache_finished_req(req, cur_all_ids)

# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)

jump_forward_reqs.append(req)
keep_indices.remove(i)

self.filter_batch(keep_indices=list(keep_indices))

return jump_forward_reqs
Expand Down Expand Up @@ -936,7 +1038,7 @@ def filter_batch(
self.top_logprobs_nums = None

self.has_stream = any(req.stream for req in self.reqs)
self.has_regex = any(req.regex_fsm for req in self.reqs)
self.has_regex = any(req.regex_fsm or req.regex_bnf for req in self.reqs)

self.sampling_info.filter_batch(keep_indices, new_indices)

Expand Down Expand Up @@ -981,11 +1083,13 @@ def get_model_worker_batch(self):

if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_bnfs = [req.regex_bnf for req in self.reqs]
self.sampling_info.regex_fsm_states = [
req.regex_fsm_state for req in self.reqs
]
else:
self.sampling_info.regex_fsms = None
self.sampling_info.regex_bnfs = None

global bid
bid += 1
Expand Down
Loading
Loading