From 044740e18281524a4f9198406b044dd8271b5e4f Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 02:41:03 +0000
Subject: [PATCH 01/10] eval mmlu

---
 python/sglang/bench_serving.py             |   5 +-
 python/sglang/test/eval_common.py          | 452 +++++++++++++++++++++
 python/sglang/test/eval_mmlu.py            | 164 ++++++++
 python/sglang/test/test_conversation.py    |  46 ---
 python/sglang/test/test_openai_protocol.py |  51 ---
 5 files changed, 619 insertions(+), 99 deletions(-)
 create mode 100644 python/sglang/test/eval_common.py
 create mode 100644 python/sglang/test/eval_mmlu.py
 delete mode 100644 python/sglang/test/test_conversation.py
 delete mode 100644 python/sglang/test/test_openai_protocol.py

diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py
index b52e114fd70..dc733012aa5 100644
--- a/python/sglang/bench_serving.py
+++ b/python/sglang/bench_serving.py
@@ -21,7 +21,7 @@
 import time
 import traceback
 import warnings
-from argparse import ArgumentParser as FlexibleArgumentParser
+from argparse import ArgumentParser
 from dataclasses import dataclass, field
 from datetime import datetime
 from typing import AsyncGenerator, List, Optional, Tuple, Union
@@ -868,7 +868,7 @@ def set_ulimit(target_soft_limit=65535):
 
 
 if __name__ == "__main__":
