diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index fe8fc5260b6..329dad33609 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -56,3 +56,22 @@ with nvtx.annotate("description", color="color"): ## Other tips 1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. + +## Profile with PyTorch Profiler +- To profile a server +```bash +# set trace path +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log +# start server +python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct + +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile +``` + +Traces can be visualized using https://ui.perfetto.dev/. + +- To profile offline +```bash +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log +python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 +``` diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f1c4e8f9e18..70fbb9add51 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -14,6 +14,7 @@ import dataclasses import json import logging +import os import random import time from typing import Dict, List, Optional, Tuple @@ -27,7 +28,7 @@ sample_random_requests, set_ulimit, ) -from sglang.srt.server import Runtime +from sglang.srt.server import Runtime, start_profile, stop_profile from sglang.srt.server_args import ServerArgs @@ -52,6 +53,7 @@ class BenchArgs: seed: int = 1 skip_warmup: bool = False do_not_exit: bool = False + profile: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -156,6 +158,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -169,6 +177,7 @@ def throughput_test_once( reqs: List[Tuple[str, int, int]], ignore_eos: bool, extra_request_body: Dict, + profile: bool, ): measurement_results = { "backend": backend_name, @@ -194,7 +203,15 @@ def throughput_test_once( ] st = time.perf_counter() + if profile: + start_profile() + gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) + + if profile: + stop_profile() + monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR")) + latency = time.perf_counter() - st if backend_name == "runtime": @@ -221,6 +238,41 @@ def throughput_test_once( return measurement_results +def monitor_trace_file(directory, interval=1): + + print(f"Monitoring {directory} for new trace files...") + + known_files = set(os.listdir(directory)) + + while True: + flag = False + time.sleep(interval) + current_files = set(os.listdir(directory)) + + new_files = current_files - known_files + for new_file in new_files: + new_file_path = os.path.join(directory, new_file) + print(f"New file detected: {new_file}") + + previous_size = 0 + while True: + try: + current_size = os.path.getsize(new_file_path) + except FileNotFoundError: + print(f"File {new_file} is no longer accessible.") + break + + if current_size > previous_size: + previous_size = current_size + else: + flag = True + break + + time.sleep(interval) + if flag: + break + + def throughput_test( server_args: ServerArgs, bench_args: BenchArgs, @@ -268,6 +320,7 @@ def throughput_test( reqs=warmup_requests, ignore_eos=not bench_args.disable_ignore_eos, extra_request_body=extra_request_body, + profile=False, ) logging.info("\nBenchmark...") @@ -277,6 +330,7 @@ def throughput_test( reqs=input_requests, ignore_eos=not bench_args.disable_ignore_eos, extra_request_body=extra_request_body, + profile=bench_args.profile, ) if bench_args.result_filename: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0d303c2cb14..a4753a13458 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -169,9 +169,19 @@ async def flush_cache(): ) +def start_profile(): + """Start profiling.""" + tokenizer_manager.start_profile() + + +def stop_profile(): + """Stop profiling.""" + tokenizer_manager.stop_profile() + + @app.get("/start_profile") @app.post("/start_profile") -async def start_profile(): +async def start_profile_async(): """Start profiling.""" tokenizer_manager.start_profile() return Response( @@ -182,7 +192,7 @@ async def start_profile(): @app.get("/stop_profile") @app.post("/stop_profile") -async def stop_profile(): +async def stop_profile_async(): """Stop profiling.""" tokenizer_manager.stop_profile() return Response(