Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

latency test enhancement - part 1 #909

Merged
merged 4 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [

[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart"]
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
Expand Down
66 changes: 52 additions & 14 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.

# Usage (latency test):
# Usage (latency test) with dummy weights:
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy

# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct

### Reference output:
### Reference output (of the correctness test above, can be gpu dependent):
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
Expand All @@ -31,7 +31,9 @@
import logging
import multiprocessing
import time
from typing import Tuple

import jsonlines
import numpy as np
import torch
import torch.distributed as dist
Expand All @@ -47,25 +49,34 @@

@dataclasses.dataclass
class BenchArgs:
batch_size: int = 1
batch_size: Tuple[int] = (1,)
input_len: int = 1024
output_len: int = 4
result_filename: str = ""
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
# use the default value's type to case the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)


def load_model(server_args, tp_rank):
Expand Down Expand Up @@ -93,7 +104,7 @@ def load_model(server_args, tp_rank):
return model_runner, tokenizer


def prepare_inputs(bench_args, tokenizer):
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
Expand All @@ -119,7 +130,9 @@ def prepare_inputs(bench_args, tokenizer):
return input_ids, reqs


def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)):
req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len :]
Expand All @@ -129,8 +142,8 @@ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
return reqs


def prepare_synthetic_inputs(bench_args, tokenizer):
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
input_ids = np.ones((batch_size, input_len), dtype=np.int32)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
Expand Down Expand Up @@ -179,15 +192,17 @@ def correctness_test(
model_runner, tokenizer = load_model(server_args, tp_rank)

# Prepare inputs
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)

if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print("prefill logits (first half)", next_token_logits)

# Prepare extend inputs
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)

# Extend
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
Expand Down Expand Up @@ -218,15 +233,25 @@ def latency_test(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)

# To make this PR easier to review, for now, only do the first element in batch_size tuple.
bench_args.batch_size = bench_args.batch_size[0]
min-xu-et marked this conversation as resolved.
Show resolved Hide resolved

# Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size, bench_args.input_len
)

def clear():
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()

@torch.inference_mode()
def run_once(output_len):
measurement_results = {
"batch_size": bench_args.batch_size,
"output_len": output_len,
}

# Prefill
torch.cuda.synchronize()
tot_latency = 0
Expand All @@ -239,6 +264,8 @@ def run_once(output_len):
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput

# Decode
for i in range(output_len):
Expand All @@ -258,6 +285,8 @@ def run_once(output_len):
rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
)
measurement_results["avg_decode_latency"] = avg_decode_latency
measurement_results["avg_decode_throughput"] = avg_decode_throughput

throughput = (
(bench_args.input_len + bench_args.output_len)
Expand All @@ -267,13 +296,22 @@ def run_once(output_len):
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["total_throughput"] = throughput
return measurement_results

# Warm up
run_once(4)
clear()

# Run again
run_once(bench_args.output_len)
result_list = []
result_list.append(run_once(bench_args.output_len))

# Write results in jsonlines format.
if bench_args.result_filename:
with jsonlines.open(bench_args.result_filename, "a") as f:
f.write_all(result_list)


def main(server_args, bench_args):
Expand Down
Loading