-    parser = FlexibleArgumentParser(
+    parser = ArgumentParser(
         description="Benchmark the online serving throughput."
     )
     parser.add_argument(
@@ -876,6 +876,7 @@ def set_ulimit(target_soft_limit=65535):
         type=str,
         required=True,
         choices=list(ASYNC_REQUEST_FUNCS.keys()),
+        default="sglang",
         help="Must specify a backend, depending on the LLM Inference Engine.",
     )
     parser.add_argument(
diff --git a/python/sglang/test/eval_common.py b/python/sglang/test/eval_common.py
new file mode 100644
index 00000000000..06bfac46dbb
--- /dev/null
+++ b/python/sglang/test/eval_common.py
@@ -0,0 +1,452 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+from dataclasses import dataclass, field
+from typing import Any
+import resource
+
+Message = dict[str, Any]  # keys role, content
+MessageList = list[Message]
+
+
+class SamplerBase:
+    """
+    Base class for defining a sampling model, which can be evaluated,
+    or used as part of the grading process.
+    """
+
+    def __call__(self, message_list: MessageList) -> str:
+        raise NotImplementedError()
+
+
+@dataclass
+class EvalResult:
+    """
+    Result of running an evaluation (usually consisting of many samples)
+    """
+
+    score: float | None  # top-line metric
+    metrics: dict[str, float] | None  # other metrics
+    htmls: list[str]  # strings of valid HTML
+    convos: list[MessageList]  # sampled conversations
+
+
+@dataclass
+class SingleEvalResult:
+    """
+    Result of evaluating a single sample
+    """
+
+    score: float | None
+    metrics: dict[str, float] = field(default_factory=dict)
+    html: str | None = None
+    convo: MessageList | None = None  # sampled conversation
+
+
+class Eval:
+    """
+    Base class for defining an evaluation.
+    """
+
+    def __call__(self, sampler: SamplerBase) -> EvalResult:
+        raise NotImplementedError()
+
+
+import base64
+import time
+from typing import Any
+
+import openai
+from openai import OpenAI
+
+
+OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
+OPENAI_SYSTEM_MESSAGE_CHATGPT = (
+    "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
+    + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
+)
+
+
+class ChatCompletionSampler(SamplerBase):
+    """
+    Sample from OpenAI's chat completion API
+    """
+
+    def __init__(
+        self,
+        base_url: str = None,
+        model: str | None = None,
+        system_message: str | None = None,
+        temperature: float = 0.0,
+        max_tokens: int = 2048,
+    ):
+        self.client = OpenAI(base_url=base_url)
+
+        if model is None:
+            model = self.client.models.list().data[0].id
+
+        self.model = model
+        self.system_message = system_message
+        self.temperature = temperature
+        self.max_tokens = max_tokens
+        self.image_format = "url"
+
+    def _handle_image(
+        self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768
+    ):
+        new_image = {
+            "type": "image_url",
+            "image_url": {
+                "url": f"data:image/{format};{encoding},{image}",
+            },
+        }
+        return new_image
+
+    def _handle_text(self, text: str):
+        return {"type": "text", "text": text}
+
+    def _pack_message(self, role: str, content: Any):
+        return {"role": str(role), "content": content}
+
+    def __call__(self, message_list: MessageList) -> str:
+        if self.system_message:
+            message_list = [self._pack_message("system", self.system_message)] + message_list
+        trial = 0
+        while True:
+            try:
+                response = self.client.chat.completions.create(
+                    model=self.model,
+                    messages=message_list,
+                    temperature=self.temperature,
+                    max_tokens=self.max_tokens,
+                )
+                return response.choices[0].message.content
+            # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
+            except openai.BadRequestError as e:
+                print("Bad Request Error", e)
+                return ""
+            except Exception as e:
+                exception_backoff = 2**trial  # expontial back off
+                print(
+                    f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
+                    e,
+                )
+                time.sleep(exception_backoff)
+                trial += 1
+            # unknown error shall throw exception
+
+
+import os
+from collections import defaultdict
+from multiprocessing.pool import ThreadPool
+from typing import Any
+
+import jinja2
+import numpy as np
+from tqdm import tqdm
+
+
+QUERY_TEMPLATE_MULTICHOICE = """
+Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
+
+{Question}
+
+A) {A}
+B) {B}
+C) {C}
+D) {D}
+""".strip()
+
+ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
+ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
+
+
+EQUALITY_TEMPLATE = r"""
+Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
+
+Examples:
+
+    Expression 1: $2x+3$
+    Expression 2: $3+2x$
+
+Yes
+
+    Expression 1: 3/2
+    Expression 2: 1.5
+
+Yes
+
+    Expression 1: $x^2+2x+1$
+    Expression 2: $y^2+2y+1$
+
+No
+
+    Expression 1: $x^2+2x+1$
+    Expression 2: $(x+1)^2$
+
+Yes
+
+    Expression 1: 3245/5
+    Expression 2: 649
+
+No
+(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
+
+    Expression 1: 2/(-3)
+    Expression 2: -2/3
+
+Yes
+(trivial simplifications are allowed)
+
+    Expression 1: 72 degrees
+    Expression 2: 72
+
+Yes
+(give benefit of the doubt to units)
+
+    Expression 1: 64
+    Expression 2: 64 square feet
+
+Yes
+(give benefit of the doubt to units)
+
+---
+
+YOUR TASK
+
+
+Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
+
+    Expression 1: %(expression1)s
+    Expression 2: %(expression2)s
+""".strip()
+
+
+HTML_JINJA = """
+<h3>Prompt conversation</h3>
+{% for message in prompt_messages %}
+{{ message_to_html(message) | safe }}
+{% endfor %}
+<h3>Sampled message</h3>
+{{ message_to_html(next_message) | safe }}
+<h3>Results</h3>
+<p>Correct Answer: {{ correct_answer }}</p>
+<p>Extracted Answer: {{ extracted_answer }}</p>
+<p>Score: {{ score }}</p>
+"""
+
+
+def format_multichoice_question(row):
+    return QUERY_TEMPLATE_MULTICHOICE.format(**row)
+
+
+def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
+    prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
+    response = sampler([dict(content=prompt, role="user")])
+    return response.lower().strip() == "yes"
+
+
+def _compute_stat(values: list, stat: str):
+    if stat == "mean":
+        return np.mean(values)
+    elif stat == "std":
+        return np.std(values)
+    elif stat == "min":
+        return np.min(values)
+    elif stat == "max":
+        return np.max(values)
+    else:
+        raise ValueError(f"Unknown {stat =}")
+
+
+def aggregate_results(
+    single_eval_results: list[SingleEvalResult],
+    default_stats: tuple[str] = ("mean", "std"),
+    name2stats: dict[str, tuple[str]] | None = None,
+) -> EvalResult:
+    """
+    Aggregate results from multiple evaluations into a single EvalResult.
+    """
+    name2stats = name2stats or {}
+    name2values = defaultdict(list)
+    htmls = []
+    convos = []
+    for single_eval_result in single_eval_results:
+        for name, value in single_eval_result.metrics.items():
+            name2values[name].append(value)
+        if single_eval_result.score is not None:
+            name2values["score"].append(single_eval_result.score)
+        htmls.append(single_eval_result.html)
+        convos.append(single_eval_result.convo)
+    final_metrics = {}
+    for name, values in name2values.items():
+        stats = name2stats.get(name, default_stats)
+        for stat in stats:
+            key = name if stat == "mean" else f"{name}:{stat}"
+            final_metrics[key] = _compute_stat(values, stat)
+    return EvalResult(
+        score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos
+    )
+
+
+def map_with_progress(f: callable, xs: list[Any], num_threads: int = 50):
+    """
+    Apply f to each element of xs, using a ThreadPool, and show progress.
+    """
+    if os.getenv("debug"):
+        return list(map(f, tqdm(xs, total=len(xs))))
+    else:
+        with ThreadPool(min(num_threads, len(xs))) as pool:
+            return list(tqdm(pool.imap(f, xs), total=len(xs)))
+
+
+jinja_env = jinja2.Environment(
+    loader=jinja2.BaseLoader(),
+    undefined=jinja2.StrictUndefined,
+    autoescape=jinja2.select_autoescape(["html", "xml"]),
+)
+_message_template = """
+<div class="message {{ role }}">
+    <div class="role">
+    {{ role }} 
+    {% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
+    </div>
+    <div class="content">
+    <pre>{{ content }}</pre>
+    </div>
+</div>
+"""
+
+
+def message_to_html(message: Message) -> str:
+    """
+    Generate HTML snippet (inside a <div>) for a message.
+    """
+    return jinja_env.from_string(_message_template).render(
+        role=message["role"], content=message["content"], variant=message.get("variant", None)
+    )
+
+
+jinja_env.globals["message_to_html"] = message_to_html
+
+
+_report_template = """<!DOCTYPE html>
+<html>
+    <head>
+        <style>
+            .message {
+                padding: 8px 16px;
+                margin-bottom: 8px;
+                border-radius: 4px;
+            }
+            .message.user {
+                background-color: #B2DFDB;
+                color: #00695C;
+            }
+            .message.assistant {
+                background-color: #B39DDB;
+                color: #4527A0;
+            }
+            .message.system {
+                background-color: #EEEEEE;
+                color: #212121;
+            }
+            .role {
+                font-weight: bold;
+                margin-bottom: 4px;
+            }
+            .variant {
+                color: #795548;
+            }
+            table, th, td {
+                border: 1px solid black;
+            }
+            pre {
+                white-space: pre-wrap;
+            }
+        </style>
+    </head>
+    <body>
+    {% if metrics %}
+    <h1>Metrics</h1>
+    <table>
+    <tr>
+        <th>Metric</th>
+        <th>Value</th>
+    </tr>
+    <tr>
+        <td><b>Score</b></td>
+        <td>{{ score | float | round(3) }}</td>
+    </tr>
+    {% for name, value in metrics.items() %}
+    <tr>
+        <td>{{ name }}</td>
+        <td>{{ value }}</td>
+    </tr>
+    {% endfor %}
+    </table>
+    {% endif %}
+    <h1>Examples</h1>
+    {% for html in htmls %}
+    {{ html | safe }}
+    <hr>
+    {% endfor %}
+    </body>
+</html>
+"""
+
+
+def make_report(eval_result: EvalResult) -> str:
+    """
+    Create a standalone HTML report from an EvalResult.
+    """
+    return jinja_env.from_string(_report_template).render(
+        score=eval_result.score,
+        metrics=eval_result.metrics,
+        htmls=eval_result.htmls,
+    )
+
+
+def make_report_from_example_htmls(htmls: list[str]):
+    """
+    Create a standalone HTML report from a list of example htmls
+    """
+    return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls)
+
+
+import requests
+from tqdm import tqdm
+
+def download_dataset(path, url):
+    print(f"Downloading dataset {path} from {url}")
+    try:
+        response = requests.get(url, stream=True)
+        response.raise_for_status()
+
+        total_size = int(response.headers.get("content-length", 0))
+        block_size = 8192
+
+        with open(path, "wb") as f, tqdm(
+            desc="Downloading",
+            total=total_size,
+            unit="iB",
+            unit_scale=True,
+            unit_divisor=1024,
+        ) as progress_bar:
+            for data in response.iter_content(block_size):
+                size = f.write(data)
+                progress_bar.update(size)
+
+        print(f"Dataset downloaded and saved to {path}")
+    except requests.RequestException as e:
+        raise Exception(f"Failed to download dataset: {e}")
+
+def set_ulimit(target_soft_limit=65535):
+    resource_type = resource.RLIMIT_NOFILE
+    current_soft, current_hard = resource.getrlimit(resource_type)
+
+    if current_soft < target_soft_limit:
+        try:
+            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
+        except ValueError as e:
+            print(f"Fail to set RLIMIT_NOFILE: {e}")
+
diff --git a/python/sglang/test/eval_mmlu.py b/python/sglang/test/eval_mmlu.py
new file mode 100644
index 00000000000..ac039e00336
--- /dev/null
+++ b/python/sglang/test/eval_mmlu.py
@@ -0,0 +1,164 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+"""
+Measuring Massive Multitask Language Understanding
+Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
+https://arxiv.org/abs/2009.03300
+"""
+
+import argparse
+import os
+import random
+import re
+import time
+
+import pandas
+
+from sglang.test.eval_common import (ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question, Eval, EvalResult, SamplerBase, SingleEvalResult, download_dataset, ChatCompletionSampler, map_with_progress, jinja_env, aggregate_results, set_ulimit)
+
+subject2category = {
+    "abstract_algebra": "stem",
+    "anatomy": "other",
+    "astronomy": "stem",
+    "business_ethics": "other",
+    "clinical_knowledge": "other",
+    "college_biology": "stem",
+    "college_chemistry": "stem",
+    "college_computer_science": "stem",
+    "college_mathematics": "stem",
+    "college_medicine": "other",
+    "college_physics": "stem",
+    "computer_security": "stem",
+    "conceptual_physics": "stem",
+    "econometrics": "social_sciences",
+    "electrical_engineering": "stem",
+    "elementary_mathematics": "stem",
+    "formal_logic": "humanities",
+    "global_facts": "other",
+    "high_school_biology": "stem",
+    "high_school_chemistry": "stem",
+    "high_school_computer_science": "stem",
+    "high_school_european_history": "humanities",
+    "high_school_geography": "social_sciences",
+    "high_school_government_and_politics": "social_sciences",
+    "high_school_macroeconomics": "social_sciences",
+    "high_school_mathematics": "stem",
+    "high_school_microeconomics": "social_sciences",
+    "high_school_physics": "stem",
+    "high_school_psychology": "social_sciences",
+    "high_school_statistics": "stem",
+    "high_school_us_history": "humanities",
+    "high_school_world_history": "humanities",
+    "human_aging": "other",
+    "human_sexuality": "social_sciences",
+    "international_law": "humanities",
+    "jurisprudence": "humanities",
+    "logical_fallacies": "humanities",
+    "machine_learning": "stem",
+    "management": "other",
+    "marketing": "other",
+    "medical_genetics": "other",
+    "miscellaneous": "other",
+    "moral_disputes": "humanities",
+    "moral_scenarios": "humanities",
+    "nutrition": "other",
+    "philosophy": "humanities",
+    "prehistory": "humanities",
+    "professional_accounting": "other",
+    "professional_law": "humanities",
+    "professional_medicine": "other",
+    "professional_psychology": "social_sciences",
+    "public_relations": "social_sciences",
+    "security_studies": "social_sciences",
+    "sociology": "social_sciences",
+    "us_foreign_policy": "social_sciences",
+    "virology": "other",
+    "world_religions": "humanities",
+}
+
+
+class MMLUEval(Eval):
+    def __init__(self, filename: str, num_examples: int | None = None):
+        df = pandas.read_csv(filename)
+        examples = [row.to_dict() for _, row in df.iterrows()]
+        if num_examples:
+            examples = random.Random(0).sample(examples, num_examples)
+        self.examples = examples
+
+    def __call__(self, sampler: SamplerBase) -> EvalResult:
+        def fn(row: dict):
+            prompt_messages = [
+                sampler._pack_message(content=format_multichoice_question(row), role="user")
+            ]
+            response_text = sampler(prompt_messages)
+            match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
+            extracted_answer = match.group(1) if match else None
+            score = 1.0 if extracted_answer == row["Answer"] else 0.0
+            html = jinja_env.from_string(HTML_JINJA).render(
+                prompt_messages=prompt_messages,
+                next_message=dict(content=response_text, role="assistant"),
+                score=score,
+                correct_answer=row["Answer"],
+                extracted_answer=extracted_answer,
+            )
+            convo = prompt_messages + [dict(content=response_text, role="assistant")]
+            category = subject2category.get(row["Subject"], "other")
+            return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo)
+
+        results = map_with_progress(fn, self.examples)
+        return aggregate_results(results)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--dataset-path", type=str, default="mmlu.csv", help="Path to the dataset."
+    )
+    parser.add_argument(
+        "--base-url",
+        type=str,
+        default=None,
+        help="Server or API base url if not using http host and port.",
+    )
+    parser.add_argument(
+        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
+    )
+    parser.add_argument(
+        "--port",
+        type=int,
+        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
+    )
+    parser.add_argument(
+        "--num-examples",
+        type=int,
+        help="The number of examples."
+    )
+    set_ulimit()
+    args = parser.parse_args()
+
+    base_url = f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
+
+    if not os.path.exists(args.dataset_path):
+        download_dataset(args.dataset_path, "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv")
+    eval_obj = MMLUEval(args.dataset_path, num_examples=args.num_examples)
+    sampler = ChatCompletionSampler(
+        model=args.model,
+        max_tokens=2048,
+        base_url=base_url,
+    )
+
+    tic = time.time()
+
+    result = eval_obj(sampler)
+    metrics = result.metrics | {"score": result.score}
+
+    latency = time.time() - tic
+    score = metrics["score"]
+
+    print(f"Total latency: {latency:.3f} s")
+    print(f"Score: {score:.3f}")
diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py
deleted file mode 100644
index e6d9f396aa7..00000000000
--- a/python/sglang/test/test_conversation.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from sglang.srt.conversation import generate_chat_conv
-from sglang.srt.managers.openai_api.protocol import (
-    ChatCompletionMessageContentImagePart,
-    ChatCompletionMessageContentImageURL,
-    ChatCompletionMessageContentTextPart,
-    ChatCompletionMessageGenericParam,
-    ChatCompletionMessageUserParam,
-    ChatCompletionRequest,
-)
-
-
-def test_chat_completion_to_conv_image():
-    """Test that we can convert a chat image request to a convo"""
-    request = ChatCompletionRequest(
-        model="default",
-        messages=[
-            ChatCompletionMessageGenericParam(
-                role="system", content="You are a helpful AI assistant"
-            ),
-            ChatCompletionMessageUserParam(
-                role="user",
-                content=[
-                    ChatCompletionMessageContentTextPart(
-                        type="text", text="Describe this image"
-                    ),
-                    ChatCompletionMessageContentImagePart(
-                        type="image_url",
-                        image_url=ChatCompletionMessageContentImageURL(
-                            url="https://someurl.com"
-                        ),
-                    ),
-                ],
-            ),
-        ],
-    )
-    conv = generate_chat_conv(request, "vicuna_v1.1")
-    assert conv.messages == [
-        ["USER", "Describe this image<image>"],
-        ["ASSISTANT", None],
-    ]
-    assert conv.system_message == "You are a helpful AI assistant"
-    assert conv.image_data == ["https://someurl.com"]
-
-
-if __name__ == "__main__":
-    test_chat_completion_to_conv_image()
diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py
deleted file mode 100644
index cade4728cba..00000000000
--- a/python/sglang/test/test_openai_protocol.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from sglang.srt.managers.openai_api.protocol import (
-    ChatCompletionMessageContentImagePart,
-    ChatCompletionMessageContentImageURL,
-    ChatCompletionMessageContentTextPart,
-    ChatCompletionMessageGenericParam,
-    ChatCompletionMessageUserParam,
-    ChatCompletionRequest,
-)
-
-
-def test_chat_completion_request_image():
-    """Test that Chat Completion Requests with images can be converted."""
-
-    image_request = {
-        "model": "default",
-        "messages": [
-            {"role": "system", "content": "You are a helpful AI assistant"},
-            {
-                "role": "user",
-                "content": [
-                    {"type": "text", "text": "Describe this image"},
-                    {"type": "image_url", "image_url": {"url": "https://someurl.com"}},
-                ],
-            },
-        ],
-        "temperature": 0,
-        "max_tokens": 64,
-    }
-    request = ChatCompletionRequest(**image_request)
-    assert len(request.messages) == 2
-    assert request.messages[0] == ChatCompletionMessageGenericParam(
-        role="system", content="You are a helpful AI assistant"
-    )
-    assert request.messages[1] == ChatCompletionMessageUserParam(
-        role="user",
-        content=[
-            ChatCompletionMessageContentTextPart(
-                type="text", text="Describe this image"
-            ),
-            ChatCompletionMessageContentImagePart(
-                type="image_url",
-                image_url=ChatCompletionMessageContentImageURL(
-                    url="https://someurl.com"
-                ),
-            ),
-        ],
-    )
-
-
-if __name__ == "__main__":
-    test_chat_completion_request_image()

