From 5977e9d16d450d1d2c6653721d5e1ab251657863 Mon Sep 17 00:00:00 2001 From: lzhang Date: Tue, 13 Aug 2024 20:21:43 +0800 Subject: [PATCH 1/6] Fix create_abort_task, GenerateReqInput does not have rids. --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e1bfbc7e670..d5fbfe05d3b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -507,7 +507,7 @@ async def abort_request(): if obj.is_single: self.abort_request(obj.rid) else: - for rid in obj.rids: + for rid in obj.rid: self.abort_request(rid) background_tasks = BackgroundTasks() From 04fae70563dd3ef444e8e6e6262b4bfc11859d61 Mon Sep 17 00:00:00 2001 From: lzhang Date: Wed, 14 Aug 2024 12:01:17 +0800 Subject: [PATCH 2/6] set input_ids instead of text if skip_tokenizer_init is set. --- python/sglang/srt/server.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7331425fae9..35083079046 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -533,13 +533,22 @@ async def async_generate( prompt: str, sampling_params: Optional[Dict] = None, ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } pos = 0 + print(json_data) + print(self.generate_url) timeout = aiohttp.ClientTimeout(total=3 * 3600) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.post(self.generate_url, json=json_data) as response: @@ -549,10 +558,13 @@ async def async_generate( if chunk == "data: [DONE]\n\n": break data = json.loads(chunk[5:].strip("\n")) - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) + if hasattr(data, 'text'): + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data add_request = async_generate From 3e97ba69cae724e1f6e8749e590c9386f26d7f67 Mon Sep 17 00:00:00 2001 From: lzhang Date: Wed, 14 Aug 2024 12:48:03 +0800 Subject: [PATCH 3/6] Remove unnecessary print. --- python/sglang/srt/server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 35083079046..7a7bd25a3f8 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -547,8 +547,6 @@ async def async_generate( } pos = 0 - print(json_data) - print(self.generate_url) timeout = aiohttp.ClientTimeout(total=3 * 3600) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.post(self.generate_url, json=json_data) as response: From 5cd2f79856c83ccaf25b935a6d9cd76dc688376e Mon Sep 17 00:00:00 2001 From: lzhang Date: Wed, 14 Aug 2024 12:49:28 +0800 Subject: [PATCH 4/6] Fix lint. --- python/sglang/srt/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7a7bd25a3f8..8f735ac0c74 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -556,7 +556,7 @@ async def async_generate( if chunk == "data: [DONE]\n\n": break data = json.loads(chunk[5:].strip("\n")) - if hasattr(data, 'text'): + if hasattr(data, "text"): cur = data["text"][pos:] if cur: yield cur From de4be50de34d741dea550d6e6f10c8bcd966bbd4 Mon Sep 17 00:00:00 2001 From: lzhang Date: Tue, 20 Aug 2024 20:44:05 +0800 Subject: [PATCH 5/6] Another way to convert between stop_strs and stop arg. --- python/sglang/srt/sampling_params.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 6a8823cc4de..8a52e5cd76c 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -123,3 +123,17 @@ def normalize(self, tokenizer): else: stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len + + def to_srt_kwargs(self): + return { + "max_new_tokens": self.max_new_tokens, + "stop": self.stop_strs, + "stop_token_ids": self.stop_token_ids, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "ignore_eos": self.ignore_eos, + "regex": self.regex, + } From 9a44fd8693d66c83ad6ef44e494e6534a12cf1be Mon Sep 17 00:00:00 2001 From: lzhang Date: Wed, 21 Aug 2024 16:59:07 +0800 Subject: [PATCH 6/6] also need convert stop_token_ids as list for json.dumps. --- python/sglang/srt/sampling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 8a52e5cd76c..712827d7929 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -128,7 +128,7 @@ def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens, "stop": self.stop_strs, - "stop_token_ids": self.stop_token_ids, + "stop_token_ids": list(self.stop_token_ids), "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k,