From cdc8d607524a9cf663d2319ff452168d99645e39 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 17 Aug 2024 14:37:52 -0700 Subject: [PATCH] Improve the code style: more comments and remove useless packages (#1139) --- .../srt/managers/detokenizer_manager.py | 4 +- python/sglang/srt/managers/io_struct.py | 37 ++++++++++++++----- python/sglang/srt/server.py | 1 - 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 08ccfd5cef0..12511ac44e5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -17,7 +17,6 @@ import asyncio import dataclasses -import inspect from typing import List import uvloop @@ -126,8 +125,6 @@ async def handle_loop(self): spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) - # Trim stop str - # TODO(lmzheng): handle the case where multiple stop strs are hit output_strs = [] for i in range(bs): s = self.decode_status[recv_obj.rids[i]] @@ -144,6 +141,7 @@ async def handle_loop(self): output_strs.append(s.decoded_text + new_text) + # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): pos = output_strs[i].find(recv_obj.finished_reason[i].matched) if pos != -1: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2d12505ae4e..82f280b6062 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,8 +22,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union -import torch - from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams @@ -43,9 +41,9 @@ class GenerateReqInput: rid: Optional[Union[List[str], str]] = None # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None - # The start location of the prompt for return_logprob. + # If return logprobs, the start location in the prompt for returning logprobs. logprob_start_len: Optional[Union[List[int], int]] = None - # The number of top logprobs to return. + # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: Optional[Union[List[int], int]] = None # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False @@ -155,16 +153,27 @@ def post_init(self): @dataclass class TokenizedGenerateReqInput: + # The request id rid: str + # The input text input_text: str + # The input token ids input_ids: List[int] + # The pixel values for input images pixel_values: List[float] + # The hash of input images image_hash: int + # The image size image_size: List[int] + # The sampling parameters sampling_params: SamplingParams + # Whether to return the logprobs return_logprob: bool + # If return logprobs, the start location in the prompt for returning logprobs. logprob_start_len: int + # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: int + # Whether to stream output stream: bool @@ -215,15 +224,21 @@ def post_init(self): @dataclass class TokenizedEmbeddingReqInput: + # The request id rid: str + # The input text input_text: str + # The input token ids input_ids: List[int] + # Dummy sampling params for compatibility sampling_params: SamplingParams @dataclass class BatchTokenIDOut: + # The request id rids: List[str] + # The version id to sync decode status with in detokenizer_manager vids: List[int] decoded_texts: List[str] decode_ids: List[int] @@ -236,17 +251,25 @@ class BatchTokenIDOut: @dataclass class BatchStrOut: + # The request id rids: List[str] + # The output decoded strings output_strs: List[str] + # The meta info meta_info: List[Dict] + # The finish reason finished_reason: List[BaseFinishReason] @dataclass class BatchEmbeddingOut: + # The request id rids: List[str] + # The output embedding embeddings: List[List[float]] + # The meta info meta_info: List[Dict] + # The finish reason finished_reason: List[BaseFinishReason] @@ -257,9 +280,5 @@ class FlushCacheReq: @dataclass class AbortReq: + # The request id rid: str - - -@dataclass -class DetokenizeReqInput: - input_ids: List[int] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 6bbf3050aef..9028c12309b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -34,7 +34,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp -import psutil import requests import uvicorn import uvloop