From 3fe7abacdf5bdf3b0c810e919ae4dfbc03c0cbe7 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 02:44:19 +0000
Subject: [PATCH 02/10] lint

---
 python/sglang/bench_serving.py    |  5 +---
 python/sglang/test/eval_common.py | 30 +++++++++++++++++-------
 python/sglang/test/eval_mmlu.py   | 39 +++++++++++++++++++++++--------
 3 files changed, 51 insertions(+), 23 deletions(-)

diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py
index dc733012aa5..253aab355dd 100644
--- a/python/sglang/bench_serving.py
+++ b/python/sglang/bench_serving.py
@@ -868,13 +868,10 @@ def set_ulimit(target_soft_limit=65535):
 
 
 if __name__ == "__main__":
-    parser = ArgumentParser(
-        description="Benchmark the online serving throughput."
-    )
+    parser = ArgumentParser(description="Benchmark the online serving throughput.")
     parser.add_argument(
         "--backend",
         type=str,
-        required=True,
         choices=list(ASYNC_REQUEST_FUNCS.keys()),
         default="sglang",
         help="Must specify a backend, depending on the LLM Inference Engine.",
diff --git a/python/sglang/test/eval_common.py b/python/sglang/test/eval_common.py
index 06bfac46dbb..64632c45a94 100644
--- a/python/sglang/test/eval_common.py
+++ b/python/sglang/test/eval_common.py
@@ -1,8 +1,8 @@
 # Adapted from https://github.com/openai/simple-evals/
 
+import resource
 from dataclasses import dataclass, field
 from typing import Any
-import resource
 
 Message = dict[str, Any]  # keys role, content
 MessageList = list[Message]
@@ -58,7 +58,6 @@ def __call__(self, sampler: SamplerBase) -> EvalResult:
 import openai
 from openai import OpenAI
 
-
 OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
 OPENAI_SYSTEM_MESSAGE_CHATGPT = (
     "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
@@ -91,7 +90,11 @@ def __init__(
         self.image_format = "url"
 
     def _handle_image(
-        self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768
+        self,
+        image: str,
+        encoding: str = "base64",
+        format: str = "png",
+        fovea: int = 768,
     ):
         new_image = {
             "type": "image_url",
@@ -109,7 +112,9 @@ def _pack_message(self, role: str, content: Any):
 
     def __call__(self, message_list: MessageList) -> str:
         if self.system_message:
-            message_list = [self._pack_message("system", self.system_message)] + message_list
+            message_list = [
+                self._pack_message("system", self.system_message)
+            ] + message_list
         trial = 0
         while True:
             try:
@@ -144,7 +149,6 @@ def __call__(self, message_list: MessageList) -> str:
 import numpy as np
 from tqdm import tqdm
 
-
 QUERY_TEMPLATE_MULTICHOICE = """
 Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
 
@@ -284,7 +288,10 @@ def aggregate_results(
             key = name if stat == "mean" else f"{name}:{stat}"
             final_metrics[key] = _compute_stat(values, stat)
     return EvalResult(
-        score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos
+        score=final_metrics.pop("score", None),
+        metrics=final_metrics,
+        htmls=htmls,
+        convos=convos,
     )
 
 
@@ -322,7 +329,9 @@ def message_to_html(message: Message) -> str:
     Generate HTML snippet (inside a <div>) for a message.
     """
     return jinja_env.from_string(_message_template).render(
-        role=message["role"], content=message["content"], variant=message.get("variant", None)
+        role=message["role"],
+        content=message["content"],
+        variant=message.get("variant", None),
     )
 
 
@@ -410,12 +419,15 @@ def make_report_from_example_htmls(htmls: list[str]):
     """
     Create a standalone HTML report from a list of example htmls
     """
-    return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls)
+    return jinja_env.from_string(_report_template).render(
+        score=None, metrics={}, htmls=htmls
+    )
 
 
 import requests
 from tqdm import tqdm
 
+
 def download_dataset(path, url):
     print(f"Downloading dataset {path} from {url}")
     try:
@@ -440,6 +452,7 @@ def download_dataset(path, url):
     except requests.RequestException as e:
         raise Exception(f"Failed to download dataset: {e}")
 
+
 def set_ulimit(target_soft_limit=65535):
     resource_type = resource.RLIMIT_NOFILE
     current_soft, current_hard = resource.getrlimit(resource_type)
@@ -449,4 +462,3 @@ def set_ulimit(target_soft_limit=65535):
             resource.setrlimit(resource_type, (target_soft_limit, current_hard))
         except ValueError as e:
             print(f"Fail to set RLIMIT_NOFILE: {e}")
-
diff --git a/python/sglang/test/eval_mmlu.py b/python/sglang/test/eval_mmlu.py
index ac039e00336..7b594511fa1 100644
--- a/python/sglang/test/eval_mmlu.py
+++ b/python/sglang/test/eval_mmlu.py
@@ -14,7 +14,21 @@
 
 import pandas
 
-from sglang.test.eval_common import (ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question, Eval, EvalResult, SamplerBase, SingleEvalResult, download_dataset, ChatCompletionSampler, map_with_progress, jinja_env, aggregate_results, set_ulimit)
+from sglang.test.eval_common import (
+    ANSWER_PATTERN_MULTICHOICE,
+    HTML_JINJA,
+    ChatCompletionSampler,
+    Eval,
+    EvalResult,
+    SamplerBase,
+    SingleEvalResult,
+    aggregate_results,
+    download_dataset,
+    format_multichoice_question,
+    jinja_env,
+    map_with_progress,
+    set_ulimit,
+)
 
 subject2category = {
     "abstract_algebra": "stem",
@@ -88,7 +102,9 @@ def __init__(self, filename: str, num_examples: int | None = None):
     def __call__(self, sampler: SamplerBase) -> EvalResult:
         def fn(row: dict):
             prompt_messages = [
-                sampler._pack_message(content=format_multichoice_question(row), role="user")
+                sampler._pack_message(
+                    content=format_multichoice_question(row), role="user"
+                )
             ]
             response_text = sampler(prompt_messages)
             match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
@@ -103,7 +119,9 @@ def fn(row: dict):
             )
             convo = prompt_messages + [dict(content=response_text, role="assistant")]
             category = subject2category.get(row["Subject"], "other")
-            return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo)
+            return SingleEvalResult(
+                html=html, score=score, metrics={category: score}, convo=convo
+            )
 
         results = map_with_progress(fn, self.examples)
         return aggregate_results(results)
@@ -133,18 +151,19 @@ def fn(row: dict):
         type=str,
         help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
     )
-    parser.add_argument(
-        "--num-examples",
-        type=int,
-        help="The number of examples."
-    )
+    parser.add_argument("--num-examples", type=int, help="The number of examples.")
     set_ulimit()
     args = parser.parse_args()
 
-    base_url = f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
+    base_url = (
+        f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
+    )
 
     if not os.path.exists(args.dataset_path):
-        download_dataset(args.dataset_path, "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv")
+        download_dataset(
+            args.dataset_path,
+            "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
+        )
     eval_obj = MMLUEval(args.dataset_path, num_examples=args.num_examples)
     sampler = ChatCompletionSampler(
         model=args.model,

From 88d5fad15be57edff5415cf81b51d7f9ae7455ac Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:09:44 +0000
Subject: [PATCH 03/10] fix style

---
 python/sglang/test/run_eval.py                | 80 ++++++++++++++++
 .../{eval_common.py => simple_eval_common.py} | 48 ++++------
 .../{eval_mmlu.py => simple_eval_mmlu.py}     | 95 +++----------------
 3 files changed, 111 insertions(+), 112 deletions(-)
 create mode 100644 python/sglang/test/run_eval.py
 rename python/sglang/test/{eval_common.py => simple_eval_common.py} (99%)
 rename python/sglang/test/{eval_mmlu.py => simple_eval_mmlu.py} (57%)

diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
new file mode 100644
index 00000000000..1f856bf3f9c
--- /dev/null
+++ b/python/sglang/test/run_eval.py
@@ -0,0 +1,80 @@
+import argparse
+import os
+import time
+import json
+
+from sglang.test.simple_eval_mmlu import MMLUEval
+from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--base-url",
+        type=str,
+        default=None,
+        help="Server or API base url if not using http host and port.",
+    )
+    parser.add_argument(
+        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
+    )
+    parser.add_argument(
+        "--port",
+        type=int,
+        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
+    )
+    parser.add_argument("--eval-name", type=str, default="mmlu")
+    parser.add_argument("--num-examples", type=int)
+    parser.add_argument("--num-threads", type=int, default=64)
+    set_ulimit()
+    args = parser.parse_args()
+
+    base_url = (
+        f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
+    )
+
+    if args.eval_name == "mmlu":
+        dataset_path = "mmlu.csv"
+
+        if not os.path.exists(dataset_path):
+            download_dataset(
+                dataset_path,
+                "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
+            )
+        eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
+    else:
+        raise ValueError(f"Invalid eval name: {args.eval_name}")
+
+    sampler = ChatCompletionSampler(
+        model=args.model,
+        max_tokens=2048,
+        base_url=base_url,
+    )
+
+    # Run eval
+    tic = time.time()
+    result = eval_obj(sampler)
+    latency = time.time() - tic
+
+    # Dump reports
+    metrics = result.metrics | {"score": result.score}
+    file_stem = f"mmlu_{sampler.model.replace('/', '_')}"
+    report_filename = f"/tmp/{file_stem}.html"
+    print(f"Writing report to {report_filename}")
+    with open(report_filename, "w") as fh:
+        fh.write(make_report(result))
+    metrics = result.metrics | {"score": result.score}
+    print(metrics)
+    result_filename = f"/tmp/{file_stem}.json"
+    with open(result_filename, "w") as f:
+        f.write(json.dumps(metrics, indent=2))
+    print(f"Writing results to {result_filename}")
+
+    # Print results
+    print(f"Total latency: {latency:.3f} s")
+    print(f"Score: {metrics['score']:.3f}")
diff --git a/python/sglang/test/eval_common.py b/python/sglang/test/simple_eval_common.py
similarity index 99%
rename from python/sglang/test/eval_common.py
rename to python/sglang/test/simple_eval_common.py
index 64632c45a94..75c26f0f037 100644
--- a/python/sglang/test/eval_common.py
+++ b/python/sglang/test/simple_eval_common.py
@@ -1,9 +1,28 @@
 # Adapted from https://github.com/openai/simple-evals/
 
+import base64
+import os
 import resource
+import time
+from collections import defaultdict
 from dataclasses import dataclass, field
+from multiprocessing.pool import ThreadPool
 from typing import Any
 
+import jinja2
+import numpy as np
+import openai
+import requests
+from openai import OpenAI
+from tqdm import tqdm
+
+OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
+OPENAI_SYSTEM_MESSAGE_CHATGPT = (
+    "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
+    + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
+)
+
+
 Message = dict[str, Any]  # keys role, content
 MessageList = list[Message]
 
@@ -51,20 +70,6 @@ def __call__(self, sampler: SamplerBase) -> EvalResult:
         raise NotImplementedError()
 
 
-import base64
-import time
-from typing import Any
-
-import openai
-from openai import OpenAI
-
-OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
-OPENAI_SYSTEM_MESSAGE_CHATGPT = (
-    "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
-    + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
-)
-
-
 class ChatCompletionSampler(SamplerBase):
     """
     Sample from OpenAI's chat completion API
@@ -140,15 +145,6 @@ def __call__(self, message_list: MessageList) -> str:
             # unknown error shall throw exception
 
 
-import os
-from collections import defaultdict
-from multiprocessing.pool import ThreadPool
-from typing import Any
-
-import jinja2
-import numpy as np
-from tqdm import tqdm
-
 QUERY_TEMPLATE_MULTICHOICE = """
 Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
 
@@ -295,7 +291,7 @@ def aggregate_results(
     )
 
 
-def map_with_progress(f: callable, xs: list[Any], num_threads: int = 50):
+def map_with_progress(f: callable, xs: list[Any], num_threads: int):
     """
     Apply f to each element of xs, using a ThreadPool, and show progress.
     """
@@ -424,10 +420,6 @@ def make_report_from_example_htmls(htmls: list[str]):
     )
 
 
