Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Catch syntax error of Regex Guide to avoid crash #1521

Merged
merged 10 commits into from
Sep 28, 2024
10 changes: 9 additions & 1 deletion python/sglang/srt/constrained/fsm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
"""

"""Cache for the compressed finite state machine."""
import logging

from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.json_schema import build_regex_from_schema
from transformers import AutoTokenizer

from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_tool_cache import BaseToolCache

logger = logging.getLogger(__name__)


class FSMCache(BaseToolCache):
def __init__(
Expand Down Expand Up @@ -76,5 +80,9 @@ def init_value(self, key):
regex = key_string
else:
raise ValueError(f"Invalid key_type: {key_type}")

try:
parse_pattern(regex)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex guide: {regex}")
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
return None, regex
return RegexGuide(regex, self.outlines_tokenizer), regex
17 changes: 15 additions & 2 deletions python/sglang/srt/constrained/jump_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
"""

import dataclasses
import logging
from collections import defaultdict

import interegular
import outlines.caching
from interegular import InvalidSyntax

from sglang.srt.constrained import (
FSMInfo,
Expand All @@ -34,6 +36,8 @@

IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class JumpEdge:
Expand All @@ -47,7 +51,12 @@ class JumpForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string)
try:
regex_pattern = interegular.parse_pattern(regex_string)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex: {regex_string}")
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
self.state_to_jump_forward = None
return

byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
Expand Down Expand Up @@ -165,7 +174,11 @@ def __init__(self):
super().__init__()

def init_value(self, regex):
return JumpForwardMap(regex)
forward_map = JumpForwardMap(regex)
if forward_map.state_to_jump_forward:
return forward_map
else:
return None


def test_main(regex_string):
Expand Down
Loading