diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py new file mode 100644 index 00000000000..3c57e1144c0 --- /dev/null +++ b/python/sglang/bench_offline_throughput.py @@ -0,0 +1,305 @@ +""" +Benchmark the throughput of using the offline LLM engine. +This script does not launch a server. +It accepts the same arguments as launch_server.py and additional benchmark arguments + +# Usage +## Sharegpt dataset with default args +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct + +## Random dataset with default args +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random + +## Shared prefix dataset with default args +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name generated-shared-prefix + +## Sharegpt dataset on runtime backend +python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --backend runtime +""" + +import argparse +import dataclasses +import json +import logging +import random +import time +from typing import List, Tuple + +import numpy as np + +from sglang.api import Engine +from sglang.bench_serving import ( + get_dataset, + get_tokenizer, + sample_random_requests, + set_ulimit, +) +from sglang.srt.server import Runtime +from sglang.srt.server_args import ServerArgs + + +@dataclasses.dataclass +class BenchArgs: + backend: str = "engine" + result_filename: str = "" + dataset_name: str = "sharegpt" + dataset_path: str = "" + num_prompts: int = 1000 + sharegpt_output_len: int = 256 + random_input_len: int = 256 + random_output_len: int = 256 + random_range_ratio: float = 0.0 + gen_num_groups: int = 8 + gen_prompts_per_group: int = 16 + gen_system_prompt_len: int = 128 + gen_question_len: int = 256 + disable_ignore_eos: bool = False + seed: int = 1 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", type=str, default=BenchArgs.backend) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random", "generated-shared-prefix"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--num-prompts", + type=int, + default=BenchArgs.num_prompts, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=BenchArgs.sharegpt_output_len, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=BenchArgs.random_input_len, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=BenchArgs.random_output_len, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=BenchArgs.random_range_ratio, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--gen-num-groups", + type=int, + default=BenchArgs.gen_num_groups, + help="Number of groups with shared prefix, used" + "only for generate-shared-prefix", + ) + parser.add_argument( + "--gen-prompts-per-group", + type=int, + default=BenchArgs.gen_prompts_per_group, + help="Number of prompts per group of shared prefix, used" + "only for generate-shared-prefix", + ) + parser.add_argument( + "--gen-system-prompt-len", + type=int, + default=BenchArgs.gen_system_prompt_len, + help="System prompt length, used" "only for generate-shared-prefix", + ) + parser.add_argument( + "--gen-question-len", + type=int, + default=BenchArgs.gen_question_len, + help="Question length, used" "only for generate-shared-prefix", + ) + parser.add_argument( + "--disable-ignore-eos", + type=bool, + default=BenchArgs.disable_ignore_eos, + help="Disable ignore EOS token", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # use the default value's type to case the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + print(attrs) + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) + + +def throughput_test_once( + backend_name: str, + backend, + reqs: List[Tuple[str, int, int]], + ignore_eos: bool, +): + measurement_results = { + "backend": backend_name, + "successful_requests": len(reqs), + "total_latency": -1, + "total_input_tokens": sum(r[1] for r in reqs), + "total_output_tokens": -1, + "request_throughput": -1, + "input_throughput": -1, + "output_throughput": -1, + "total_throughput": -1, + } + + prompt = [r[0] for r in reqs] + sampling_params = [ + { + "temperature": 0, + "max_new_tokens": r[2], + "ignore_eos": ignore_eos, + } + for r in reqs + ] + + st = time.perf_counter() + gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) + latency = time.perf_counter() - st + + if backend_name == "runtime": + gen_out = json.loads(gen_out) + + measurement_results["total_latency"] = latency + measurement_results["total_output_tokens"] = sum( + o["meta_info"]["completion_tokens"] for o in gen_out + ) + measurement_results["request_throughput"] = ( + measurement_results["successful_requests"] / latency + ) + measurement_results["input_throughput"] = ( + measurement_results["total_input_tokens"] / latency + ) + measurement_results["output_throughput"] = ( + measurement_results["total_output_tokens"] / latency + ) + measurement_results["total_throughput"] = ( + measurement_results["total_input_tokens"] + + measurement_results["total_output_tokens"] + ) / latency + + return measurement_results + + +def throughput_test( + server_args: ServerArgs, + bench_args: BenchArgs, +): + if bench_args.backend == "engine": + backend = Engine(**dataclasses.asdict(server_args)) + if not backend: + raise ValueError("Please provide valid engine arguments") + elif bench_args.backend == "runtime": + backend = Runtime(**dataclasses.asdict(server_args)) + else: + raise ValueError('Please set backend to either "engine" or "runtime"') + + tokenizer_id = server_args.model_path + tokenizer = get_tokenizer(tokenizer_id) + + # Set global environmnets + set_ulimit() + random.seed(bench_args.seed) + np.random.seed(bench_args.seed) + + input_requests = get_dataset(bench_args, tokenizer) + + warmup_requests = sample_random_requests( + input_len=20, + output_len=4, + num_prompts=2, + range_ratio=0.8, + tokenizer=tokenizer, + dataset_path=bench_args.dataset_path, + ) + + # Warm up + throughput_test_once( + backend_name=bench_args.backend, + backend=backend, + reqs=warmup_requests, + ignore_eos=not bench_args.disable_ignore_eos, + ) + + result = throughput_test_once( + backend_name=bench_args.backend, + backend=backend, + reqs=input_requests, + ignore_eos=not bench_args.disable_ignore_eos, + ) + + if bench_args.result_filename: + with open(bench_args.result_filename, "a") as fout: + fout.write(json.dumps(result) + "\n") + + print( + "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=") + ) + print("{:<40} {:<10}".format("Backend:", result["backend"])) + print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"])) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"])) + print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"])) + print( + "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"]) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", result["request_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", result["input_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", result["output_throughput"] + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total token throughput (tok/s):", result["total_throughput"] + ) + ) + print("=" * 50) + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + throughput_test(server_args, bench_args) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index c0cf946ede9..68c67241302 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -421,6 +421,37 @@ def get_tokenizer( ) +def get_dataset(args, tokenizer): + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + elif args.dataset_name == "generated-shared-prefix": + input_requests = sample_generated_shared_prefix_requests( + num_groups=args.gen_num_groups, + prompts_per_group=args.gen_prompts_per_group, + system_prompt_len=args.gen_system_prompt_len, + question_len=args.gen_question_len, + output_len=args.gen_output_len, + tokenizer=tokenizer, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + return input_requests + + ASYNC_REQUEST_FUNCS = { "sglang": async_request_sglang_generate, "sglang-native": async_request_sglang_generate, @@ -443,6 +474,8 @@ class BenchmarkMetrics: input_throughput: float output_throughput: float output_throughput_retokenized: float + total_throughput: float + total_throughput_retokenized: float mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float @@ -590,7 +623,6 @@ def sample_random_requests( (data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset ] - # Shuffle the dataset. random.shuffle(dataset) @@ -764,6 +796,9 @@ def calculate_metrics( input_throughput=total_input / dur_s, output_throughput=sum(output_lens) / dur_s, output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + total_throughput=(total_input + sum(output_lens)) / dur_s, + total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) + / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, @@ -881,6 +916,11 @@ async def benchmark( "Output token throughput (tok/s):", metrics.output_throughput ) ) + print( + "{:<40} {:<10.2f}".format( + "Total token throughput (tok/s):", metrics.total_throughput + ) + ) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) @@ -1098,35 +1138,7 @@ def run_benchmark(args_: argparse.Namespace): tokenizer = get_tokenizer(tokenizer_id) - if args.dataset_name == "sharegpt": - assert args.random_input_len is None and args.random_output_len is None - input_requests = sample_sharegpt_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, - ) - elif args.dataset_name == "random": - assert args.random_input_len is not None and args.random_output_len is not None - input_requests = sample_random_requests( - input_len=args.random_input_len, - output_len=args.random_output_len, - num_prompts=args.num_prompts, - range_ratio=args.random_range_ratio, - tokenizer=tokenizer, - dataset_path=args.dataset_path, - ) - elif args.dataset_name == "generated-shared-prefix": - input_requests = sample_generated_shared_prefix_requests( - num_groups=args.gen_num_groups, - prompts_per_group=args.gen_prompts_per_group, - system_prompt_len=args.gen_system_prompt_len, - question_len=args.gen_question_len, - output_len=args.gen_output_len, - tokenizer=tokenizer, - ) - else: - raise ValueError(f"Unknown dataset: {args.dataset_name}") + input_requests = get_dataset(args, tokenizer) if not args.multi: return asyncio.run( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e27bb1bb97b..e4798877a83 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -768,7 +768,7 @@ def generate( self, # The input prompt. It can be a single prompt or a batch of prompts. prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Dict] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None, return_logprob: Optional[Union[List[bool], bool]] = False, diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index a375c2900d5..33232f50b41 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -11,7 +11,9 @@ import torch import sglang as sgl +from sglang.bench_offline_throughput import BenchArgs, throughput_test from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.server_args import ServerArgs from sglang.test.few_shot_gsm8k_engine import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -152,6 +154,14 @@ def test_6_engine_runtime_encode_consistency(self): self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) + def test_7_engine_offline_throughput(self): + server_args = ServerArgs( + model_path=DEFAULT_MODEL_NAME_FOR_TEST, + ) + bench_args = BenchArgs(num_prompts=100) + result = throughput_test(server_args=server_args, bench_args=bench_args) + self.assertTrue(result["total_throughput"] > 3000) + if __name__ == "__main__": unittest.main()