-import requests
-from tqdm import tqdm
-
-
 def download_dataset(path, url):
     print(f"Downloading dataset {path} from {url}")
     try:
diff --git a/python/sglang/test/eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py
similarity index 57%
rename from python/sglang/test/eval_mmlu.py
rename to python/sglang/test/simple_eval_mmlu.py
index 7b594511fa1..6467f7760dc 100644
--- a/python/sglang/test/eval_mmlu.py
+++ b/python/sglang/test/simple_eval_mmlu.py
@@ -6,29 +6,15 @@
 https://arxiv.org/abs/2009.03300
 """
 
-import argparse
-import os
 import random
 import re
-import time
 
 import pandas
 
-from sglang.test.eval_common import (
-    ANSWER_PATTERN_MULTICHOICE,
-    HTML_JINJA,
-    ChatCompletionSampler,
-    Eval,
-    EvalResult,
-    SamplerBase,
-    SingleEvalResult,
-    aggregate_results,
-    download_dataset,
-    format_multichoice_question,
-    jinja_env,
-    map_with_progress,
-    set_ulimit,
-)
+from sglang.test import simple_eval_common as common
+from sglang.test.simple_eval_common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question
+from sglang.test.simple_eval_common import Eval, EvalResult, SamplerBase, SingleEvalResult
+
 
 subject2category = {
     "abstract_algebra": "stem",
@@ -92,25 +78,24 @@
 
 
 class MMLUEval(Eval):
-    def __init__(self, filename: str, num_examples: int | None = None):
+    def __init__(self, filename: str, num_examples: int | None, num_threads: int):
         df = pandas.read_csv(filename)
         examples = [row.to_dict() for _, row in df.iterrows()]
         if num_examples:
             examples = random.Random(0).sample(examples, num_examples)
         self.examples = examples
+        self.num_threads = num_threads
 
     def __call__(self, sampler: SamplerBase) -> EvalResult:
         def fn(row: dict):
             prompt_messages = [
-                sampler._pack_message(
-                    content=format_multichoice_question(row), role="user"
-                )
+                sampler._pack_message(content=format_multichoice_question(row), role="user")
             ]
             response_text = sampler(prompt_messages)
             match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
             extracted_answer = match.group(1) if match else None
             score = 1.0 if extracted_answer == row["Answer"] else 0.0
-            html = jinja_env.from_string(HTML_JINJA).render(
+            html = common.jinja_env.from_string(HTML_JINJA).render(
                 prompt_messages=prompt_messages,
                 next_message=dict(content=response_text, role="assistant"),
                 score=score,
@@ -119,65 +104,7 @@ def fn(row: dict):
             )
             convo = prompt_messages + [dict(content=response_text, role="assistant")]
             category = subject2category.get(row["Subject"], "other")
-            return SingleEvalResult(
-                html=html, score=score, metrics={category: score}, convo=convo
-            )
-
-        results = map_with_progress(fn, self.examples)
-        return aggregate_results(results)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--dataset-path", type=str, default="mmlu.csv", help="Path to the dataset."
-    )
-    parser.add_argument(
-        "--base-url",
-        type=str,
-        default=None,
-        help="Server or API base url if not using http host and port.",
-    )
-    parser.add_argument(
-        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
-    )
-    parser.add_argument(
-        "--port",
-        type=int,
-        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
-    )
-    parser.add_argument(
-        "--model",
-        type=str,
-        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
-    )
-    parser.add_argument("--num-examples", type=int, help="The number of examples.")
-    set_ulimit()
-    args = parser.parse_args()
-
-    base_url = (
-        f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
-    )
-
-    if not os.path.exists(args.dataset_path):
-        download_dataset(
-            args.dataset_path,
-            "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
-        )
-    eval_obj = MMLUEval(args.dataset_path, num_examples=args.num_examples)
-    sampler = ChatCompletionSampler(
-        model=args.model,
-        max_tokens=2048,
-        base_url=base_url,
-    )
-
-    tic = time.time()
-
-    result = eval_obj(sampler)
-    metrics = result.metrics | {"score": result.score}
-
-    latency = time.time() - tic
-    score = metrics["score"]
+            return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo)
 
-    print(f"Total latency: {latency:.3f} s")
-    print(f"Score: {score:.3f}")
+        results = common.map_with_progress(fn, self.examples, self.num_threads)
+        return common.aggregate_results(results)

From af51e277579e8e24c08987d4024c7fd10fff9356 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:11:49 +0000
Subject: [PATCH 04/10] update

---
 python/sglang/test/run_eval.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
index 1f856bf3f9c..afa5ca38d7e 100644
--- a/python/sglang/test/run_eval.py
+++ b/python/sglang/test/run_eval.py
@@ -1,10 +1,14 @@
+"""
+Usage:
+python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
+"""
 import argparse
 import os
 import time
 import json
 
 from sglang.test.simple_eval_mmlu import MMLUEval
-from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report
+from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report, download_dataset
 
 
 if __name__ == "__main__":

From 496552fec3dd7b13fa28befc74573bf25db2cd93 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:12:35 +0000
Subject: [PATCH 05/10] lint

---
 python/sglang/test/run_eval.py         | 11 ++++++++---
 python/sglang/test/simple_eval_mmlu.py | 20 +++++++++++++++-----
 2 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
index afa5ca38d7e..430e00d2834 100644
--- a/python/sglang/test/run_eval.py
+++ b/python/sglang/test/run_eval.py
@@ -2,14 +2,19 @@
 Usage:
 python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
 """
