diff --git a/python/pyproject.toml b/python/pyproject.toml index 963c8253fda..b59ef852bda 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", - "psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart"] + "psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index c4ffce634bd..a3bf9158c49 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -1,13 +1,13 @@ """ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. -# Usage (latency test): +# Usage (latency test) with dummy weights: python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy # Usage (correctness test): python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct -### Reference output: +### Reference output (of the correctness test above, can be gpu dependent): prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], @@ -31,7 +31,9 @@ import logging import multiprocessing import time +from typing import Tuple +import jsonlines import numpy as np import torch import torch.distributed as dist @@ -47,25 +49,34 @@ @dataclasses.dataclass class BenchArgs: - batch_size: int = 1 + batch_size: Tuple[int] = (1,) input_len: int = 1024 output_len: int = 4 + result_filename: str = "" correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 @staticmethod def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size) + parser.add_argument( + "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size + ) parser.add_argument("--input-len", type=int, default=BenchArgs.input_len) parser.add_argument("--output-len", type=int, default=BenchArgs.output_len) + parser.add_argument( + "--result-filename", type=str, default=BenchArgs.result_filename + ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) @classmethod def from_cli_args(cls, args: argparse.Namespace): - attrs = [attr.name for attr in dataclasses.fields(cls)] - return cls(**{attr: getattr(args, attr) for attr in attrs}) + # use the default value's type to case the args into correct types. + attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] + return cls( + **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} + ) def load_model(server_args, tp_rank): @@ -93,7 +104,7 @@ def load_model(server_args, tp_rank): return model_runner, tokenizer -def prepare_inputs(bench_args, tokenizer): +def prepare_inputs_for_correctness_test(bench_args, tokenizer): prompts = [ "The capital of France is", "The capital of the United Kindom is", @@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer): return input_ids, reqs -def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): +def prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner +): for i in range(len(reqs)): req = reqs[i] req.input_ids += input_ids[i][bench_args.cut_len :] @@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): return reqs -def prepare_synthetic_inputs(bench_args, tokenizer): - input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32) +def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): + input_ids = np.ones((batch_size, input_len), dtype=np.int32) sampling_params = SamplingParams( temperature=0, max_new_tokens=BenchArgs.output_len, @@ -179,7 +192,7 @@ def correctness_test( model_runner, tokenizer = load_model(server_args, tp_rank) # Prepare inputs - input_ids, reqs = prepare_inputs(bench_args, tokenizer) + input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) if bench_args.cut_len > 0: # Prefill @@ -187,7 +200,9 @@ def correctness_test( rank_print("prefill logits (first half)", next_token_logits) # Prepare extend inputs - reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner) + reqs = prepare_extend_inputs_for_correctness_test( + bench_args, input_ids, reqs, model_runner + ) # Extend next_token_ids, next_token_logits, batch = extend(reqs, model_runner) @@ -218,8 +233,13 @@ def latency_test( f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" ) + # To make this PR easier to review, for now, only do the first element in batch_size tuple. + bench_args.batch_size = bench_args.batch_size[0] + # Prepare inputs - reqs = prepare_synthetic_inputs(bench_args, tokenizer) + reqs = prepare_synthetic_inputs_for_latency_test( + bench_args.batch_size, bench_args.input_len + ) def clear(): model_runner.req_to_token_pool.clear() @@ -227,6 +247,11 @@ def clear(): @torch.inference_mode() def run_once(output_len): + measurement_results = { + "batch_size": bench_args.batch_size, + "output_len": output_len, + } + # Prefill torch.cuda.synchronize() tot_latency = 0 @@ -239,6 +264,8 @@ def run_once(output_len): rank_print( f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) + measurement_results["prefill_latency"] = prefill_latency + measurement_results["prefill_throughput"] = throughput # Decode for i in range(output_len): @@ -258,6 +285,8 @@ def run_once(output_len): rank_print( f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" ) + measurement_results["avg_decode_latency"] = avg_decode_latency + measurement_results["avg_decode_throughput"] = avg_decode_throughput throughput = ( (bench_args.input_len + bench_args.output_len) @@ -267,13 +296,22 @@ def run_once(output_len): rank_print( f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" ) + measurement_results["total_latency"] = tot_latency + measurement_results["total_throughput"] = throughput + return measurement_results # Warm up run_once(4) clear() # Run again - run_once(bench_args.output_len) + result_list = [] + result_list.append(run_once(bench_args.output_len)) + + # Write results in jsonlines format. + if bench_args.result_filename: + with jsonlines.open(bench_args.result_filename, "a") as f: + f.write_all(result_list) def main(server_args, bench_args):