From 8bc804cdbc7715504eb942461ce2d8374754c617 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sat, 19 Oct 2024 12:12:46 +0000 Subject: [PATCH 01/13] feat(xgrammar): support xgrammar as one of the grammar backends --- python/sglang/srt/constrained/__init__.py | 16 +++ python/sglang/srt/constrained/bnf_cache.py | 61 +++++++++ python/sglang/srt/managers/schedule_batch.py | 116 +++++++++++++++++- python/sglang/srt/managers/scheduler.py | 89 ++++++++++---- .../srt/sampling/sampling_batch_info.py | 38 +++++- python/sglang/srt/server_args.py | 8 ++ 6 files changed, 293 insertions(+), 35 deletions(-) create mode 100644 python/sglang/srt/constrained/bnf_cache.py diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index c47c5c8dd58..c11a3dc12f4 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -51,6 +51,19 @@ def build_regex_from_object( return build_regex_from_schema(schema, whitespace_pattern) +try: + from xgrammar import ( + GrammarMatcher, + GrammarMatcherInitContext, + GrammarMatcherInitContextCache, + ) +except ImportError as e: + print( + f'\nError: {e}. Please install a new version of xgrammar by `pip install "xgrammar>=0.0.12"`\n' + ) + raise + + __all__ = [ "RegexGuide", "FSMInfo", @@ -60,4 +73,7 @@ def build_regex_from_object( "disk_cache", "disable_cache", "make_byte_level_fsm", + "GrammarMatcher", + "GrammarMatcherInitContext", + "GrammarMatcherInitContextCache", ] diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py new file mode 100644 index 00000000000..9b7ac7d4126 --- /dev/null +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -0,0 +1,61 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Cache for the compressed finite state machine.""" + +from transformers import AutoTokenizer +from typing import Tuple +from sglang.srt.constrained import ( + GrammarMatcher, + GrammarMatcherInitContext, + GrammarMatcherInitContextCache, +) + +MAX_ROLLBACK_TOKENS = 10 + + +class BNFCache(): + grammar_cache : GrammarMatcherInitContextCache + + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + skip_tokenizer_init=False, + whitespace_patterns=None, + ): + # TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init + if skip_tokenizer_init: + return + + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, **tokenizer_args_dict + ) + self.grammar_cache = GrammarMatcherInitContextCache( + tokenizer_or_vocab=tokenizer + ) + + def get_context(self, key : Tuple[str, str]) -> GrammarMatcherInitContext: + key_type, key_string = key + if key_type == "json": + return self.grammar_cache.get_init_context_for_json_schema(key_string) + elif key_type == "regex": + raise ValueError(f"regex hasn't been supported by xgrammar yet") + else: + raise ValueError(f"Invalid key_type: {key_type}") + + def query(self, key : Tuple[str, str], vocab_size: int) -> GrammarMatcher: + ctx = self.get_context(key) + return GrammarMatcher( + ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size + ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8e55fb1d74b..dd556f5d4ed 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -34,7 +34,7 @@ import torch from sglang.global_config import global_config -from sglang.srt.constrained import RegexGuide +from sglang.srt.constrained import GrammarMatcher, RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache @@ -232,8 +232,11 @@ def __init__( self.embedding = None # Constrained decoding - self.regex_fsm: RegexGuide = None + self.regex_fsm: Optional[RegexGuide] = None self.regex_fsm_state: int = 0 + self.regex_bnf: Optional[GrammarMatcher] = None + + self.allow_jump_forward = False self.jump_forward_map: JumpForwardMap = None # whether request reached finished condition @@ -335,7 +338,7 @@ def check_finished(self): self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return - def jump_forward_and_retokenize(self, jump_forward_str, next_state): + def jump_forward_and_retokenize_fsm(self, jump_forward_str, next_state): if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( @@ -392,6 +395,72 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): return True + def jump_forward_and_retokenize_bnf(self, jump_forward_str): + assert self.regex_bnf is not None, "should be a regex request" + assert self.tokenizer is not None, "should have a tokenizer" + + if self.origin_input_text is None: + # Recovering text can only use unpadded ids + self.origin_input_text = self.tokenizer.decode( + self.origin_input_ids_unpadded + ) + + all_text = self.origin_input_text + self.decoded_text + jump_forward_str + all_ids = self.tokenizer.encode(all_text) + if not all_ids: + logger.warning("Encoded all_text resulted in empty all_ids") + return False + + prompt_tokens = len(self.origin_input_ids_unpadded) + if prompt_tokens > len(all_ids): + logger.warning("prompt_tokens is larger than encoded all_ids") + return False + + if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: + # TODO(lsyin): fix token fusion + logger.warning( + "Token fusion between input and output, try to avoid this by removing the space at the end of the input." + ) + return False + + old_output_ids = self.output_ids + self.output_ids = all_ids[prompt_tokens:] + self.decoded_text = self.decoded_text + jump_forward_str + self.surr_offset = prompt_tokens + self.read_offset = len(all_ids) + + # NOTE: A trick to reduce the surrouding tokens decoding overhead + for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET): + surr_text_ = self.tokenizer.decode( + all_ids[self.read_offset - i : self.read_offset] + ) + if not surr_text_.endswith("�"): + self.surr_offset = self.read_offset - i + break + + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == self.output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.regex_bnf.rollback(len(old_output_ids) - k) + + for i in range(k, len(self.output_ids)): + assert self.regex_bnf.accept_token(self.output_ids[i]) + + if self.return_logprob: + # For fast-forward part's logprobs + self.output_token_logprobs = self.output_token_logprobs[:k] + self.output_top_logprobs = self.output_top_logprobs[:k] + self.logprob_start_len = prompt_tokens + k + self.last_update_decode_tokens = len(self.output_ids) - k + + return True + def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " @@ -444,7 +513,7 @@ class ScheduleBatch: def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) has_stream = any(req.stream for req in reqs) - has_regex = any(req.regex_fsm for req in reqs) + has_regex = any(req.regex_fsm or req.regex_bnf for req in reqs) return cls( reqs=reqs, @@ -684,6 +753,7 @@ def check_for_jump_forward(self, pad_input_ids_func): for i, req in enumerate(self.reqs): if req.jump_forward_map is not None: + assert req.regex_fsm is not None and req.regex_bnf is None jump_forward_bytes = req.jump_forward_map.jump_forward_byte( req.regex_fsm_state ) @@ -721,7 +791,7 @@ def check_for_jump_forward(self, pad_input_ids_func): # Make the incrementally decoded text part of jump_forward_str # so that the UTF-8 will not corrupt jump_forward_str = new_text + jump_forward_str - if not req.jump_forward_and_retokenize( + if not req.jump_forward_and_retokenize_fsm( jump_forward_str, next_state ): req.output_ids = cur_output_ids @@ -742,6 +812,38 @@ def check_for_jump_forward(self, pad_input_ids_func): jump_forward_reqs.append(req) keep_indices.remove(i) + if req.allow_jump_forward: + assert req.regex_bnf is not None and req.regex_fsm is None + jump_forward_str = req.regex_bnf.find_jump_forward_string() + if len(jump_forward_str) > 1: + cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] + cur_output_ids = req.output_ids + decode_res, new_text = req.get_next_inc_detokenization() + if not decode_res: + req.output_ids = cur_output_ids + continue + + jump_forward_str = new_text + jump_forward_str + if not req.jump_forward_and_retokenize_bnf(jump_forward_str): + # Failed to jump forward, revert + req.output_ids = cur_output_ids + continue + + # The decode status has diverged from detokenizer_manager + req.vid += 1 + + # insert the old request into tree_cache + self.tree_cache.cache_finished_req(req, cur_all_ids) + + # re-applying image padding + if req.image_inputs is not None: + req.origin_input_ids = pad_input_ids_func( + req.origin_input_ids_unpadded, req.image_inputs + ) + + jump_forward_reqs.append(req) + keep_indices.remove(i) + self.filter_batch(keep_indices=list(keep_indices)) return jump_forward_reqs @@ -802,7 +904,7 @@ def filter_batch( self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) - self.has_regex = any(req.regex_fsm for req in self.reqs) + self.has_regex = any(req.regex_fsm or req.regex_bnf for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) @@ -845,11 +947,13 @@ def get_model_worker_batch(self): lora_paths = [req.lora_path for req in self.reqs] if self.has_regex: self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] + self.sampling_info.regex_bnfs = [req.regex_bnf for req in self.reqs] self.sampling_info.regex_fsm_states = [ req.regex_fsm_state for req in self.reqs ] else: self.sampling_info.regex_fsms = None + self.sampling_info.regex_bnfs = None global bid bid += 1 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 16c43dd1694..d8a6e3247fa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -29,6 +29,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.bnf_cache import BNFCache from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -211,16 +212,31 @@ def __init__( ) # Init the FSM cache for constrained generation + self.regex_fsm_cache = None + self.regex_bnf_cache = None + if not server_args.skip_tokenizer_init: - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, - ) + if server_args.grammar_backend == "xgrammar": + self.regex_bnf_cache = BNFCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + whitespace_patterns=server_args.constrained_json_whitespace_pattern, + ) + else: + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) + self.jump_forward_cache = JumpForwardCache() # Init new token estimation @@ -401,23 +417,40 @@ def handle_generate_request( # By default, only return the logprobs for output tokens req.logprob_start_len = len(recv_req.input_ids) - 1 - # Init regex FSM + # Init regex FSM or BNF if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): - if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("json", req.sampling_params.json_schema) - ) - elif req.sampling_params.regex is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("regex", req.sampling_params.regex) - ) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) + if self.regex_fsm_cache is not None: + # FSM cache + assert self.regex_bnf_cache is None + if req.sampling_params.json_schema is not None: + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("json", req.sampling_params.json_schema) + ) + elif req.sampling_params.regex is not None: + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("regex", req.sampling_params.regex) + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) + else: + # BNF cache + assert self.regex_bnf_cache is not None + vocab_size = self.model_config.vocab_size + if req.sampling_params.json_schema is not None: + req.regex_bnf = self.regex_bnf_cache.query( + ("json", req.sampling_params.json_schema), vocab_size + ) + elif req.sampling_params.regex is not None: + req.regex_bnf = self.regex_bnf_cache.query( + ("regex", req.sampling_params.regex), vocab_size + ) + if not self.disable_regex_jump_forward: + req.allow_jump_forward = True # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -789,9 +822,13 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): self.tree_cache.cache_unfinished_req(req) if req.regex_fsm is not None: + assert req.regex_bnf is None req.regex_fsm_state = req.regex_fsm.get_next_state( req.regex_fsm_state, next_token_ids[i] ) + if req.regex_bnf is not None: + assert req.regex_fsm is None + assert req.regex_bnf.accept_token(next_token_ids[i]) if req.return_logprob: logprob_pt += self.add_logprob_return_values( @@ -845,9 +882,13 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.check_finished() if req.regex_fsm is not None: + assert req.regex_bnf is None req.regex_fsm_state = req.regex_fsm.get_next_state( req.regex_fsm_state, next_token_id ) + if req.regex_bnf is not None: + assert req.regex_fsm is None + assert req.regex_bnf.accept_token(next_token_id) if req.finished(): self.cache_finished_req(req) @@ -1043,7 +1084,9 @@ def flush_cache(self): ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_fsm_cache.reset() + if self.regex_fsm_cache is not None: + self.regex_fsm_cache.reset() + # TODO(dark): reset the bnf cache self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache() diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index e6a593fccc7..f0722db2914 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -6,7 +6,7 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib -from sglang.srt.constrained import RegexGuide +from sglang.srt.constrained import GrammarMatcher, RegexGuide if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -29,12 +29,15 @@ class SamplingBatchInfo: # Bias Tensors vocab_size: int logit_bias: torch.Tensor = None - vocab_mask: torch.Tensor = None + vocab_mask: Optional[torch.Tensor] = None # FSM states - regex_fsms: List[RegexGuide] = None + regex_fsms: Optional[List[Optional[RegexGuide]]] = None regex_fsm_states: List[int] = None + # BNF states + regex_bnfs: Optional[List[Optional[GrammarMatcher]]] = None + # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None linear_penalties: Optional[torch.Tensor] = None @@ -135,9 +138,8 @@ def update_penalties(self): ) self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self): - has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) - if not has_regex: + def _update_regex_vocab_mask_fsm(self): + if not self.regex_fsms or not any(regex_fsm for regex_fsm in self.regex_fsms): self.vocab_mask = None return @@ -154,6 +156,30 @@ def update_regex_vocab_mask(self): regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens ] = 0 + def _update_regex_vocab_mask_bnf(self): + if not self.regex_bnfs or not any(regex_bnf for regex_bnf in self.regex_bnfs): + self.vocab_mask = None + return + + self.vocab_mask = torch.zeros( + len(self.temperatures), + self.vocab_size, + dtype=torch.bool, + device=self.device, + ) + for i, regex_bnf in enumerate(self.regex_bnfs): + if regex_bnf is not None: + # Note that this bitmask is a bitset, not bool + bitmask = regex_bnf.find_next_token_bitmask() + # Mask the tokens that are not allowed + self.vocab_mask[i][ + regex_bnf.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) + ] = 1 + + def update_regex_vocab_mask(self): + self._update_regex_vocab_mask_fsm() + self._update_regex_vocab_mask_bnf() + def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: self.penalizer_orchestrator.filter(unfinished_indices, new_indices) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 722e30f6be4..a97746b5276 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -102,6 +102,7 @@ class ServerArgs: # Kernel backend attention_backend: Optional[str] = None sampling_backend: Optional[str] = None + grammar_backend: str = "xgrammar" # Optimization/debug options disable_flashinfer: bool = False @@ -527,6 +528,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.sampling_backend, help="Choose the kernels for sampling layers.", ) + parser.add_argument( + "--grammar-backend", + type=str, + choices=["xgrammar", "outlines"], + default=ServerArgs.grammar_backend, + help="Choose the backend for constrained decoding.", + ) # Optimization/debug options parser.add_argument( From cae33a9197283997d9c976eff79bb08b428ed74b Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sat, 19 Oct 2024 12:14:43 +0000 Subject: [PATCH 02/13] fix: fix wrongly clearing the vocab_mask of outlines --- python/sglang/srt/sampling/sampling_batch_info.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index f0722db2914..653df0e6334 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -140,7 +140,6 @@ def update_penalties(self): def _update_regex_vocab_mask_fsm(self): if not self.regex_fsms or not any(regex_fsm for regex_fsm in self.regex_fsms): - self.vocab_mask = None return self.vocab_mask = torch.zeros( @@ -158,7 +157,6 @@ def _update_regex_vocab_mask_fsm(self): def _update_regex_vocab_mask_bnf(self): if not self.regex_bnfs or not any(regex_bnf for regex_bnf in self.regex_bnfs): - self.vocab_mask = None return self.vocab_mask = torch.zeros( @@ -177,6 +175,7 @@ def _update_regex_vocab_mask_bnf(self): ] = 1 def update_regex_vocab_mask(self): + self.vocab_mask = None self._update_regex_vocab_mask_fsm() self._update_regex_vocab_mask_bnf() From 1b17c7225e53da8882456c8bfdf39826899ec8a5 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Sat, 19 Oct 2024 12:15:30 +0000 Subject: [PATCH 03/13] minor: fix the format by running pre-commit --- python/sglang/srt/constrained/bnf_cache.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index 9b7ac7d4126..19765731bd6 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -13,8 +13,10 @@ """Cache for the compressed finite state machine.""" -from transformers import AutoTokenizer from typing import Tuple + +from transformers import AutoTokenizer + from sglang.srt.constrained import ( GrammarMatcher, GrammarMatcherInitContext, @@ -24,8 +26,8 @@ MAX_ROLLBACK_TOKENS = 10 -class BNFCache(): - grammar_cache : GrammarMatcherInitContextCache +class BNFCache: + grammar_cache: GrammarMatcherInitContextCache def __init__( self, @@ -38,14 +40,12 @@ def __init__( if skip_tokenizer_init: return - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, **tokenizer_args_dict - ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) self.grammar_cache = GrammarMatcherInitContextCache( tokenizer_or_vocab=tokenizer ) - def get_context(self, key : Tuple[str, str]) -> GrammarMatcherInitContext: + def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext: key_type, key_string = key if key_type == "json": return self.grammar_cache.get_init_context_for_json_schema(key_string) @@ -54,7 +54,7 @@ def get_context(self, key : Tuple[str, str]) -> GrammarMatcherInitContext: else: raise ValueError(f"Invalid key_type: {key_type}") - def query(self, key : Tuple[str, str], vocab_size: int) -> GrammarMatcher: + def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher: ctx = self.get_context(key) return GrammarMatcher( ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size From d93f76e3fc92b2e9c9826c030944fc65e6522be8 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Mon, 21 Oct 2024 13:34:06 +0000 Subject: [PATCH 04/13] fix: set the object to error when import failed --- python/sglang/srt/constrained/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index c11a3dc12f4..ce6387f18e0 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -58,10 +58,9 @@ def build_regex_from_object( GrammarMatcherInitContextCache, ) except ImportError as e: - print( - f'\nError: {e}. Please install a new version of xgrammar by `pip install "xgrammar>=0.0.12"`\n' - ) - raise + GrammarMatcher = e + GrammarMatcherInitContext = e + GrammarMatcherInitContextCache = e __all__ = [ From ee4306592e2f7afc8889c36a7dd2db4e1f9a8e2a Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Mon, 21 Oct 2024 13:39:39 +0000 Subject: [PATCH 05/13] minor: set the default grammar backend as outlines --- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/server_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2dec04cb443..41cb29fee02 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -241,7 +241,7 @@ def __init__( skip_tokenizer_init=server_args.skip_tokenizer_init, whitespace_patterns=server_args.constrained_json_whitespace_pattern, ) - else: + elif server_args.grammar_backend == "outlines": self.regex_fsm_cache = FSMCache( server_args.tokenizer_path, { diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a97746b5276..afe7bfc0a26 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -102,7 +102,7 @@ class ServerArgs: # Kernel backend attention_backend: Optional[str] = None sampling_backend: Optional[str] = None - grammar_backend: str = "xgrammar" + grammar_backend: Optional[str] = "outlines" # Optimization/debug options disable_flashinfer: bool = False From 5ce813cba6ba45641958421c170656ddd8cb46a3 Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Wed, 23 Oct 2024 15:24:07 +0000 Subject: [PATCH 06/13] Merge branch 'main' into xgrammar-outlines --- docs/en/benchmark_and_profiling.md | 6 +- .../layers/attention/flashinfer_backend.py | 10 +-- python/sglang/srt/layers/sampler.py | 81 ++++++++++--------- .../srt/managers/detokenizer_manager.py | 4 + python/sglang/srt/managers/io_struct.py | 10 +++ python/sglang/srt/managers/scheduler.py | 8 +- .../sglang/srt/managers/tokenizer_manager.py | 14 ++++ .../srt/managers/tp_worker_overlap_thread.py | 36 +++++---- python/sglang/srt/mem_cache/memory_pool.py | 27 ++++--- .../srt/model_executor/cuda_graph_runner.py | 8 +- python/sglang/srt/server.py | 12 +++ python/sglang/test/run_eval.py | 2 + test/srt/test_eval_accuracy_mini.py | 1 + test/srt/test_pytorch_sampling_backend.py | 3 +- test/srt/test_srt_endpoint.py | 4 + 15 files changed, 150 insertions(+), 76 deletions(-) diff --git a/docs/en/benchmark_and_profiling.md b/docs/en/benchmark_and_profiling.md index 77fbbfc1b64..c0f54957d1f 100644 --- a/docs/en/benchmark_and_profiling.md +++ b/docs/en/benchmark_and_profiling.md @@ -46,4 +46,8 @@ pip install nvtx import nvtx with nvtx.annotate("description", color="color"): # some critical code -``` \ No newline at end of file +``` + +## Other tips + +1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e5e7ca29c90..c6b5393ee92 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -337,7 +337,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def update( self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens ): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( @@ -432,8 +432,8 @@ def call_begin_forward( kv_start_idx, ): bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) @@ -497,7 +497,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.update = self.update_single_wrapper def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( @@ -589,8 +589,8 @@ def call_begin_forward( use_ragged, ): bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -602,8 +602,8 @@ def call_begin_forward( self.max_context_len, ) + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) # extend part if use_ragged: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 454078d59c3..54fc47b736f 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -33,56 +33,61 @@ def forward( if isinstance(logits, LogitsProcessorOutput): logits = logits.next_token_logits - # Post process logits logits = logits.contiguous() - logits.div_(sampling_info.temperatures) - probs = torch.softmax(logits, dim=-1) - logits = None - del logits - if self.use_nan_detectioin and torch.any(torch.isnan(probs)): - logger.warning("Detected errors during sampling! NaN in the probability.") - probs = torch.where( - torch.isnan(probs), torch.full_like(probs, 1e-10), probs + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): + logger.warning("Detected errors during sampling! NaN in the logits.") + logits = torch.where( + torch.isnan(logits), torch.full_like(logits, -1e5), logits ) if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling - batch_next_token_ids = torch.argmax(probs, -1) - elif global_server_args_dict["sampling_backend"] == "flashinfer": - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps + batch_next_token_ids = torch.argmax(logits, -1) + else: + # Post process logits + logits.div_(sampling_info.temperatures) + probs = torch.softmax(logits, dim=-1) + logits = None + del logits + + if global_server_args_dict["sampling_backend"] == "flashinfer": + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids, success = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + sampling_info.top_ks, + sampling_info.top_ps, + filter_apply_order="joint", + ) + + if not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + elif global_server_args_dict["sampling_backend"] == "pytorch": + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( probs, - uniform_samples, sampling_info.top_ks, sampling_info.top_ps, - filter_apply_order="joint", + sampling_info.min_ps, + ) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - if not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) - elif global_server_args_dict["sampling_backend"] == "pytorch": - # Here we provide a slower fallback implementation. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( - probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps - ) - else: - raise ValueError( - f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" - ) - - return batch_next_token_ids + return batch_next_token_ids.to(torch.int32) def top_k_top_p_min_p_sampling_from_probs_torch( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 4ae31ecc8bf..d0d399363f3 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -27,6 +27,7 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + GetMemPoolSizeReqOutput, UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN @@ -111,6 +112,9 @@ def event_loop(self): # If it is a weight update request, no detokenization is needed. self.send_to_tokenizer.send_pyobj(recv_obj) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): + self.send_to_tokenizer.send_pyobj(recv_obj) + continue elif self.tokenizer is None: # If the tokenizer is skipped, no detokenization is needed self.send_to_tokenizer.send_pyobj(recv_obj) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9625ff44ebf..2cdc3f47851 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -353,3 +353,13 @@ class AbortReq: class ProfileReq(Enum): START_PROFILE = 1 STOP_PROFILE = 2 + + +@dataclass +class GetMemPoolSizeReq: + pass + + +@dataclass +class GetMemPoolSizeReqOutput: + size: int diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 865536d2abe..8573534a191 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -39,6 +39,8 @@ BatchEmbeddingOut, BatchTokenIDOut, FlushCacheReq, + GetMemPoolSizeReq, + GetMemPoolSizeReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -379,6 +381,10 @@ def process_input_requests(self, recv_reqs: List): self.start_profile() else: self.stop_profile() + elif isinstance(recv_req, GetMemPoolSizeReq): + self.send_to_detokenizer.send_pyobj( + GetMemPoolSizeReqOutput(self.max_total_num_tokens) + ) else: raise ValueError(f"Invalid request: {recv_req}") @@ -449,7 +455,7 @@ def handle_generate_request( req.allow_jump_forward = True # Truncate prompts that are too long - if len(req.origin_input_ids) >= self.max_req_input_len: + if len(req.origin_input_ids) > self.max_req_input_len: logger.warning( "Request length is longer than the KV cache pool size or " "the max context length. Truncated!!!" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fc9e2351980..875239a941e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,8 @@ EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + GetMemPoolSizeReq, + GetMemPoolSizeReqOutput, ProfileReq, RewardReqInput, TokenizedEmbeddingReqInput, @@ -531,6 +533,15 @@ def stop_profile(self): req = ProfileReq.STOP_PROFILE self.send_to_scheduler.send_pyobj(req) + async def get_memory_pool_size(self): + if self.to_create_loop: + self.create_handle_loop() + + req = GetMemPoolSizeReq() + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + return await self.mem_pool_size + async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None ): @@ -590,6 +601,9 @@ async def handle_loop(self): if isinstance(recv_obj, UpdateWeightReqOutput): self.model_update_result.set_result(recv_obj) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): + self.mem_pool_size.set_result(recv_obj) + continue assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8b27d2a69a9..8032915e7b0 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -32,6 +32,15 @@ logger = logging.getLogger(__name__) +@torch.compile(dynamic=True) +def resolve_future_token_ids(input_ids, future_token_ids_map): + input_ids[:] = torch.where( + input_ids < 0, + future_token_ids_map[torch.clamp(-input_ids, min=0)], + input_ids, + ) + + class TpModelWorkerClient: """A tensor parallel model worker.""" @@ -99,33 +108,25 @@ def forward_thread_func_(self): # Resolve future tokens in the input input_ids = model_worker_batch.input_ids - input_ids[:] = torch.where( - input_ids < 0, - self.future_token_ids_map[torch.clamp(-input_ids, min=0)], - input_ids, - ) + resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward logits_output, next_token_ids = self.worker.forward_batch_generation( model_worker_batch ) - self.launch_event.set() # Update the future token ids map bs = len(model_worker_batch.seq_lens) - future_next_token_ids = torch.arange( - -(future_token_ids_ct + bs), - -(future_token_ids_ct), - dtype=torch.int32, - device=self.device, - ) - self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( - torch.int32 - ) + self.future_token_ids_map[ + future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 + ] = next_token_ids + # Copy results to the CPU next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_event = torch.cuda.Event(blocking=True) copy_event.record() + + self.launch_event.set() self.copy_queue.put((copy_event, next_token_ids)) def copy_thread_func(self): @@ -149,8 +150,9 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # Allocate output future objects bs = len(model_worker_batch.seq_lens) future_next_token_ids = torch.arange( - -(self.future_token_ids_ct + bs), - -(self.future_token_ids_ct), + -(self.future_token_ids_ct + 1), + -(self.future_token_ids_ct + 1 + bs), + -1, dtype=torch.int32, device=self.device, ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 4277862a7eb..181ac7eefe9 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -51,7 +51,7 @@ def __init__(self, size: int, max_context_len: int, device: str, use_records: bo self.write = self.write_without_records def write(self, indices, values): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def available_size(self): @@ -221,16 +221,21 @@ def set_kv_buffer( cache_v: torch.Tensor, ): layer_id = layer.layer_id - if cache_k.dtype != self.dtype: - cache_k = cache_k.to(self.dtype) - if cache_v.dtype != self.dtype: - cache_v = cache_v.to(self.dtype) - if self.store_dtype != self.dtype: - self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) - self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) - else: - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v + copy_two_array( + loc, + self.k_buffer[layer_id], + cache_k, + self.v_buffer[layer_id], + cache_v, + self.dtype, + self.store_dtype, + ) + + +@torch.compile(dynamic=True) +def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): + dst_1[loc] = src_1.to(dtype).view(store_dtype) + dst_2[loc] = src_2.to(dtype).view(store_dtype) class MLATokenToKVPool(BaseTokenToKVPool): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ffa77ec4c90..b859df35888 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -92,6 +92,11 @@ def set_torch_compile_config(): torch._dynamo.config.accumulated_cache_size_limit = 1024 +@torch.compile(dynamic=True) +def clamp_position(seq_lens): + return torch.clamp((seq_lens - 1), min=0).to(torch.int64) + + class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" @@ -112,7 +117,6 @@ def __init__(self, model_runner: "ModelRunner"): self.capture_bs = list(range(1, 32)) + [64, 128] else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - self.capture_bs = [ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size ] @@ -253,7 +257,7 @@ def run_once(): encoder_lens=encoder_lens, return_logprob=False, top_logprobs_nums=[0] * bs, - positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), + positions=clamp_position(seq_lens), ) return forward(input_ids, forward_batch.positions, forward_batch) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ceb2d55c281..8912c5583a6 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -172,6 +172,18 @@ async def stop_profile(): ) +@app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) +async def get_memory_pool_size(): + """Get the memory pool size in number of tokens""" + try: + ret = await tokenizer_manager.get_memory_pool_size() + return ret.size + except Exception as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): """Update the weights inplace without re-launching the server.""" diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 51b32ca01b3..fe88171ce27 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -67,6 +67,7 @@ def run_eval(args): model=args.model, max_tokens=2048, base_url=base_url, + temperature=getattr(args, "temperature", 0.0), ) # Run eval @@ -119,6 +120,7 @@ def run_eval(args): parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--num-examples", type=int) parser.add_argument("--num-threads", type=int, default=512) + parser.add_argument("--temperature", type=float, default=0.0) args = parser.parse_args() run_eval(args) diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index 6ddd97d9405..ee977a63681 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -31,6 +31,7 @@ def test_mmlu(self): eval_name="mmlu", num_examples=64, num_threads=32, + temperature=0.1, ) metrics = run_eval(args) diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index 5507182a731..ee06de8fae4 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--sampling-backend", "pytorch"], + other_args=["--sampling-backend", "pytorch", "--disable-radix-cache"], ) @classmethod @@ -37,6 +37,7 @@ def test_mmlu(self): eval_name="mmlu", num_examples=64, num_threads=32, + temperature=0.1, ) metrics = run_eval(args) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 9a0a37c607b..c4c8e844d6f 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -119,6 +119,10 @@ def test_logprob_start_len(self): [x[-1] for x in res["meta_info"]["output_token_logprobs"]] ) + def test_get_memory_pool_size(self): + response = requests.post(self.base_url + "/get_memory_pool_size") + assert isinstance(response.json(), int) + if __name__ == "__main__": unittest.main() From b8648dd1d6b5241ac07e4c93907f3e3d329c969a Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Wed, 23 Oct 2024 17:27:31 +0000 Subject: [PATCH 07/13] refactor(constrained): add a new abstraction for constrained decoding --- python/sglang/srt/constrained/grammar.py | 167 +++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 169 +++--------------- python/sglang/srt/managers/scheduler.py | 103 +++-------- .../srt/sampling/sampling_batch_info.py | 47 +---- 4 files changed, 226 insertions(+), 260 deletions(-) create mode 100644 python/sglang/srt/constrained/grammar.py diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py new file mode 100644 index 00000000000..2a51938857c --- /dev/null +++ b/python/sglang/srt/constrained/grammar.py @@ -0,0 +1,167 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Cache for the compressed finite state machine.""" +import logging +import torch + +from typing import Union, Tuple, List, Optional +from sglang.srt.constrained import GrammarMatcher, RegexGuide + +from sglang.srt.constrained.bnf_cache import BNFCache +from sglang.srt.constrained.fsm_cache import FSMCache +from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.constrained.jump_forward import JumpForwardCache + +# from sglang.srt.managers.schedule_batch import Req + +logger = logging.getLogger(__name__) + +INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + +class JumpHelper(): + data : Union[List, str] + state : int + suffix_ids : List[int] + + def __init__(self, data : Union[List, str], state : int = -1, suffix_ids = []) -> None: + self.data = data + self.state = state + self.suffix_ids = suffix_ids + + def can_jump(self): + return len(self.data) > 0 + +class Grammar(): + grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]] + jump : Union[bool, JumpForwardMap, None] + def __init__( + self, + grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]], + jump : Union[bool, JumpForwardMap, None] + ) -> None: + self.grammar = grammar + self.jump = jump + + def accept_token(self, token : int): + if isinstance(self.grammar, GrammarMatcher): + assert self.grammar.accept_token(token) + else: + guide, state = self.grammar + self.grammar = guide, guide.get_next_state(state, token) + + def try_jump(self, tokenizer) -> JumpHelper: + if isinstance(self.grammar, GrammarMatcher): + return JumpHelper(self.grammar.find_jump_forward_string()) + elif isinstance(self.grammar, Tuple): + assert isinstance(self.jump, JumpForwardMap) + _, state = self.grammar + jump_forward_bytes = self.jump.jump_forward_byte(state) + if jump_forward_bytes is None or len(jump_forward_bytes) == 0: + return JumpHelper("") # can't jump + + # preprocess the jump forward string + suffix_bytes = [] + continuation_range = range(0x80, 0xC0) + cur_state = state + while ( + len(jump_forward_bytes) + and jump_forward_bytes[0][0] in continuation_range + ): + # continuation bytes + byte_edge = jump_forward_bytes.pop(0) + suffix_bytes.append(byte_edge[0]) + cur_state = byte_edge[1] + + suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] + suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) + return JumpHelper(suffix_ids, cur_state, suffix_bytes) + else: + return JumpHelper("") # can't jump + + def jump_forward_str_state(self, helper : JumpHelper) -> Tuple[str, int]: + if isinstance(helper.data, str): + return helper.data, -1 + else: + assert isinstance(self.jump, JumpForwardMap) + return self.jump.jump_forward_symbol(helper.state) + + def jump_and_retokenize(self, old_output_ids : List[int], new_output_ids : List[int], next_state : int): + if isinstance(self.grammar, GrammarMatcher): + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == new_output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.grammar.rollback(len(old_output_ids) - k) + + for i in range(k, len(new_output_ids)): + assert self.grammar.accept_token(new_output_ids[i]) + else: + self.grammar = self.grammar[0], next_state + + def fill_vocab_mask(self, vocab_mask : torch.Tensor, vocab_size : int): + if isinstance(self.grammar, GrammarMatcher): + # Note that this bitmask is a bitset, not bool + bitmask = self.grammar.find_next_token_bitmask() + # Mask the tokens that are not allowed + vocab_mask[ + self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size) + ] = 1 + else: + guide, state = self.grammar + vocab_mask.fill_(1) + vocab_mask[ + guide.get_next_instruction(state).tokens + ] = 0 + +class GrammarCache(): + grammar_cache : Union[BNFCache, FSMCache] + jump_cache : Union[bool, JumpForwardCache, None] + + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + enable=True, + skip_tokenizer_init=False, + whitespace_patterns=None, + backend=None, + allow_jump=False + ): + if backend == "xgrammar": + self.grammar_cache = BNFCache(tokenizer_path, tokenizer_args_dict, enable, skip_tokenizer_init) + else: + assert backend == "outlines" + self.grammar_cache = FSMCache(tokenizer_path, tokenizer_args_dict, enable, skip_tokenizer_init, whitespace_patterns) + + def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: + if isinstance(self.grammar_cache, BNFCache): + assert not isinstance(self.jump_cache, JumpForwardCache) + return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache) + else: + jump_map = None + guide, regex = self.grammar_cache.query(key) + if isinstance(self.jump_cache, JumpForwardCache): + jump_map = self.jump_cache.query(regex) + return Grammar((guide, 0), jump_map) + + def reset(self): + if isinstance(self.grammar_cache, FSMCache): + self.grammar_cache.reset() + if isinstance(self.jump_cache, JumpForwardCache): + self.jump_cache.reset() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 73f27514e61..ba3f263dbe6 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -37,8 +37,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained import GrammarMatcher, RegexGuide -from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.constrained.grammar import Grammar from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -247,12 +246,7 @@ def __init__( self.embedding = None # Constrained decoding - self.regex_fsm: Optional[RegexGuide] = None - self.regex_fsm_state: int = 0 - self.regex_bnf: Optional[GrammarMatcher] = None - - self.allow_jump_forward = False - self.jump_forward_map: JumpForwardMap = None + self.grammar: Optional[Grammar] = None # For Qwen2-VL self.mrope_position_delta = [] # use mutable object @@ -356,7 +350,9 @@ def check_finished(self): self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return - def jump_forward_and_retokenize_fsm(self, jump_forward_str, next_state): + def jump_forward_and_retokenize(self, jump_forward_str, next_state): + assert self.grammar is not None and self.tokenizer is not None + if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( @@ -396,7 +392,8 @@ def jump_forward_and_retokenize_fsm(self, jump_forward_str, next_state): self.surr_offset = self.read_offset - i break - self.regex_fsm_state = next_state + # update the inner state of the grammar + self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state) if self.return_logprob: # For fast-forward part's logprobs @@ -413,72 +410,6 @@ def jump_forward_and_retokenize_fsm(self, jump_forward_str, next_state): return True - def jump_forward_and_retokenize_bnf(self, jump_forward_str): - assert self.regex_bnf is not None, "should be a regex request" - assert self.tokenizer is not None, "should have a tokenizer" - - if self.origin_input_text is None: - # Recovering text can only use unpadded ids - self.origin_input_text = self.tokenizer.decode( - self.origin_input_ids_unpadded - ) - - all_text = self.origin_input_text + self.decoded_text + jump_forward_str - all_ids = self.tokenizer.encode(all_text) - if not all_ids: - logger.warning("Encoded all_text resulted in empty all_ids") - return False - - prompt_tokens = len(self.origin_input_ids_unpadded) - if prompt_tokens > len(all_ids): - logger.warning("prompt_tokens is larger than encoded all_ids") - return False - - if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: - # TODO(lsyin): fix token fusion - logger.warning( - "Token fusion between input and output, try to avoid this by removing the space at the end of the input." - ) - return False - - old_output_ids = self.output_ids - self.output_ids = all_ids[prompt_tokens:] - self.decoded_text = self.decoded_text + jump_forward_str - self.surr_offset = prompt_tokens - self.read_offset = len(all_ids) - - # NOTE: A trick to reduce the surrouding tokens decoding overhead - for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET): - surr_text_ = self.tokenizer.decode( - all_ids[self.read_offset - i : self.read_offset] - ) - if not surr_text_.endswith("�"): - self.surr_offset = self.read_offset - i - break - - k = 0 - for i, old_id in enumerate(old_output_ids): - if old_id == self.output_ids[i]: - k = i + 1 - else: - break - - # rollback to the last token that is the same - if k < len(old_output_ids): - self.regex_bnf.rollback(len(old_output_ids) - k) - - for i in range(k, len(self.output_ids)): - assert self.regex_bnf.accept_token(self.output_ids[i]) - - if self.return_logprob: - # For fast-forward part's logprobs - self.output_token_logprobs = self.output_token_logprobs[:k] - self.output_top_logprobs = self.output_top_logprobs[:k] - self.logprob_start_len = prompt_tokens + k - self.last_update_decode_tokens = len(self.output_ids) - k - - return True - def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " @@ -532,8 +463,8 @@ class ScheduleBatch: # Stream has_stream: bool = False - # Has regex - has_regex: bool = False + # Has grammar + has_grammar: bool = False # device device: str = "cuda" @@ -541,7 +472,7 @@ class ScheduleBatch: @classmethod def init_new( cls, - reqs, + reqs: List[Req], req_to_token_pool, token_to_kv_pool, tree_cache, @@ -555,7 +486,7 @@ def init_new( model_config=model_config, return_logprob=any(req.return_logprob for req in reqs), has_stream=any(req.stream for req in reqs), - has_regex=any(req.regex_fsm or req.regex_bnf for req in reqs), + has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, ) @@ -862,27 +793,10 @@ def check_for_jump_forward(self, pad_input_ids_func): keep_indices = set(i for i in range(len(self.reqs))) for i, req in enumerate(self.reqs): - if req.jump_forward_map is not None: - assert req.regex_fsm is not None and req.regex_bnf is None - jump_forward_bytes = req.jump_forward_map.jump_forward_byte( - req.regex_fsm_state - ) - if jump_forward_bytes is not None and len(jump_forward_bytes) > 1: - suffix_bytes = [] - continuation_range = range(0x80, 0xC0) - cur_state = req.regex_fsm_state - while ( - len(jump_forward_bytes) - and jump_forward_bytes[0][0] in continuation_range - ): - # continuation bytes - byte_edge = jump_forward_bytes.pop(0) - suffix_bytes.append(byte_edge[0]) - cur_state = byte_edge[1] - - suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] - suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens) - + if req.grammar is not None: + jump_helper = req.grammar.try_jump(req.tokenizer) + if jump_helper.can_jump(): + suffix_ids = jump_helper.suffix_ids # Current ids, for cache and revert cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] cur_output_ids = req.output_ids @@ -896,12 +810,10 @@ def check_for_jump_forward(self, pad_input_ids_func): ( jump_forward_str, next_state, - ) = req.jump_forward_map.jump_forward_symbol(cur_state) + ) = req.grammar.jump_forward_str_state(jump_helper) - # Make the incrementally decoded text part of jump_forward_str - # so that the UTF-8 will not corrupt jump_forward_str = new_text + jump_forward_str - if not req.jump_forward_and_retokenize_fsm( + if not req.jump_forward_and_retokenize( jump_forward_str, next_state ): req.output_ids = cur_output_ids @@ -922,38 +834,6 @@ def check_for_jump_forward(self, pad_input_ids_func): jump_forward_reqs.append(req) keep_indices.remove(i) - if req.allow_jump_forward: - assert req.regex_bnf is not None and req.regex_fsm is None - jump_forward_str = req.regex_bnf.find_jump_forward_string() - if len(jump_forward_str) > 1: - cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] - cur_output_ids = req.output_ids - decode_res, new_text = req.get_next_inc_detokenization() - if not decode_res: - req.output_ids = cur_output_ids - continue - - jump_forward_str = new_text + jump_forward_str - if not req.jump_forward_and_retokenize_bnf(jump_forward_str): - # Failed to jump forward, revert - req.output_ids = cur_output_ids - continue - - # The decode status has diverged from detokenizer_manager - req.vid += 1 - - # insert the old request into tree_cache - self.tree_cache.cache_finished_req(req, cur_all_ids) - - # re-applying image padding - if req.image_inputs is not None: - req.origin_input_ids = pad_input_ids_func( - req.origin_input_ids_unpadded, req.image_inputs - ) - - jump_forward_reqs.append(req) - keep_indices.remove(i) - self.filter_batch(keep_indices=list(keep_indices)) return jump_forward_reqs @@ -1038,7 +918,7 @@ def filter_batch( self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) - self.has_regex = any(req.regex_fsm or req.regex_bnf for req in self.reqs) + self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) @@ -1071,7 +951,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.return_logprob = self.return_logprob or other.return_logprob self.has_stream = self.has_stream or other.has_stream - self.has_regex = self.has_regex or other.has_regex + self.has_grammar = self.has_grammar or other.has_grammar def get_model_worker_batch(self): if self.forward_mode.is_decode(): @@ -1081,15 +961,10 @@ def get_model_worker_batch(self): extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - if self.has_regex: - self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] - self.sampling_info.regex_bnfs = [req.regex_bnf for req in self.reqs] - self.sampling_info.regex_fsm_states = [ - req.regex_fsm_state for req in self.reqs - ] + if self.has_grammar: + self.sampling_info.grammars = [req.grammar for req in self.reqs] else: - self.sampling_info.regex_fsms = None - self.sampling_info.regex_bnfs = None + self.sampling_info.grammars = None global bid bid += 1 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8573534a191..0dfd4528cce 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -29,9 +29,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.bnf_cache import BNFCache -from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardCache +from sglang.srt.constrained.grammar import GrammarCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -227,32 +225,20 @@ def __init__( ) # Init the FSM cache for constrained generation - self.regex_fsm_cache = None - self.regex_bnf_cache = None + self.grammar_cache = None if not server_args.skip_tokenizer_init: - if server_args.grammar_backend == "xgrammar": - self.regex_bnf_cache = BNFCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - whitespace_patterns=server_args.constrained_json_whitespace_pattern, - ) - elif server_args.grammar_backend == "outlines": - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, - ) - - self.jump_forward_cache = JumpForwardCache() + self.grammar_cache = GrammarCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + whitespace_patterns=server_args.constrained_json_whitespace_pattern, + backend=server_args.grammar_backend, + allow_jump=not server_args.disable_regex_jump_forward, + ) # Init new token estimation assert ( @@ -424,35 +410,16 @@ def handle_generate_request( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): - if self.regex_fsm_cache is not None: - # FSM cache - assert self.regex_bnf_cache is None - if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("json", req.sampling_params.json_schema) - ) - elif req.sampling_params.regex is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("regex", req.sampling_params.regex) - ) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) - else: - # BNF cache - assert self.regex_bnf_cache is not None - vocab_size = self.model_config.vocab_size - if req.sampling_params.json_schema is not None: - req.regex_bnf = self.regex_bnf_cache.query( - ("json", req.sampling_params.json_schema), vocab_size - ) - elif req.sampling_params.regex is not None: - req.regex_bnf = self.regex_bnf_cache.query( - ("regex", req.sampling_params.regex), vocab_size - ) - if not self.disable_regex_jump_forward: - req.allow_jump_forward = True + assert self.grammar_cache is not None + if req.sampling_params.json_schema is not None: + req.grammar = self.grammar_cache.query( + ("json", req.sampling_params.json_schema), + self.model_config.vocab_size, + ) + elif req.sampling_params.regex is not None: + req.grammar = self.grammar_cache.query( + ("regex", req.sampling_params.regex), self.model_config.vocab_size + ) # Truncate prompts that are too long if len(req.origin_input_ids) > self.max_req_input_len: @@ -830,14 +797,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) - if req.regex_fsm is not None: - assert req.regex_bnf is None - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_ids[i] - ) - if req.regex_bnf is not None: - assert req.regex_fsm is None - assert req.regex_bnf.accept_token(next_token_ids[i]) + if req.grammar is not None: + req.grammar.accept_token(next_token_ids[i]) if req.return_logprob: logprob_pt += self.add_logprob_return_values( @@ -892,14 +853,8 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_id) req.check_finished() - if req.regex_fsm is not None: - assert req.regex_bnf is None - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_id - ) - if req.regex_bnf is not None: - assert req.regex_fsm is None - assert req.regex_bnf.accept_token(next_token_id) + if req.grammar is not None: + req.grammar.accept_token(next_token_id) if req.finished(): self.tree_cache.cache_finished_req(req) @@ -1097,8 +1052,8 @@ def flush_cache(self): ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - if self.regex_fsm_cache is not None: - self.regex_fsm_cache.reset() + if self.grammar_cache is not None: + self.grammar_cache.reset() # TODO(dark): reset the bnf cache self.req_to_token_pool.clear() self.token_to_kv_pool.clear() diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5c33eaa6476..6afd48cc8a1 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -6,7 +6,7 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib -from sglang.srt.constrained import GrammarMatcher, RegexGuide +from sglang.srt.constrained.grammar import Grammar if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -31,12 +31,7 @@ class SamplingBatchInfo: logit_bias: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None - # FSM states - regex_fsms: Optional[List[Optional[RegexGuide]]] = None - regex_fsm_states: List[int] = None - - # BNF states - regex_bnfs: Optional[List[Optional[GrammarMatcher]]] = None + grammars: Optional[List[Optional[Grammar]]] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None @@ -138,25 +133,9 @@ def update_penalties(self): ) self.linear_penalties = penalizer.apply(self.linear_penalties) - def _update_regex_vocab_mask_fsm(self): - if not self.regex_fsms or not any(regex_fsm for regex_fsm in self.regex_fsms): - return - - self.vocab_mask = torch.zeros( - len(self.temperatures), - self.vocab_size, - dtype=torch.bool, - device=self.device, - ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: - self.vocab_mask[i].fill_(1) - self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens - ] = 0 - - def _update_regex_vocab_mask_bnf(self): - if not self.regex_bnfs or not any(regex_bnf for regex_bnf in self.regex_bnfs): + def update_regex_vocab_mask(self): + if not self.grammars or not any(grammar for grammar in self.grammars): + self.vocab_mask = None return self.vocab_mask = torch.zeros( @@ -165,19 +144,9 @@ def _update_regex_vocab_mask_bnf(self): dtype=torch.bool, device=self.device, ) - for i, regex_bnf in enumerate(self.regex_bnfs): - if regex_bnf is not None: - # Note that this bitmask is a bitset, not bool - bitmask = regex_bnf.find_next_token_bitmask() - # Mask the tokens that are not allowed - self.vocab_mask[i][ - regex_bnf.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) - ] = 1 - - def update_regex_vocab_mask(self): - self.vocab_mask = None - self._update_regex_vocab_mask_fsm() - self._update_regex_vocab_mask_bnf() + for i, grammar in enumerate(self.grammars): + if grammar is not None: + grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: From e615ce3b22523047c93f8618b331ad481ed649c1 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 25 Oct 2024 00:00:07 +0800 Subject: [PATCH 08/13] minor(constrained): set import failure object as None to pass type check --- python/sglang/srt/constrained/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index ce6387f18e0..8660084edde 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -58,9 +58,9 @@ def build_regex_from_object( GrammarMatcherInitContextCache, ) except ImportError as e: - GrammarMatcher = e - GrammarMatcherInitContext = e - GrammarMatcherInitContextCache = e + GrammarMatcher = None + GrammarMatcherInitContext = None + GrammarMatcherInitContextCache = None __all__ = [ From cd59ed0eb458c0f0a32c93539e647fd7ea695a2e Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 25 Oct 2024 00:23:07 +0800 Subject: [PATCH 09/13] fix(constrained): use DummyType to avoid type failure in 'isinstance' --- python/sglang/srt/constrained/__init__.py | 9 ++++---- python/sglang/srt/constrained/grammar.py | 25 ++++++++++++++--------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 8660084edde..18251567a73 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -58,10 +58,11 @@ def build_regex_from_object( GrammarMatcherInitContextCache, ) except ImportError as e: - GrammarMatcher = None - GrammarMatcherInitContext = None - GrammarMatcherInitContextCache = None - + class Dummy: + pass + GrammarMatcher = Dummy + GrammarMatcherInitContext = Dummy + GrammarMatcherInitContextCache = Dummy __all__ = [ "RegexGuide", diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py index 2a51938857c..ed6a0267e31 100644 --- a/python/sglang/srt/constrained/grammar.py +++ b/python/sglang/srt/constrained/grammar.py @@ -29,6 +29,9 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 +class XGrammarJump(): + pass + class JumpHelper(): data : Union[List, str] state : int @@ -43,15 +46,15 @@ def can_jump(self): return len(self.data) > 0 class Grammar(): - grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]] - jump : Union[bool, JumpForwardMap, None] + grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]] + jump_map : Union[XGrammarJump, JumpForwardMap, None] def __init__( self, - grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]], - jump : Union[bool, JumpForwardMap, None] + grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]], + jump_map : Union[XGrammarJump, JumpForwardMap, None] ) -> None: self.grammar = grammar - self.jump = jump + self.jump_map = jump_map def accept_token(self, token : int): if isinstance(self.grammar, GrammarMatcher): @@ -64,9 +67,9 @@ def try_jump(self, tokenizer) -> JumpHelper: if isinstance(self.grammar, GrammarMatcher): return JumpHelper(self.grammar.find_jump_forward_string()) elif isinstance(self.grammar, Tuple): - assert isinstance(self.jump, JumpForwardMap) + assert isinstance(self.jump_map, JumpForwardMap) _, state = self.grammar - jump_forward_bytes = self.jump.jump_forward_byte(state) + jump_forward_bytes = self.jump_map.jump_forward_byte(state) if jump_forward_bytes is None or len(jump_forward_bytes) == 0: return JumpHelper("") # can't jump @@ -93,8 +96,8 @@ def jump_forward_str_state(self, helper : JumpHelper) -> Tuple[str, int]: if isinstance(helper.data, str): return helper.data, -1 else: - assert isinstance(self.jump, JumpForwardMap) - return self.jump.jump_forward_symbol(helper.state) + assert isinstance(self.jump_map, JumpForwardMap) + return self.jump_map.jump_forward_symbol(helper.state) def jump_and_retokenize(self, old_output_ids : List[int], new_output_ids : List[int], next_state : int): if isinstance(self.grammar, GrammarMatcher): @@ -131,7 +134,7 @@ def fill_vocab_mask(self, vocab_mask : torch.Tensor, vocab_size : int): class GrammarCache(): grammar_cache : Union[BNFCache, FSMCache] - jump_cache : Union[bool, JumpForwardCache, None] + jump_cache : Union[XGrammarJump, JumpForwardCache, None] def __init__( self, @@ -145,9 +148,11 @@ def __init__( ): if backend == "xgrammar": self.grammar_cache = BNFCache(tokenizer_path, tokenizer_args_dict, enable, skip_tokenizer_init) + self.jump_cache = XGrammarJump() if allow_jump else None else: assert backend == "outlines" self.grammar_cache = FSMCache(tokenizer_path, tokenizer_args_dict, enable, skip_tokenizer_init, whitespace_patterns) + self.jump_cache = JumpForwardCache() if allow_jump else None def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: if isinstance(self.grammar_cache, BNFCache): From d01e7afe550627ea4813a5411525294cb7b641ef Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 25 Oct 2024 00:47:33 +0800 Subject: [PATCH 10/13] fix(constrained): fix wrong parameter order in initing bnf_cache --- python/sglang/srt/constrained/bnf_cache.py | 2 +- python/sglang/srt/constrained/grammar.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index 19765731bd6..8feb496e626 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -34,7 +34,7 @@ def __init__( tokenizer_path, tokenizer_args_dict, skip_tokenizer_init=False, - whitespace_patterns=None, + whitespace_patterns=None ): # TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init if skip_tokenizer_init: diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py index ed6a0267e31..3f99099981f 100644 --- a/python/sglang/srt/constrained/grammar.py +++ b/python/sglang/srt/constrained/grammar.py @@ -140,18 +140,28 @@ def __init__( self, tokenizer_path, tokenizer_args_dict, - enable=True, skip_tokenizer_init=False, whitespace_patterns=None, backend=None, allow_jump=False ): if backend == "xgrammar": - self.grammar_cache = BNFCache(tokenizer_path, tokenizer_args_dict, enable, skip_tokenizer_init) + self.grammar_cache = BNFCache( + tokenizer_path=tokenizer_path, + tokenizer_args_dict=tokenizer_args_dict, + skip_tokenizer_init=skip_tokenizer_init, + whitespace_patterns=whitespace_patterns + ) self.jump_cache = XGrammarJump() if allow_jump else None else: assert backend == "outlines" - self.grammar_cache = FSMCache(tokenizer_path, tokenizer_args_dict, enable, skip_tokenizer_init, whitespace_patterns) + self.grammar_cache = FSMCache( + tokenizer_path=tokenizer_path, + tokenizer_args_dict=tokenizer_args_dict, + skip_tokenizer_init=skip_tokenizer_init, + constrained_json_whitespace_pattern=whitespace_patterns, + enable=True + ) self.jump_cache = JumpForwardCache() if allow_jump else None def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: From c07cd0da1968d6a917caea241d5e77ee5bebfc3e Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Thu, 24 Oct 2024 17:10:22 +0000 Subject: [PATCH 11/13] minor: format the code using pre-commit --- python/sglang/srt/constrained/__init__.py | 2 + python/sglang/srt/constrained/bnf_cache.py | 2 +- python/sglang/srt/constrained/grammar.py | 66 ++++++++++++---------- 3 files changed, 38 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 18251567a73..a8708dfea71 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -58,8 +58,10 @@ def build_regex_from_object( GrammarMatcherInitContextCache, ) except ImportError as e: + class Dummy: pass + GrammarMatcher = Dummy GrammarMatcherInitContext = Dummy GrammarMatcherInitContextCache = Dummy diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py index 8feb496e626..19765731bd6 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -34,7 +34,7 @@ def __init__( tokenizer_path, tokenizer_args_dict, skip_tokenizer_init=False, - whitespace_patterns=None + whitespace_patterns=None, ): # TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init if skip_tokenizer_init: diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py index 3f99099981f..37526128d0a 100644 --- a/python/sglang/srt/constrained/grammar.py +++ b/python/sglang/srt/constrained/grammar.py @@ -13,15 +13,14 @@ """Cache for the compressed finite state machine.""" import logging +from typing import List, Optional, Tuple, Union + import torch -from typing import Union, Tuple, List, Optional from sglang.srt.constrained import GrammarMatcher, RegexGuide - from sglang.srt.constrained.bnf_cache import BNFCache from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardMap -from sglang.srt.constrained.jump_forward import JumpForwardCache +from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap # from sglang.srt.managers.schedule_batch import Req @@ -29,15 +28,17 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 -class XGrammarJump(): + +class XGrammarJump: pass -class JumpHelper(): - data : Union[List, str] - state : int - suffix_ids : List[int] - def __init__(self, data : Union[List, str], state : int = -1, suffix_ids = []) -> None: +class JumpHelper: + data: Union[List, str] + state: int + suffix_ids: List[int] + + def __init__(self, data: Union[List, str], state: int = -1, suffix_ids=[]) -> None: self.data = data self.state = state self.suffix_ids = suffix_ids @@ -45,18 +46,20 @@ def __init__(self, data : Union[List, str], state : int = -1, suffix_ids = []) - def can_jump(self): return len(self.data) > 0 -class Grammar(): - grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]] - jump_map : Union[XGrammarJump, JumpForwardMap, None] + +class Grammar: + grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]] + jump_map: Union[XGrammarJump, JumpForwardMap, None] + def __init__( - self, - grammar : Union[GrammarMatcher, Tuple[RegexGuide, int]], - jump_map : Union[XGrammarJump, JumpForwardMap, None] + self, + grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]], + jump_map: Union[XGrammarJump, JumpForwardMap, None], ) -> None: self.grammar = grammar self.jump_map = jump_map - def accept_token(self, token : int): + def accept_token(self, token: int): if isinstance(self.grammar, GrammarMatcher): assert self.grammar.accept_token(token) else: @@ -71,7 +74,7 @@ def try_jump(self, tokenizer) -> JumpHelper: _, state = self.grammar jump_forward_bytes = self.jump_map.jump_forward_byte(state) if jump_forward_bytes is None or len(jump_forward_bytes) == 0: - return JumpHelper("") # can't jump + return JumpHelper("") # can't jump # preprocess the jump forward string suffix_bytes = [] @@ -90,16 +93,18 @@ def try_jump(self, tokenizer) -> JumpHelper: suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) return JumpHelper(suffix_ids, cur_state, suffix_bytes) else: - return JumpHelper("") # can't jump + return JumpHelper("") # can't jump - def jump_forward_str_state(self, helper : JumpHelper) -> Tuple[str, int]: + def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]: if isinstance(helper.data, str): return helper.data, -1 else: assert isinstance(self.jump_map, JumpForwardMap) return self.jump_map.jump_forward_symbol(helper.state) - def jump_and_retokenize(self, old_output_ids : List[int], new_output_ids : List[int], next_state : int): + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): if isinstance(self.grammar, GrammarMatcher): k = 0 for i, old_id in enumerate(old_output_ids): @@ -117,7 +122,7 @@ def jump_and_retokenize(self, old_output_ids : List[int], new_output_ids : List[ else: self.grammar = self.grammar[0], next_state - def fill_vocab_mask(self, vocab_mask : torch.Tensor, vocab_size : int): + def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int): if isinstance(self.grammar, GrammarMatcher): # Note that this bitmask is a bitset, not bool bitmask = self.grammar.find_next_token_bitmask() @@ -128,13 +133,12 @@ def fill_vocab_mask(self, vocab_mask : torch.Tensor, vocab_size : int): else: guide, state = self.grammar vocab_mask.fill_(1) - vocab_mask[ - guide.get_next_instruction(state).tokens - ] = 0 + vocab_mask[guide.get_next_instruction(state).tokens] = 0 + -class GrammarCache(): - grammar_cache : Union[BNFCache, FSMCache] - jump_cache : Union[XGrammarJump, JumpForwardCache, None] +class GrammarCache: + grammar_cache: Union[BNFCache, FSMCache] + jump_cache: Union[XGrammarJump, JumpForwardCache, None] def __init__( self, @@ -143,14 +147,14 @@ def __init__( skip_tokenizer_init=False, whitespace_patterns=None, backend=None, - allow_jump=False + allow_jump=False, ): if backend == "xgrammar": self.grammar_cache = BNFCache( tokenizer_path=tokenizer_path, tokenizer_args_dict=tokenizer_args_dict, skip_tokenizer_init=skip_tokenizer_init, - whitespace_patterns=whitespace_patterns + whitespace_patterns=whitespace_patterns, ) self.jump_cache = XGrammarJump() if allow_jump else None else: @@ -160,7 +164,7 @@ def __init__( tokenizer_args_dict=tokenizer_args_dict, skip_tokenizer_init=skip_tokenizer_init, constrained_json_whitespace_pattern=whitespace_patterns, - enable=True + enable=True, ) self.jump_cache = JumpForwardCache() if allow_jump else None From 8608c2b6d44101b037911deee5e72401f5c242e7 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 25 Oct 2024 11:18:59 +0800 Subject: [PATCH 12/13] fix(constrained): fix wrong jump-forward assertion --- python/sglang/srt/constrained/grammar.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py index 37526128d0a..c2b8082424e 100644 --- a/python/sglang/srt/constrained/grammar.py +++ b/python/sglang/srt/constrained/grammar.py @@ -38,7 +38,7 @@ class JumpHelper: state: int suffix_ids: List[int] - def __init__(self, data: Union[List, str], state: int = -1, suffix_ids=[]) -> None: + def __init__(self, data: Union[List, str]="", state: int = -1, suffix_ids=[]) -> None: self.data = data self.state = state self.suffix_ids = suffix_ids @@ -67,14 +67,16 @@ def accept_token(self, token: int): self.grammar = guide, guide.get_next_state(state, token) def try_jump(self, tokenizer) -> JumpHelper: - if isinstance(self.grammar, GrammarMatcher): + if isinstance(self.jump_map, XGrammarJump): + assert isinstance(self.grammar, GrammarMatcher) return JumpHelper(self.grammar.find_jump_forward_string()) - elif isinstance(self.grammar, Tuple): - assert isinstance(self.jump_map, JumpForwardMap) + elif isinstance(self.jump_map, JumpForwardMap): + assert isinstance(self.grammar, Tuple) + _, state = self.grammar jump_forward_bytes = self.jump_map.jump_forward_byte(state) if jump_forward_bytes is None or len(jump_forward_bytes) == 0: - return JumpHelper("") # can't jump + return JumpHelper() # can't jump # preprocess the jump forward string suffix_bytes = [] @@ -93,7 +95,7 @@ def try_jump(self, tokenizer) -> JumpHelper: suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) return JumpHelper(suffix_ids, cur_state, suffix_bytes) else: - return JumpHelper("") # can't jump + return JumpHelper() # can't jump def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]: if isinstance(helper.data, str): From cbdca833668d8cf10d3527290297b9dca85fbf1d Mon Sep 17 00:00:00 2001 From: DarkSharpness Date: Fri, 25 Oct 2024 03:21:39 +0000 Subject: [PATCH 13/13] minor: format the code using pre-commit --- python/sglang/srt/constrained/grammar.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py index c2b8082424e..0281539b89c 100644 --- a/python/sglang/srt/constrained/grammar.py +++ b/python/sglang/srt/constrained/grammar.py @@ -38,7 +38,9 @@ class JumpHelper: state: int suffix_ids: List[int] - def __init__(self, data: Union[List, str]="", state: int = -1, suffix_ids=[]) -> None: + def __init__( + self, data: Union[List, str] = "", state: int = -1, suffix_ids=[] + ) -> None: self.data = data self.state = state self.suffix_ids = suffix_ids @@ -95,7 +97,7 @@ def try_jump(self, tokenizer) -> JumpHelper: suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) return JumpHelper(suffix_ids, cur_state, suffix_bytes) else: - return JumpHelper() # can't jump + return JumpHelper() # can't jump def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]: if isinstance(helper.data, str):