+
 import argparse
+import json
 import os
 import time
-import json
 
+from sglang.test.simple_eval_common import (
+    ChatCompletionSampler,
+    download_dataset,
+    make_report,
+    set_ulimit,
+)
 from sglang.test.simple_eval_mmlu import MMLUEval
-from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report, download_dataset
-
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py
index 6467f7760dc..3c0287510cb 100644
--- a/python/sglang/test/simple_eval_mmlu.py
+++ b/python/sglang/test/simple_eval_mmlu.py
@@ -12,9 +12,15 @@
 import pandas
 
 from sglang.test import simple_eval_common as common
-from sglang.test.simple_eval_common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question
-from sglang.test.simple_eval_common import Eval, EvalResult, SamplerBase, SingleEvalResult
-
+from sglang.test.simple_eval_common import (
+    ANSWER_PATTERN_MULTICHOICE,
+    HTML_JINJA,
+    Eval,
+    EvalResult,
+    SamplerBase,
+    SingleEvalResult,
+    format_multichoice_question,
+)
 
 subject2category = {
     "abstract_algebra": "stem",
@@ -89,7 +95,9 @@ def __init__(self, filename: str, num_examples: int | None, num_threads: int):
     def __call__(self, sampler: SamplerBase) -> EvalResult:
         def fn(row: dict):
             prompt_messages = [
-                sampler._pack_message(content=format_multichoice_question(row), role="user")
+                sampler._pack_message(
+                    content=format_multichoice_question(row), role="user"
+                )
             ]
             response_text = sampler(prompt_messages)
             match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
@@ -104,7 +112,9 @@ def fn(row: dict):
             )
             convo = prompt_messages + [dict(content=response_text, role="assistant")]
             category = subject2category.get(row["Subject"], "other")
