Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stop_token_ids in sglang API #1092

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand Down Expand Up @@ -98,6 +99,7 @@ def gen(
name,
max_tokens,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
Expand All @@ -117,6 +119,7 @@ def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -132,6 +135,7 @@ def gen_int(
name,
max_tokens,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
Expand All @@ -151,6 +155,7 @@ def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -166,6 +171,7 @@ def gen_string(
name,
max_tokens,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
SglConstantText,
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
Expand Down Expand Up @@ -181,8 +180,10 @@ def __init__(
num_api_spec_tokens=None,
use_thread=True,
):
from sglang.lang.backend.base_backend import BaseBackend

self.sid = uuid.uuid4().hex
self.backend = backend
self.backend: BaseBackend = backend
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
Expand Down Expand Up @@ -658,6 +659,7 @@ def _resolve_sampling_params(self, sampling_params):
for item in [
"max_new_tokens",
"stop",
"stop_token_ids",
"temperature",
"top_p",
"top_k",
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class SglSamplingParams:
max_new_tokens: int = 128
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
Expand All @@ -37,6 +38,7 @@ def clone(self):
return SglSamplingParams(
self.max_new_tokens,
self.stop,
self.stop_token_ids,
self.temperature,
self.top_p,
self.top_k,
Expand Down Expand Up @@ -108,6 +110,7 @@ def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
Expand Down Expand Up @@ -141,7 +144,8 @@ def run(
self,
*args,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
stop: Union[str, List[str]] = [],
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand All @@ -161,6 +165,7 @@ def run(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand All @@ -181,6 +186,7 @@ def run_batch(
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand Down Expand Up @@ -218,6 +224,7 @@ def run_batch(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand Down Expand Up @@ -397,6 +404,7 @@ def __init__(
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -416,6 +424,7 @@ def __init__(
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand Down
10 changes: 6 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,12 @@ def check_finished(self):
return

last_token_id = self.output_ids[-1]
if self.tokenizer is None:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
else:
matched_eos = last_token_id == self.tokenizer.eos_token_id

matched_eos = last_token_id in self.sampling_params.stop_token_ids

if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id

if matched_eos and not self.sampling_params.ignore_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/test/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,16 @@ def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR

s += "Generate a JSON object to describe the basic city information of Paris.\n"
s += "Here are the JSON object:\n"

# NOTE: we recommend using dtype gen or whole regex string to control the output

with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
s += "}"

ret = decode_json.run(temperature=0.0)
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_moe_serving_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_default_without_radix_cache(self):

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
# A100 (PCIE) performance
assert res["output_throughput"] > 940
assert res["output_throughput"] > 930

def test_default_with_chunked_prefill(self):
res = self.run_test(
Expand Down
Loading