From 5d09ca5735462eacc36a0b0aed7f4108c3d33f2f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 11 Oct 2024 06:26:20 -0700 Subject: [PATCH] Fix constrained decoding (#1634) --- python/sglang/srt/managers/schedule_batch.py | 2 ++ test/srt/test_json_constrained.py | 24 ++++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 92f998b2c30..156e830d185 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -810,6 +810,8 @@ def get_model_worker_batch(self): self.sampling_info.regex_fsm_states = [ req.regex_fsm_state for req in self.reqs ] + else: + self.sampling_info.regex_fsms = None return ModelWorkerBatch( forward_mode=self.forward_mode, diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index d3abc70a44f..12cd5167613 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -1,5 +1,6 @@ import json import unittest +from concurrent.futures import ThreadPoolExecutor import openai import requests @@ -27,13 +28,18 @@ def setUpClass(cls): "required": ["name", "population"], } ) - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=["--max-running-requests", "10"], + ) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) - def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): + def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): response = requests.post( self.base_url + "/generate", json={ @@ -43,7 +49,7 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): "max_new_tokens": 128, "n": n, "stop_token_ids": [119690], - "json_schema": self.json_schema, + "json_schema": json_schema, }, "stream": False, "return_logprob": return_logprob, @@ -53,6 +59,10 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): ) print(json.dumps(response.json())) print("=" * 100) + + if not json_schema: + return + try: js_obj = json.loads(response.json()["text"]) except (TypeError, json.decoder.JSONDecodeError): @@ -61,7 +71,7 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): assert isinstance(js_obj["population"], int) def test_json_generate(self): - self.run_decode() + self.run_decode(json_schema=self.json_schema) def test_json_openai(self): client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") @@ -89,6 +99,12 @@ def test_json_openai(self): assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) + def test_mix_json_and_other(self): + json_schemas = [None, None, self.json_schema, self.json_schema] * 10 + + with ThreadPoolExecutor(len(json_schemas)) as executor: + list(executor.map(self.run_decode, json_schemas)) + if __name__ == "__main__": unittest.main()