-            return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo)
+            return SingleEvalResult(
+                html=html, score=score, metrics={category: score}, convo=convo
+            )
 
         results = common.map_with_progress(fn, self.examples, self.num_threads)
         return common.aggregate_results(results)

From 4250c24e24e193bb2d8f99afb5db72335ecaed4a Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:30:31 +0000
Subject: [PATCH 06/10] test mmlu cot accuracy

---
 .github/workflows/e2e-test.yml   |  4 +--
 .github/workflows/unit-test.yml  |  7 ++++
 python/sglang/test/run_eval.py   | 59 ++++++++++++++++++--------------
 python/sglang/test/test_utils.py | 30 ++++++++++++++++
 test/srt/test_eval_accuracy.py   | 43 +++++++++++++++++++++++
 test/srt/test_openai_server.py   | 36 +++----------------
 6 files changed, 119 insertions(+), 60 deletions(-)
 create mode 100644 test/srt/test_eval_accuracy.py

diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml
index 9630ca71838..7b59054fe12 100644
--- a/.github/workflows/e2e-test.yml
+++ b/.github/workflows/e2e-test.yml
@@ -18,7 +18,7 @@ concurrency:
   cancel-in-progress: true
 
 jobs:
-  pr-e2e-test:
+  e2e-test:
     runs-on: self-hosted
 
     env:
@@ -38,7 +38,7 @@ jobs:
         pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
         pip install --upgrade transformers
 
-    - name: Benchmark Serving
+    - name: Benchmark Serving Throughput
       run: |
         cd /data/zhyncs/venv && source ./bin/activate && cd -
         python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache &
diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml
index dc464fa8cc2..f1c069ea5b6 100644
--- a/.github/workflows/unit-test.yml
+++ b/.github/workflows/unit-test.yml
@@ -59,3 +59,10 @@ jobs:
 
         cd test/srt
         python3 test_openai_server.py
+
+    - name: Test Accuracy
+      run: |
+        cd /data/zhyncs/venv && source ./bin/activate && cd -
+
+        cd test/srt
+        python3 test_eval_accuracy.py
diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
index 430e00d2834..6433f49474d 100644
--- a/python/sglang/test/run_eval.py
+++ b/python/sglang/test/run_eval.py
@@ -16,33 +16,8 @@
 )
 from sglang.test.simple_eval_mmlu import MMLUEval
 
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--base-url",
-        type=str,
-        default=None,
-        help="Server or API base url if not using http host and port.",
-    )
-    parser.add_argument(
-        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
-    )
-    parser.add_argument(
-        "--port",
-        type=int,
-        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
-    )
-    parser.add_argument(
-        "--model",
-        type=str,
-        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
-    )
-    parser.add_argument("--eval-name", type=str, default="mmlu")
-    parser.add_argument("--num-examples", type=int)
-    parser.add_argument("--num-threads", type=int, default=64)
-    set_ulimit()
-    args = parser.parse_args()
 
+def run_eval(args):
     base_url = (
         f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
     )
@@ -87,3 +62,35 @@
     # Print results
     print(f"Total latency: {latency:.3f} s")
     print(f"Score: {metrics['score']:.3f}")
+
+    return metrics
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--base-url",
+        type=str,
+        default=None,
+        help="Server or API base url if not using http host and port.",
+    )
+    parser.add_argument(
+        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
+    )
+    parser.add_argument(
+        "--port",
+        type=int,
+        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
+    )
+    parser.add_argument("--eval-name", type=str, default="mmlu")
+    parser.add_argument("--num-examples", type=int)
+    parser.add_argument("--num-threads", type=int, default=64)
+    set_ulimit()
+    args = parser.parse_args()
+
+    run_eval(args)
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index af7f3765ef4..5cb2c4f56c6 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -1,6 +1,8 @@
 """Common utilities for testing and benchmarking"""
 
 import asyncio
+import subprocess
+import time
 from functools import partial
 
 import numpy as np
@@ -379,3 +381,31 @@ def func(*args, **kwargs):
             raise
 
     return func
