Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add accuracy test to CI: MMLU #882

Merged
merged 10 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/e2e-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ concurrency:
cancel-in-progress: true

jobs:
pr-e2e-test:
e2e-test:
runs-on: self-hosted

env:
Expand All @@ -38,7 +38,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
pip install --upgrade transformers

- name: Benchmark Serving
- name: Benchmark Serving Throughput
run: |
cd /data/zhyncs/venv && source ./bin/activate && cd -
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache &
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ jobs:

cd test/srt
python3 test_openai_server.py

- name: Test Accuracy
run: |
cd /data/zhyncs/venv && source ./bin/activate && cd -

cd test/srt
python3 test_eval_accuracy.py
8 changes: 3 additions & 5 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
import traceback
import warnings
from argparse import ArgumentParser as FlexibleArgumentParser
from argparse import ArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput."
)
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
required=True,
choices=list(ASYNC_REQUEST_FUNCS.keys()),
default="sglang",
help="Must specify a backend, depending on the LLM Inference Engine.",
)
parser.add_argument(
Expand Down
99 changes: 99 additions & 0 deletions python/sglang/test/run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Usage:
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
"""

import argparse
import json
import os
import time

from sglang.test.simple_eval_common import (
ChatCompletionSampler,
download_dataset,
make_report,
set_ulimit,
)
from sglang.test.simple_eval_mmlu import MMLUEval


def run_eval(args):
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = "EMPTY"

base_url = (
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
)

if args.eval_name == "mmlu":
dataset_path = "mmlu.csv"

if not os.path.exists(dataset_path):
download_dataset(
dataset_path,
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
)
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
else:
raise ValueError(f"Invalid eval name: {args.eval_name}")

sampler = ChatCompletionSampler(
model=args.model,
max_tokens=2048,
base_url=base_url,
)

# Run eval
tic = time.time()
result = eval_obj(sampler)
latency = time.time() - tic

# Dump reports
metrics = result.metrics | {"score": result.score}
file_stem = f"mmlu_{sampler.model.replace('/', '_')}"
report_filename = f"/tmp/{file_stem}.html"
print(f"Writing report to {report_filename}")
with open(report_filename, "w") as fh:
fh.write(make_report(result))
metrics = result.metrics | {"score": result.score}
print(metrics)
result_filename = f"/tmp/{file_stem}.json"
with open(result_filename, "w") as f:
f.write(json.dumps(metrics, indent=2))
print(f"Writing results to {result_filename}")

# Print results
print(f"Total latency: {latency:.3f} s")
print(f"Score: {metrics['score']:.3f}")

return metrics


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
)
parser.add_argument(
"--port",
type=int,
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
)
parser.add_argument(
"--model",
type=str,
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
)
parser.add_argument("--eval-name", type=str, default="mmlu")
parser.add_argument("--num-examples", type=int)
parser.add_argument("--num-threads", type=int, default=64)
set_ulimit()
args = parser.parse_args()

run_eval(args)
Loading
Loading