+
+
+def popen_launch_server(model, port, timeout, *args):
+    command = [
+        "python3",
+        "-m",
+        "sglang.launch_server",
+        "--model-path",
+        model,
+        "--host",
+        "localhost",
+        "--port",
+        str(port),
+        *args
+    ]
+    process = subprocess.Popen(command, stdout=None, stderr=None)
+    base_url = f"http://localhost:{port}/v1"
+
+    start_time = time.time()
+    while time.time() - start_time < timeout:
+        try:
+            response = requests.get(f"{base_url}/models")
+            if response.status_code == 200:
+                return process
+        except requests.RequestException:
+            pass
+        time.sleep(10)
+    raise TimeoutError("Server failed to start within the timeout period.")
diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py
new file mode 100644
index 00000000000..a54679efa71
--- /dev/null
+++ b/test/srt/test_eval_accuracy.py
@@ -0,0 +1,43 @@
+import json
+import unittest
+from types import SimpleNamespace
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.test_utils import popen_launch_server
+from sglang.test.run_eval import run_eval
+
+
+class TestAccuracy(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        port = 30000
+
+        cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+        cls.base_url = f"http://localhost:{port}"
+        cls.process = popen_launch_server(cls.model, port, timeout=300)
+
+    @classmethod
+    def tearDownClass(cls):
+        kill_child_process(cls.process.pid)
+
+    def test_mmlu(self):
+        args = SimpleNamespace(
+            base_url=self.base_url,
+            model=self.model,
+            eval_name="mmlu",
+            num_examples=20,
+            num_threads=20,
+        )
+
+        metrics = run_eval(args)
+        assert metrics["score"] >= 0.5
+
+
+if __name__ == "__main__":
+    unittest.main(warnings="ignore")
+
+    # t = TestAccuracy()
+    # t.setUpClass()
+    # t.test_mmlu()
+    # t.tearDownClass()
diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py
index e15c2ba88ff..6c03a40db36 100644
--- a/test/srt/test_openai_server.py
+++ b/test/srt/test_openai_server.py
@@ -1,47 +1,21 @@
 import json
-import subprocess
-import time
 import unittest
 
 import openai
-import requests
 
 from sglang.srt.utils import kill_child_process
+from sglang.test.test_utils import popen_launch_server
 
 
 class TestOpenAIServer(unittest.TestCase):
 
     @classmethod
     def setUpClass(cls):
-        model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
         port = 30000
-        timeout = 300
-
-        command = [
-            "python3",
-            "-m",
-            "sglang.launch_server",
-            "--model-path",
-            model,
-            "--host",
-            "localhost",
-            "--port",
-            str(port),
-        ]
-        cls.process = subprocess.Popen(command, stdout=None, stderr=None)
+
+        cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
         cls.base_url = f"http://localhost:{port}/v1"
-        cls.model = model
-
-        start_time = time.time()
-        while time.time() - start_time < timeout:
-            try:
-                response = requests.get(f"{cls.base_url}/models")
-                if response.status_code == 200:
-                    return
-            except requests.RequestException:
-                pass
-            time.sleep(10)
-        raise TimeoutError("Server failed to start within the timeout period.")
+        cls.process = popen_launch_server(cls.model, port, timeout=300)
 
     @classmethod
     def tearDownClass(cls):
@@ -178,8 +152,6 @@ def run_chat_completion_stream(self, logprobs):
 
         is_first = True
         for response in generator:
-            print(response)
-
             data = response.choices[0].delta
             if is_first:
                 data.role == "assistant"

From 13b7e45fda53de80a246f402dd9ce097cd4baef6 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:33:01 +0000
Subject: [PATCH 07/10] lint

---
 python/sglang/test/test_utils.py | 2 +-
 test/srt/test_eval_accuracy.py   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 5cb2c4f56c6..f0cdf69c751 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -394,7 +394,7 @@ def popen_launch_server(model, port, timeout, *args):
         "localhost",
         "--port",
         str(port),
-        *args
+        *args,
     ]
     process = subprocess.Popen(command, stdout=None, stderr=None)
     base_url = f"http://localhost:{port}/v1"
diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py
index a54679efa71..436cc18bffe 100644
--- a/test/srt/test_eval_accuracy.py
+++ b/test/srt/test_eval_accuracy.py
@@ -3,8 +3,8 @@
 from types import SimpleNamespace
 
 from sglang.srt.utils import kill_child_process
-from sglang.test.test_utils import popen_launch_server
 from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import popen_launch_server
 
 
 class TestAccuracy(unittest.TestCase):

From 6f373feae4fa535231698b91b5f293602e29b350 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:49:52 +0000
Subject: [PATCH 08/10] add test_srt_endpoint

---
 python/sglang/test/run_eval.py   |  3 ++
 python/sglang/test/test_utils.py |  2 ++
 test/lang/test_srt_backend.py    |  3 +-
 test/srt/test_eval_accuracy.py   |  4 +--
 test/srt/test_openai_server.py   |  4 +--
 test/srt/test_srt_endpoint.py    | 61 ++++++++++++++++++++++++++++++++
 6 files changed, 72 insertions(+), 5 deletions(-)
 create mode 100644 test/srt/test_srt_endpoint.py

diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
index 6433f49474d..3729ef7ab26 100644
--- a/python/sglang/test/run_eval.py
+++ b/python/sglang/test/run_eval.py
@@ -18,6 +18,9 @@
 
 
 def run_eval(args):
+    if "OPENAI_API_KEY" not in os.environ:
+        os.environ["OPENAI_API_KEY"] = "EMPTY"
+
     base_url = (
         f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
     )
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index f0cdf69c751..4348b57e9b0 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -13,6 +13,8 @@
 from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
 from sglang.utils import get_exception_traceback
 
+MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+
 
 def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
     assert url is not None
diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py
index f9d79ed290e..7accd349f81 100644
--- a/test/lang/test_srt_backend.py
+++ b/test/lang/test_srt_backend.py
@@ -14,6 +14,7 @@
     test_stream,
     test_tool_use,
 )
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST
 
 
 class TestSRTBackend(unittest.TestCase):
@@ -21,7 +22,7 @@ class TestSRTBackend(unittest.TestCase):
 
     @classmethod
     def setUpClass(cls):
-        cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
+        cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST)
         sgl.set_default_backend(cls.backend)
 
     @classmethod
diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py
index 436cc18bffe..d392dc4c066 100644
--- a/test/srt/test_eval_accuracy.py
+++ b/test/srt/test_eval_accuracy.py
@@ -4,7 +4,7 @@
 
 from sglang.srt.utils import kill_child_process
 from sglang.test.run_eval import run_eval
-from sglang.test.test_utils import popen_launch_server
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
 
 
 class TestAccuracy(unittest.TestCase):
@@ -13,7 +13,7 @@ class TestAccuracy(unittest.TestCase):
     def setUpClass(cls):
         port = 30000
 
-        cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+        cls.model = MODEL_NAME_FOR_TEST
         cls.base_url = f"http://localhost:{port}"
         cls.process = popen_launch_server(cls.model, port, timeout=300)
 
diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py
index 6c03a40db36..76a105a6254 100644
--- a/test/srt/test_openai_server.py
+++ b/test/srt/test_openai_server.py
@@ -4,7 +4,7 @@
 import openai
 
 from sglang.srt.utils import kill_child_process
-from sglang.test.test_utils import popen_launch_server
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
 
 
 class TestOpenAIServer(unittest.TestCase):
@@ -13,7 +13,7 @@ class TestOpenAIServer(unittest.TestCase):
     def setUpClass(cls):
         port = 30000
 
-        cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+        cls.model = MODEL_NAME_FOR_TEST
         cls.base_url = f"http://localhost:{port}/v1"
         cls.process = popen_launch_server(cls.model, port, timeout=300)
 
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
new file mode 100644
index 00000000000..72b9e753dd0
--- /dev/null
+++ b/test/srt/test_srt_endpoint.py
@@ -0,0 +1,61 @@
+import json
+import unittest
+import requests
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import popen_launch_server, MODEL_NAME_FOR_TEST
+
+
+class TestSRTEndpoint(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        port = 30000
+
+        cls.model = MODEL_NAME_FOR_TEST
+        cls.base_url = f"http://localhost:{port}"
+        cls.process = popen_launch_server(cls.model, port, timeout=300)
+
+    @classmethod
+    def tearDownClass(cls):
+        kill_child_process(cls.process.pid)
+
+    def run_decode(self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
+        response = requests.post(
+            self.base_url + "/generate",
+            json={
+                "text": "The capital of France is",
+                "sampling_params": {
+                    "temperature": 0 if n == 1 else 0.5,
+                    "max_new_tokens": 32,
+                    "n": n,
+                },
+                "stream": False,
+                "return_logprob": return_logprob,
+                "top_logprobs_num": top_logprobs_num,
+                "return_text_in_logprobs": return_text,
+                "logprob_start_len": 0,
+            },
+        )
+        print(json.dumps(response.json()))
+        print("=" * 100)
+
+    def test_simple_decode(self):
+        self.run_decode()
+
+    def test_parallel_sample(self):
+        self.run_decode(n=3)
+
+    def test_logprob(self):
+        for top_logprobs_num in [0, 3]:
+            for return_text in [True, False]:
+                self.run_decode(
+                    return_logprob=True,
+                    top_logprobs_num=top_logprobs_num,
+                    return_text=return_text,
+                )
+
+
+if __name__ == "__main__":
+    unittest.main(warnings="ignore")

From 208ab71d480154fc2350def23f547fd22abdae92 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:50:46 +0000
Subject: [PATCH 09/10] rename

---
 test/srt/{old => deprecated}/test_curl.sh                     | 0
 test/srt/{old => deprecated}/test_flashinfer.py               | 0
 test/srt/{old => deprecated}/test_httpserver_classify.py      | 0
 test/srt/{old => deprecated}/test_httpserver_concurrent.py    | 0
 test/srt/{old => deprecated}/test_httpserver_decode.py        | 0
 test/srt/{old => deprecated}/test_httpserver_decode_stream.py | 0
 test/srt/{old => deprecated}/test_httpserver_llava.py         | 0
 test/srt/{old => deprecated}/test_httpserver_reuse.py         | 0
 test/srt/{old => deprecated}/test_jump_forward.py             | 0
 test/srt/{old => deprecated}/test_openai_server.py            | 0
 test/srt/{old => deprecated}/test_robust.py                   | 0
 11 files changed, 0 insertions(+), 0 deletions(-)
 rename test/srt/{old => deprecated}/test_curl.sh (100%)
 rename test/srt/{old => deprecated}/test_flashinfer.py (100%)
 rename test/srt/{old => deprecated}/test_httpserver_classify.py (100%)
 rename test/srt/{old => deprecated}/test_httpserver_concurrent.py (100%)
 rename test/srt/{old => deprecated}/test_httpserver_decode.py (100%)
 rename test/srt/{old => deprecated}/test_httpserver_decode_stream.py (100%)
 rename test/srt/{old => deprecated}/test_httpserver_llava.py (100%)
 rename test/srt/{old => deprecated}/test_httpserver_reuse.py (100%)
 rename test/srt/{old => deprecated}/test_jump_forward.py (100%)
 rename test/srt/{old => deprecated}/test_openai_server.py (100%)
 rename test/srt/{old => deprecated}/test_robust.py (100%)

diff --git a/test/srt/old/test_curl.sh b/test/srt/deprecated/test_curl.sh
similarity index 100%
rename from test/srt/old/test_curl.sh
rename to test/srt/deprecated/test_curl.sh
diff --git a/test/srt/old/test_flashinfer.py b/test/srt/deprecated/test_flashinfer.py
similarity index 100%
rename from test/srt/old/test_flashinfer.py
rename to test/srt/deprecated/test_flashinfer.py
diff --git a/test/srt/old/test_httpserver_classify.py b/test/srt/deprecated/test_httpserver_classify.py
similarity index 100%
rename from test/srt/old/test_httpserver_classify.py
rename to test/srt/deprecated/test_httpserver_classify.py
diff --git a/test/srt/old/test_httpserver_concurrent.py b/test/srt/deprecated/test_httpserver_concurrent.py
similarity index 100%
rename from test/srt/old/test_httpserver_concurrent.py
rename to test/srt/deprecated/test_httpserver_concurrent.py
diff --git a/test/srt/old/test_httpserver_decode.py b/test/srt/deprecated/test_httpserver_decode.py
similarity index 100%
rename from test/srt/old/test_httpserver_decode.py
rename to test/srt/deprecated/test_httpserver_decode.py
diff --git a/test/srt/old/test_httpserver_decode_stream.py b/test/srt/deprecated/test_httpserver_decode_stream.py
similarity index 100%
rename from test/srt/old/test_httpserver_decode_stream.py
rename to test/srt/deprecated/test_httpserver_decode_stream.py
diff --git a/test/srt/old/test_httpserver_llava.py b/test/srt/deprecated/test_httpserver_llava.py
similarity index 100%
rename from test/srt/old/test_httpserver_llava.py
rename to test/srt/deprecated/test_httpserver_llava.py
diff --git a/test/srt/old/test_httpserver_reuse.py b/test/srt/deprecated/test_httpserver_reuse.py
similarity index 100%
rename from test/srt/old/test_httpserver_reuse.py
rename to test/srt/deprecated/test_httpserver_reuse.py
diff --git a/test/srt/old/test_jump_forward.py b/test/srt/deprecated/test_jump_forward.py
similarity index 100%
rename from test/srt/old/test_jump_forward.py
rename to test/srt/deprecated/test_jump_forward.py
diff --git a/test/srt/old/test_openai_server.py b/test/srt/deprecated/test_openai_server.py
similarity index 100%
rename from test/srt/old/test_openai_server.py
rename to test/srt/deprecated/test_openai_server.py
diff --git a/test/srt/old/test_robust.py b/test/srt/deprecated/test_robust.py
similarity index 100%
rename from test/srt/old/test_robust.py
rename to test/srt/deprecated/test_robust.py

From 0134e5f8faa05e42880bc3c0e929124026472a8a Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Fri, 2 Aug 2024 03:57:22 +0000
Subject: [PATCH 10/10] lint

---
 test/srt/test_srt_endpoint.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
index 72b9e753dd0..3454678586b 100644
--- a/test/srt/test_srt_endpoint.py
+++ b/test/srt/test_srt_endpoint.py
@@ -1,10 +1,11 @@
 import json
 import unittest
+
 import requests
 
 from sglang.srt.utils import kill_child_process
 from sglang.test.run_eval import run_eval
-from sglang.test.test_utils import popen_launch_server, MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
 
 
 class TestSRTEndpoint(unittest.TestCase):
@@ -21,7 +22,9 @@ def setUpClass(cls):
     def tearDownClass(cls):
         kill_child_process(cls.process.pid)
 
-    def run_decode(self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
+    def run_decode(
+        self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
+    ):
         response = requests.post(
             self.base_url + "/generate",
             json={