From 044740e18281524a4f9198406b044dd8271b5e4f Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 02:41:03 +0000 Subject: [PATCH 01/10] eval mmlu --- python/sglang/bench_serving.py | 5 +- python/sglang/test/eval_common.py | 452 +++++++++++++++++++++ python/sglang/test/eval_mmlu.py | 164 ++++++++ python/sglang/test/test_conversation.py | 46 --- python/sglang/test/test_openai_protocol.py | 51 --- 5 files changed, 619 insertions(+), 99 deletions(-) create mode 100644 python/sglang/test/eval_common.py create mode 100644 python/sglang/test/eval_mmlu.py delete mode 100644 python/sglang/test/test_conversation.py delete mode 100644 python/sglang/test/test_openai_protocol.py diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index b52e114fd70..dc733012aa5 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -21,7 +21,7 @@ import time import traceback import warnings -from argparse import ArgumentParser as FlexibleArgumentParser +from argparse import ArgumentParser from dataclasses import dataclass, field from datetime import datetime from typing import AsyncGenerator, List, Optional, Tuple, Union @@ -868,7 +868,7 @@ def set_ulimit(target_soft_limit=65535): if __name__ == "__main__": - parser = FlexibleArgumentParser( + parser = ArgumentParser( description="Benchmark the online serving throughput." ) parser.add_argument( @@ -876,6 +876,7 @@ def set_ulimit(target_soft_limit=65535): type=str, required=True, choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", help="Must specify a backend, depending on the LLM Inference Engine.", ) parser.add_argument( diff --git a/python/sglang/test/eval_common.py b/python/sglang/test/eval_common.py new file mode 100644 index 00000000000..06bfac46dbb --- /dev/null +++ b/python/sglang/test/eval_common.py @@ -0,0 +1,452 @@ +# Adapted from https://github.com/openai/simple-evals/ + +from dataclasses import dataclass, field +from typing import Any +import resource + +Message = dict[str, Any] # keys role, content +MessageList = list[Message] + + +class SamplerBase: + """ + Base class for defining a sampling model, which can be evaluated, + or used as part of the grading process. + """ + + def __call__(self, message_list: MessageList) -> str: + raise NotImplementedError() + + +@dataclass +class EvalResult: + """ + Result of running an evaluation (usually consisting of many samples) + """ + + score: float | None # top-line metric + metrics: dict[str, float] | None # other metrics + htmls: list[str] # strings of valid HTML + convos: list[MessageList] # sampled conversations + + +@dataclass +class SingleEvalResult: + """ + Result of evaluating a single sample + """ + + score: float | None + metrics: dict[str, float] = field(default_factory=dict) + html: str | None = None + convo: MessageList | None = None # sampled conversation + + +class Eval: + """ + Base class for defining an evaluation. + """ + + def __call__(self, sampler: SamplerBase) -> EvalResult: + raise NotImplementedError() + + +import base64 +import time +from typing import Any + +import openai +from openai import OpenAI + + +OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." +OPENAI_SYSTEM_MESSAGE_CHATGPT = ( + "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." + + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" +) + + +class ChatCompletionSampler(SamplerBase): + """ + Sample from OpenAI's chat completion API + """ + + def __init__( + self, + base_url: str = None, + model: str | None = None, + system_message: str | None = None, + temperature: float = 0.0, + max_tokens: int = 2048, + ): + self.client = OpenAI(base_url=base_url) + + if model is None: + model = self.client.models.list().data[0].id + + self.model = model + self.system_message = system_message + self.temperature = temperature + self.max_tokens = max_tokens + self.image_format = "url" + + def _handle_image( + self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768 + ): + new_image = { + "type": "image_url", + "image_url": { + "url": f"data:image/{format};{encoding},{image}", + }, + } + return new_image + + def _handle_text(self, text: str): + return {"type": "text", "text": text} + + def _pack_message(self, role: str, content: Any): + return {"role": str(role), "content": content} + + def __call__(self, message_list: MessageList) -> str: + if self.system_message: + message_list = [self._pack_message("system", self.system_message)] + message_list + trial = 0 + while True: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=message_list, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + return response.choices[0].message.content + # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU + except openai.BadRequestError as e: + print("Bad Request Error", e) + return "" + except Exception as e: + exception_backoff = 2**trial # expontial back off + print( + f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", + e, + ) + time.sleep(exception_backoff) + trial += 1 + # unknown error shall throw exception + + +import os +from collections import defaultdict +from multiprocessing.pool import ThreadPool +from typing import Any + +import jinja2 +import numpy as np +from tqdm import tqdm + + +QUERY_TEMPLATE_MULTICHOICE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" +ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" + + +EQUALITY_TEMPLATE = r""" +Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications + +Examples: + + Expression 1: $2x+3$ + Expression 2: $3+2x$ + +Yes + + Expression 1: 3/2 + Expression 2: 1.5 + +Yes + + Expression 1: $x^2+2x+1$ + Expression 2: $y^2+2y+1$ + +No + + Expression 1: $x^2+2x+1$ + Expression 2: $(x+1)^2$ + +Yes + + Expression 1: 3245/5 + Expression 2: 649 + +No +(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) + + Expression 1: 2/(-3) + Expression 2: -2/3 + +Yes +(trivial simplifications are allowed) + + Expression 1: 72 degrees + Expression 2: 72 + +Yes +(give benefit of the doubt to units) + + Expression 1: 64 + Expression 2: 64 square feet + +Yes +(give benefit of the doubt to units) + +--- + +YOUR TASK + + +Respond with only "Yes" or "No" (without quotes). Do not include a rationale. + + Expression 1: %(expression1)s + Expression 2: %(expression2)s +""".strip() + + +HTML_JINJA = """ +<h3>Prompt conversation</h3> +{% for message in prompt_messages %} +{{ message_to_html(message) | safe }} +{% endfor %} +<h3>Sampled message</h3> +{{ message_to_html(next_message) | safe }} +<h3>Results</h3> +<p>Correct Answer: {{ correct_answer }}</p> +<p>Extracted Answer: {{ extracted_answer }}</p> +<p>Score: {{ score }}</p> +""" + + +def format_multichoice_question(row): + return QUERY_TEMPLATE_MULTICHOICE.format(**row) + + +def check_equality(sampler: SamplerBase, expr1: str, expr2: str): + prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} + response = sampler([dict(content=prompt, role="user")]) + return response.lower().strip() == "yes" + + +def _compute_stat(values: list, stat: str): + if stat == "mean": + return np.mean(values) + elif stat == "std": + return np.std(values) + elif stat == "min": + return np.min(values) + elif stat == "max": + return np.max(values) + else: + raise ValueError(f"Unknown {stat =}") + + +def aggregate_results( + single_eval_results: list[SingleEvalResult], + default_stats: tuple[str] = ("mean", "std"), + name2stats: dict[str, tuple[str]] | None = None, +) -> EvalResult: + """ + Aggregate results from multiple evaluations into a single EvalResult. + """ + name2stats = name2stats or {} + name2values = defaultdict(list) + htmls = [] + convos = [] + for single_eval_result in single_eval_results: + for name, value in single_eval_result.metrics.items(): + name2values[name].append(value) + if single_eval_result.score is not None: + name2values["score"].append(single_eval_result.score) + htmls.append(single_eval_result.html) + convos.append(single_eval_result.convo) + final_metrics = {} + for name, values in name2values.items(): + stats = name2stats.get(name, default_stats) + for stat in stats: + key = name if stat == "mean" else f"{name}:{stat}" + final_metrics[key] = _compute_stat(values, stat) + return EvalResult( + score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos + ) + + +def map_with_progress(f: callable, xs: list[Any], num_threads: int = 50): + """ + Apply f to each element of xs, using a ThreadPool, and show progress. + """ + if os.getenv("debug"): + return list(map(f, tqdm(xs, total=len(xs)))) + else: + with ThreadPool(min(num_threads, len(xs))) as pool: + return list(tqdm(pool.imap(f, xs), total=len(xs))) + + +jinja_env = jinja2.Environment( + loader=jinja2.BaseLoader(), + undefined=jinja2.StrictUndefined, + autoescape=jinja2.select_autoescape(["html", "xml"]), +) +_message_template = """ +<div class="message {{ role }}"> + <div class="role"> + {{ role }} + {% if variant %}<span class="variant">({{ variant }})</span>{% endif %} + </div> + <div class="content"> + <pre>{{ content }}</pre> + </div> +</div> +""" + + +def message_to_html(message: Message) -> str: + """ + Generate HTML snippet (inside a <div>) for a message. + """ + return jinja_env.from_string(_message_template).render( + role=message["role"], content=message["content"], variant=message.get("variant", None) + ) + + +jinja_env.globals["message_to_html"] = message_to_html + + +_report_template = """<!DOCTYPE html> +<html> + <head> + <style> + .message { + padding: 8px 16px; + margin-bottom: 8px; + border-radius: 4px; + } + .message.user { + background-color: #B2DFDB; + color: #00695C; + } + .message.assistant { + background-color: #B39DDB; + color: #4527A0; + } + .message.system { + background-color: #EEEEEE; + color: #212121; + } + .role { + font-weight: bold; + margin-bottom: 4px; + } + .variant { + color: #795548; + } + table, th, td { + border: 1px solid black; + } + pre { + white-space: pre-wrap; + } + </style> + </head> + <body> + {% if metrics %} + <h1>Metrics</h1> + <table> + <tr> + <th>Metric</th> + <th>Value</th> + </tr> + <tr> + <td><b>Score</b></td> + <td>{{ score | float | round(3) }}</td> + </tr> + {% for name, value in metrics.items() %} + <tr> + <td>{{ name }}</td> + <td>{{ value }}</td> + </tr> + {% endfor %} + </table> + {% endif %} + <h1>Examples</h1> + {% for html in htmls %} + {{ html | safe }} + <hr> + {% endfor %} + </body> +</html> +""" + + +def make_report(eval_result: EvalResult) -> str: + """ + Create a standalone HTML report from an EvalResult. + """ + return jinja_env.from_string(_report_template).render( + score=eval_result.score, + metrics=eval_result.metrics, + htmls=eval_result.htmls, + ) + + +def make_report_from_example_htmls(htmls: list[str]): + """ + Create a standalone HTML report from a list of example htmls + """ + return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls) + + +import requests +from tqdm import tqdm + +def download_dataset(path, url): + print(f"Downloading dataset {path} from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + diff --git a/python/sglang/test/eval_mmlu.py b/python/sglang/test/eval_mmlu.py new file mode 100644 index 00000000000..ac039e00336 --- /dev/null +++ b/python/sglang/test/eval_mmlu.py @@ -0,0 +1,164 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +Measuring Massive Multitask Language Understanding +Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt +https://arxiv.org/abs/2009.03300 +""" + +import argparse +import os +import random +import re +import time + +import pandas + +from sglang.test.eval_common import (ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question, Eval, EvalResult, SamplerBase, SingleEvalResult, download_dataset, ChatCompletionSampler, map_with_progress, jinja_env, aggregate_results, set_ulimit) + +subject2category = { + "abstract_algebra": "stem", + "anatomy": "other", + "astronomy": "stem", + "business_ethics": "other", + "clinical_knowledge": "other", + "college_biology": "stem", + "college_chemistry": "stem", + "college_computer_science": "stem", + "college_mathematics": "stem", + "college_medicine": "other", + "college_physics": "stem", + "computer_security": "stem", + "conceptual_physics": "stem", + "econometrics": "social_sciences", + "electrical_engineering": "stem", + "elementary_mathematics": "stem", + "formal_logic": "humanities", + "global_facts": "other", + "high_school_biology": "stem", + "high_school_chemistry": "stem", + "high_school_computer_science": "stem", + "high_school_european_history": "humanities", + "high_school_geography": "social_sciences", + "high_school_government_and_politics": "social_sciences", + "high_school_macroeconomics": "social_sciences", + "high_school_mathematics": "stem", + "high_school_microeconomics": "social_sciences", + "high_school_physics": "stem", + "high_school_psychology": "social_sciences", + "high_school_statistics": "stem", + "high_school_us_history": "humanities", + "high_school_world_history": "humanities", + "human_aging": "other", + "human_sexuality": "social_sciences", + "international_law": "humanities", + "jurisprudence": "humanities", + "logical_fallacies": "humanities", + "machine_learning": "stem", + "management": "other", + "marketing": "other", + "medical_genetics": "other", + "miscellaneous": "other", + "moral_disputes": "humanities", + "moral_scenarios": "humanities", + "nutrition": "other", + "philosophy": "humanities", + "prehistory": "humanities", + "professional_accounting": "other", + "professional_law": "humanities", + "professional_medicine": "other", + "professional_psychology": "social_sciences", + "public_relations": "social_sciences", + "security_studies": "social_sciences", + "sociology": "social_sciences", + "us_foreign_policy": "social_sciences", + "virology": "other", + "world_religions": "humanities", +} + + +class MMLUEval(Eval): + def __init__(self, filename: str, num_examples: int | None = None): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + if num_examples: + examples = random.Random(0).sample(examples, num_examples) + self.examples = examples + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + prompt_messages = [ + sampler._pack_message(content=format_multichoice_question(row), role="user") + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) + extracted_answer = match.group(1) if match else None + score = 1.0 if extracted_answer == row["Answer"] else 0.0 + html = jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=row["Answer"], + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + category = subject2category.get(row["Subject"], "other") + return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo) + + results = map_with_progress(fn, self.examples) + return aggregate_results(results) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset-path", type=str, default="mmlu.csv", help="Path to the dataset." + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--num-examples", + type=int, + help="The number of examples." + ) + set_ulimit() + args = parser.parse_args() + + base_url = f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" + + if not os.path.exists(args.dataset_path): + download_dataset(args.dataset_path, "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv") + eval_obj = MMLUEval(args.dataset_path, num_examples=args.num_examples) + sampler = ChatCompletionSampler( + model=args.model, + max_tokens=2048, + base_url=base_url, + ) + + tic = time.time() + + result = eval_obj(sampler) + metrics = result.metrics | {"score": result.score} + + latency = time.time() - tic + score = metrics["score"] + + print(f"Total latency: {latency:.3f} s") + print(f"Score: {score:.3f}") diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py deleted file mode 100644 index e6d9f396aa7..00000000000 --- a/python/sglang/test/test_conversation.py +++ /dev/null @@ -1,46 +0,0 @@ -from sglang.srt.conversation import generate_chat_conv -from sglang.srt.managers.openai_api.protocol import ( - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentImageURL, - ChatCompletionMessageContentTextPart, - ChatCompletionMessageGenericParam, - ChatCompletionMessageUserParam, - ChatCompletionRequest, -) - - -def test_chat_completion_to_conv_image(): - """Test that we can convert a chat image request to a convo""" - request = ChatCompletionRequest( - model="default", - messages=[ - ChatCompletionMessageGenericParam( - role="system", content="You are a helpful AI assistant" - ), - ChatCompletionMessageUserParam( - role="user", - content=[ - ChatCompletionMessageContentTextPart( - type="text", text="Describe this image" - ), - ChatCompletionMessageContentImagePart( - type="image_url", - image_url=ChatCompletionMessageContentImageURL( - url="https://someurl.com" - ), - ), - ], - ), - ], - ) - conv = generate_chat_conv(request, "vicuna_v1.1") - assert conv.messages == [ - ["USER", "Describe this image<image>"], - ["ASSISTANT", None], - ] - assert conv.system_message == "You are a helpful AI assistant" - assert conv.image_data == ["https://someurl.com"] - - -if __name__ == "__main__": - test_chat_completion_to_conv_image() diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py deleted file mode 100644 index cade4728cba..00000000000 --- a/python/sglang/test/test_openai_protocol.py +++ /dev/null @@ -1,51 +0,0 @@ -from sglang.srt.managers.openai_api.protocol import ( - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentImageURL, - ChatCompletionMessageContentTextPart, - ChatCompletionMessageGenericParam, - ChatCompletionMessageUserParam, - ChatCompletionRequest, -) - - -def test_chat_completion_request_image(): - """Test that Chat Completion Requests with images can be converted.""" - - image_request = { - "model": "default", - "messages": [ - {"role": "system", "content": "You are a helpful AI assistant"}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe this image"}, - {"type": "image_url", "image_url": {"url": "https://someurl.com"}}, - ], - }, - ], - "temperature": 0, - "max_tokens": 64, - } - request = ChatCompletionRequest(**image_request) - assert len(request.messages) == 2 - assert request.messages[0] == ChatCompletionMessageGenericParam( - role="system", content="You are a helpful AI assistant" - ) - assert request.messages[1] == ChatCompletionMessageUserParam( - role="user", - content=[ - ChatCompletionMessageContentTextPart( - type="text", text="Describe this image" - ), - ChatCompletionMessageContentImagePart( - type="image_url", - image_url=ChatCompletionMessageContentImageURL( - url="https://someurl.com" - ), - ), - ], - ) - - -if __name__ == "__main__": - test_chat_completion_request_image() From 3fe7abacdf5bdf3b0c810e919ae4dfbc03c0cbe7 Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 02:44:19 +0000 Subject: [PATCH 02/10] lint --- python/sglang/bench_serving.py | 5 +--- python/sglang/test/eval_common.py | 30 +++++++++++++++++------- python/sglang/test/eval_mmlu.py | 39 +++++++++++++++++++++++-------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index dc733012aa5..253aab355dd 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -868,13 +868,10 @@ def set_ulimit(target_soft_limit=65535): if __name__ == "__main__": - parser = ArgumentParser( - description="Benchmark the online serving throughput." - ) + parser = ArgumentParser(description="Benchmark the online serving throughput.") parser.add_argument( "--backend", type=str, - required=True, choices=list(ASYNC_REQUEST_FUNCS.keys()), default="sglang", help="Must specify a backend, depending on the LLM Inference Engine.", diff --git a/python/sglang/test/eval_common.py b/python/sglang/test/eval_common.py index 06bfac46dbb..64632c45a94 100644 --- a/python/sglang/test/eval_common.py +++ b/python/sglang/test/eval_common.py @@ -1,8 +1,8 @@ # Adapted from https://github.com/openai/simple-evals/ +import resource from dataclasses import dataclass, field from typing import Any -import resource Message = dict[str, Any] # keys role, content MessageList = list[Message] @@ -58,7 +58,6 @@ def __call__(self, sampler: SamplerBase) -> EvalResult: import openai from openai import OpenAI - OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." OPENAI_SYSTEM_MESSAGE_CHATGPT = ( "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." @@ -91,7 +90,11 @@ def __init__( self.image_format = "url" def _handle_image( - self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768 + self, + image: str, + encoding: str = "base64", + format: str = "png", + fovea: int = 768, ): new_image = { "type": "image_url", @@ -109,7 +112,9 @@ def _pack_message(self, role: str, content: Any): def __call__(self, message_list: MessageList) -> str: if self.system_message: - message_list = [self._pack_message("system", self.system_message)] + message_list + message_list = [ + self._pack_message("system", self.system_message) + ] + message_list trial = 0 while True: try: @@ -144,7 +149,6 @@ def __call__(self, message_list: MessageList) -> str: import numpy as np from tqdm import tqdm - QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. @@ -284,7 +288,10 @@ def aggregate_results( key = name if stat == "mean" else f"{name}:{stat}" final_metrics[key] = _compute_stat(values, stat) return EvalResult( - score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos + score=final_metrics.pop("score", None), + metrics=final_metrics, + htmls=htmls, + convos=convos, ) @@ -322,7 +329,9 @@ def message_to_html(message: Message) -> str: Generate HTML snippet (inside a <div>) for a message. """ return jinja_env.from_string(_message_template).render( - role=message["role"], content=message["content"], variant=message.get("variant", None) + role=message["role"], + content=message["content"], + variant=message.get("variant", None), ) @@ -410,12 +419,15 @@ def make_report_from_example_htmls(htmls: list[str]): """ Create a standalone HTML report from a list of example htmls """ - return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls) + return jinja_env.from_string(_report_template).render( + score=None, metrics={}, htmls=htmls + ) import requests from tqdm import tqdm + def download_dataset(path, url): print(f"Downloading dataset {path} from {url}") try: @@ -440,6 +452,7 @@ def download_dataset(path, url): except requests.RequestException as e: raise Exception(f"Failed to download dataset: {e}") + def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) @@ -449,4 +462,3 @@ def set_ulimit(target_soft_limit=65535): resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: print(f"Fail to set RLIMIT_NOFILE: {e}") - diff --git a/python/sglang/test/eval_mmlu.py b/python/sglang/test/eval_mmlu.py index ac039e00336..7b594511fa1 100644 --- a/python/sglang/test/eval_mmlu.py +++ b/python/sglang/test/eval_mmlu.py @@ -14,7 +14,21 @@ import pandas -from sglang.test.eval_common import (ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question, Eval, EvalResult, SamplerBase, SingleEvalResult, download_dataset, ChatCompletionSampler, map_with_progress, jinja_env, aggregate_results, set_ulimit) +from sglang.test.eval_common import ( + ANSWER_PATTERN_MULTICHOICE, + HTML_JINJA, + ChatCompletionSampler, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, + aggregate_results, + download_dataset, + format_multichoice_question, + jinja_env, + map_with_progress, + set_ulimit, +) subject2category = { "abstract_algebra": "stem", @@ -88,7 +102,9 @@ def __init__(self, filename: str, num_examples: int | None = None): def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): prompt_messages = [ - sampler._pack_message(content=format_multichoice_question(row), role="user") + sampler._pack_message( + content=format_multichoice_question(row), role="user" + ) ] response_text = sampler(prompt_messages) match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) @@ -103,7 +119,9 @@ def fn(row: dict): ) convo = prompt_messages + [dict(content=response_text, role="assistant")] category = subject2category.get(row["Subject"], "other") - return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo) + return SingleEvalResult( + html=html, score=score, metrics={category: score}, convo=convo + ) results = map_with_progress(fn, self.examples) return aggregate_results(results) @@ -133,18 +151,19 @@ def fn(row: dict): type=str, help="Name or path of the model. If not set, the default model will request /v1/models for conf.", ) - parser.add_argument( - "--num-examples", - type=int, - help="The number of examples." - ) + parser.add_argument("--num-examples", type=int, help="The number of examples.") set_ulimit() args = parser.parse_args() - base_url = f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" + base_url = ( + f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" + ) if not os.path.exists(args.dataset_path): - download_dataset(args.dataset_path, "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv") + download_dataset( + args.dataset_path, + "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ) eval_obj = MMLUEval(args.dataset_path, num_examples=args.num_examples) sampler = ChatCompletionSampler( model=args.model, From 88d5fad15be57edff5415cf81b51d7f9ae7455ac Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:09:44 +0000 Subject: [PATCH 03/10] fix style --- python/sglang/test/run_eval.py | 80 ++++++++++++++++ .../{eval_common.py => simple_eval_common.py} | 48 ++++------ .../{eval_mmlu.py => simple_eval_mmlu.py} | 95 +++---------------- 3 files changed, 111 insertions(+), 112 deletions(-) create mode 100644 python/sglang/test/run_eval.py rename python/sglang/test/{eval_common.py => simple_eval_common.py} (99%) rename python/sglang/test/{eval_mmlu.py => simple_eval_mmlu.py} (57%) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py new file mode 100644 index 00000000000..1f856bf3f9c --- /dev/null +++ b/python/sglang/test/run_eval.py @@ -0,0 +1,80 @@ +import argparse +import os +import time +import json + +from sglang.test.simple_eval_mmlu import MMLUEval +from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument("--eval-name", type=str, default="mmlu") + parser.add_argument("--num-examples", type=int) + parser.add_argument("--num-threads", type=int, default=64) + set_ulimit() + args = parser.parse_args() + + base_url = ( + f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" + ) + + if args.eval_name == "mmlu": + dataset_path = "mmlu.csv" + + if not os.path.exists(dataset_path): + download_dataset( + dataset_path, + "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ) + eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads) + else: + raise ValueError(f"Invalid eval name: {args.eval_name}") + + sampler = ChatCompletionSampler( + model=args.model, + max_tokens=2048, + base_url=base_url, + ) + + # Run eval + tic = time.time() + result = eval_obj(sampler) + latency = time.time() - tic + + # Dump reports + metrics = result.metrics | {"score": result.score} + file_stem = f"mmlu_{sampler.model.replace('/', '_')}" + report_filename = f"/tmp/{file_stem}.html" + print(f"Writing report to {report_filename}") + with open(report_filename, "w") as fh: + fh.write(make_report(result)) + metrics = result.metrics | {"score": result.score} + print(metrics) + result_filename = f"/tmp/{file_stem}.json" + with open(result_filename, "w") as f: + f.write(json.dumps(metrics, indent=2)) + print(f"Writing results to {result_filename}") + + # Print results + print(f"Total latency: {latency:.3f} s") + print(f"Score: {metrics['score']:.3f}") diff --git a/python/sglang/test/eval_common.py b/python/sglang/test/simple_eval_common.py similarity index 99% rename from python/sglang/test/eval_common.py rename to python/sglang/test/simple_eval_common.py index 64632c45a94..75c26f0f037 100644 --- a/python/sglang/test/eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -1,9 +1,28 @@ # Adapted from https://github.com/openai/simple-evals/ +import base64 +import os import resource +import time +from collections import defaultdict from dataclasses import dataclass, field +from multiprocessing.pool import ThreadPool from typing import Any +import jinja2 +import numpy as np +import openai +import requests +from openai import OpenAI +from tqdm import tqdm + +OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." +OPENAI_SYSTEM_MESSAGE_CHATGPT = ( + "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." + + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" +) + + Message = dict[str, Any] # keys role, content MessageList = list[Message] @@ -51,20 +70,6 @@ def __call__(self, sampler: SamplerBase) -> EvalResult: raise NotImplementedError() -import base64 -import time -from typing import Any - -import openai -from openai import OpenAI - -OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." -OPENAI_SYSTEM_MESSAGE_CHATGPT = ( - "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." - + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" -) - - class ChatCompletionSampler(SamplerBase): """ Sample from OpenAI's chat completion API @@ -140,15 +145,6 @@ def __call__(self, message_list: MessageList) -> str: # unknown error shall throw exception -import os -from collections import defaultdict -from multiprocessing.pool import ThreadPool -from typing import Any - -import jinja2 -import numpy as np -from tqdm import tqdm - QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. @@ -295,7 +291,7 @@ def aggregate_results( ) -def map_with_progress(f: callable, xs: list[Any], num_threads: int = 50): +def map_with_progress(f: callable, xs: list[Any], num_threads: int): """ Apply f to each element of xs, using a ThreadPool, and show progress. """ @@ -424,10 +420,6 @@ def make_report_from_example_htmls(htmls: list[str]): ) -import requests -from tqdm import tqdm - - def download_dataset(path, url): print(f"Downloading dataset {path} from {url}") try: diff --git a/python/sglang/test/eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py similarity index 57% rename from python/sglang/test/eval_mmlu.py rename to python/sglang/test/simple_eval_mmlu.py index 7b594511fa1..6467f7760dc 100644 --- a/python/sglang/test/eval_mmlu.py +++ b/python/sglang/test/simple_eval_mmlu.py @@ -6,29 +6,15 @@ https://arxiv.org/abs/2009.03300 """ -import argparse -import os import random import re -import time import pandas -from sglang.test.eval_common import ( - ANSWER_PATTERN_MULTICHOICE, - HTML_JINJA, - ChatCompletionSampler, - Eval, - EvalResult, - SamplerBase, - SingleEvalResult, - aggregate_results, - download_dataset, - format_multichoice_question, - jinja_env, - map_with_progress, - set_ulimit, -) +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question +from sglang.test.simple_eval_common import Eval, EvalResult, SamplerBase, SingleEvalResult + subject2category = { "abstract_algebra": "stem", @@ -92,25 +78,24 @@ class MMLUEval(Eval): - def __init__(self, filename: str, num_examples: int | None = None): + def __init__(self, filename: str, num_examples: int | None, num_threads: int): df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: examples = random.Random(0).sample(examples, num_examples) self.examples = examples + self.num_threads = num_threads def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): prompt_messages = [ - sampler._pack_message( - content=format_multichoice_question(row), role="user" - ) + sampler._pack_message(content=format_multichoice_question(row), role="user") ] response_text = sampler(prompt_messages) match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) extracted_answer = match.group(1) if match else None score = 1.0 if extracted_answer == row["Answer"] else 0.0 - html = jinja_env.from_string(HTML_JINJA).render( + html = common.jinja_env.from_string(HTML_JINJA).render( prompt_messages=prompt_messages, next_message=dict(content=response_text, role="assistant"), score=score, @@ -119,65 +104,7 @@ def fn(row: dict): ) convo = prompt_messages + [dict(content=response_text, role="assistant")] category = subject2category.get(row["Subject"], "other") - return SingleEvalResult( - html=html, score=score, metrics={category: score}, convo=convo - ) - - results = map_with_progress(fn, self.examples) - return aggregate_results(results) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dataset-path", type=str, default="mmlu.csv", help="Path to the dataset." - ) - parser.add_argument( - "--base-url", - type=str, - default=None, - help="Server or API base url if not using http host and port.", - ) - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." - ) - parser.add_argument( - "--port", - type=int, - help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", - ) - parser.add_argument( - "--model", - type=str, - help="Name or path of the model. If not set, the default model will request /v1/models for conf.", - ) - parser.add_argument("--num-examples", type=int, help="The number of examples.") - set_ulimit() - args = parser.parse_args() - - base_url = ( - f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" - ) - - if not os.path.exists(args.dataset_path): - download_dataset( - args.dataset_path, - "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", - ) - eval_obj = MMLUEval(args.dataset_path, num_examples=args.num_examples) - sampler = ChatCompletionSampler( - model=args.model, - max_tokens=2048, - base_url=base_url, - ) - - tic = time.time() - - result = eval_obj(sampler) - metrics = result.metrics | {"score": result.score} - - latency = time.time() - tic - score = metrics["score"] + return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo) - print(f"Total latency: {latency:.3f} s") - print(f"Score: {score:.3f}") + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) From af51e277579e8e24c08987d4024c7fd10fff9356 Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:11:49 +0000 Subject: [PATCH 04/10] update --- python/sglang/test/run_eval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 1f856bf3f9c..afa5ca38d7e 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -1,10 +1,14 @@ +""" +Usage: +python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10 +""" import argparse import os import time import json from sglang.test.simple_eval_mmlu import MMLUEval -from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report +from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report, download_dataset if __name__ == "__main__": From 496552fec3dd7b13fa28befc74573bf25db2cd93 Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:12:35 +0000 Subject: [PATCH 05/10] lint --- python/sglang/test/run_eval.py | 11 ++++++++--- python/sglang/test/simple_eval_mmlu.py | 20 +++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index afa5ca38d7e..430e00d2834 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -2,14 +2,19 @@ Usage: python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10 """ + import argparse +import json import os import time -import json +from sglang.test.simple_eval_common import ( + ChatCompletionSampler, + download_dataset, + make_report, + set_ulimit, +) from sglang.test.simple_eval_mmlu import MMLUEval -from sglang.test.simple_eval_common import ChatCompletionSampler, set_ulimit, make_report, download_dataset - if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py index 6467f7760dc..3c0287510cb 100644 --- a/python/sglang/test/simple_eval_mmlu.py +++ b/python/sglang/test/simple_eval_mmlu.py @@ -12,9 +12,15 @@ import pandas from sglang.test import simple_eval_common as common -from sglang.test.simple_eval_common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question -from sglang.test.simple_eval_common import Eval, EvalResult, SamplerBase, SingleEvalResult - +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN_MULTICHOICE, + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, + format_multichoice_question, +) subject2category = { "abstract_algebra": "stem", @@ -89,7 +95,9 @@ def __init__(self, filename: str, num_examples: int | None, num_threads: int): def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): prompt_messages = [ - sampler._pack_message(content=format_multichoice_question(row), role="user") + sampler._pack_message( + content=format_multichoice_question(row), role="user" + ) ] response_text = sampler(prompt_messages) match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) @@ -104,7 +112,9 @@ def fn(row: dict): ) convo = prompt_messages + [dict(content=response_text, role="assistant")] category = subject2category.get(row["Subject"], "other") - return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo) + return SingleEvalResult( + html=html, score=score, metrics={category: score}, convo=convo + ) results = common.map_with_progress(fn, self.examples, self.num_threads) return common.aggregate_results(results) From 4250c24e24e193bb2d8f99afb5db72335ecaed4a Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:30:31 +0000 Subject: [PATCH 06/10] test mmlu cot accuracy --- .github/workflows/e2e-test.yml | 4 +-- .github/workflows/unit-test.yml | 7 ++++ python/sglang/test/run_eval.py | 59 ++++++++++++++++++-------------- python/sglang/test/test_utils.py | 30 ++++++++++++++++ test/srt/test_eval_accuracy.py | 43 +++++++++++++++++++++++ test/srt/test_openai_server.py | 36 +++---------------- 6 files changed, 119 insertions(+), 60 deletions(-) create mode 100644 test/srt/test_eval_accuracy.py diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 9630ca71838..7b59054fe12 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -18,7 +18,7 @@ concurrency: cancel-in-progress: true jobs: - pr-e2e-test: + e2e-test: runs-on: self-hosted env: @@ -38,7 +38,7 @@ jobs: pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall pip install --upgrade transformers - - name: Benchmark Serving + - name: Benchmark Serving Throughput run: | cd /data/zhyncs/venv && source ./bin/activate && cd - python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache & diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index dc464fa8cc2..f1c069ea5b6 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -59,3 +59,10 @@ jobs: cd test/srt python3 test_openai_server.py + + - name: Test Accuracy + run: | + cd /data/zhyncs/venv && source ./bin/activate && cd - + + cd test/srt + python3 test_eval_accuracy.py diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 430e00d2834..6433f49474d 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -16,33 +16,8 @@ ) from sglang.test.simple_eval_mmlu import MMLUEval -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--base-url", - type=str, - default=None, - help="Server or API base url if not using http host and port.", - ) - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." - ) - parser.add_argument( - "--port", - type=int, - help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", - ) - parser.add_argument( - "--model", - type=str, - help="Name or path of the model. If not set, the default model will request /v1/models for conf.", - ) - parser.add_argument("--eval-name", type=str, default="mmlu") - parser.add_argument("--num-examples", type=int) - parser.add_argument("--num-threads", type=int, default=64) - set_ulimit() - args = parser.parse_args() +def run_eval(args): base_url = ( f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" ) @@ -87,3 +62,35 @@ # Print results print(f"Total latency: {latency:.3f} s") print(f"Score: {metrics['score']:.3f}") + + return metrics + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument("--eval-name", type=str, default="mmlu") + parser.add_argument("--num-examples", type=int) + parser.add_argument("--num-threads", type=int, default=64) + set_ulimit() + args = parser.parse_args() + + run_eval(args) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index af7f3765ef4..5cb2c4f56c6 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1,6 +1,8 @@ """Common utilities for testing and benchmarking""" import asyncio +import subprocess +import time from functools import partial import numpy as np @@ -379,3 +381,31 @@ def func(*args, **kwargs): raise return func + + +def popen_launch_server(model, port, timeout, *args): + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + "localhost", + "--port", + str(port), + *args + ] + process = subprocess.Popen(command, stdout=None, stderr=None) + base_url = f"http://localhost:{port}/v1" + + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"{base_url}/models") + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(10) + raise TimeoutError("Server failed to start within the timeout period.") diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py new file mode 100644 index 00000000000..a54679efa71 --- /dev/null +++ b/test/srt/test_eval_accuracy.py @@ -0,0 +1,43 @@ +import json +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import popen_launch_server +from sglang.test.run_eval import run_eval + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + port = 30000 + + cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + cls.base_url = f"http://localhost:{port}" + cls.process = popen_launch_server(cls.model, port, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index e15c2ba88ff..6c03a40db36 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,47 +1,21 @@ import json -import subprocess -import time import unittest import openai -import requests from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import popen_launch_server class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): - model = "meta-llama/Meta-Llama-3.1-8B-Instruct" port = 30000 - timeout = 300 - - command = [ - "python3", - "-m", - "sglang.launch_server", - "--model-path", - model, - "--host", - "localhost", - "--port", - str(port), - ] - cls.process = subprocess.Popen(command, stdout=None, stderr=None) + + cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct" cls.base_url = f"http://localhost:{port}/v1" - cls.model = model - - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = requests.get(f"{cls.base_url}/models") - if response.status_code == 200: - return - except requests.RequestException: - pass - time.sleep(10) - raise TimeoutError("Server failed to start within the timeout period.") + cls.process = popen_launch_server(cls.model, port, timeout=300) @classmethod def tearDownClass(cls): @@ -178,8 +152,6 @@ def run_chat_completion_stream(self, logprobs): is_first = True for response in generator: - print(response) - data = response.choices[0].delta if is_first: data.role == "assistant" From 13b7e45fda53de80a246f402dd9ce097cd4baef6 Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:33:01 +0000 Subject: [PATCH 07/10] lint --- python/sglang/test/test_utils.py | 2 +- test/srt/test_eval_accuracy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 5cb2c4f56c6..f0cdf69c751 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -394,7 +394,7 @@ def popen_launch_server(model, port, timeout, *args): "localhost", "--port", str(port), - *args + *args, ] process = subprocess.Popen(command, stdout=None, stderr=None) base_url = f"http://localhost:{port}/v1" diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py index a54679efa71..436cc18bffe 100644 --- a/test/srt/test_eval_accuracy.py +++ b/test/srt/test_eval_accuracy.py @@ -3,8 +3,8 @@ from types import SimpleNamespace from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import popen_launch_server from sglang.test.run_eval import run_eval +from sglang.test.test_utils import popen_launch_server class TestAccuracy(unittest.TestCase): From 6f373feae4fa535231698b91b5f293602e29b350 Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:49:52 +0000 Subject: [PATCH 08/10] add test_srt_endpoint --- python/sglang/test/run_eval.py | 3 ++ python/sglang/test/test_utils.py | 2 ++ test/lang/test_srt_backend.py | 3 +- test/srt/test_eval_accuracy.py | 4 +-- test/srt/test_openai_server.py | 4 +-- test/srt/test_srt_endpoint.py | 61 ++++++++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 test/srt/test_srt_endpoint.py diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 6433f49474d..3729ef7ab26 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -18,6 +18,9 @@ def run_eval(args): + if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = "EMPTY" + base_url = ( f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" ) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index f0cdf69c751..4348b57e9b0 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -13,6 +13,8 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.utils import get_exception_traceback +MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" + def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): assert url is not None diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index f9d79ed290e..7accd349f81 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -14,6 +14,7 @@ test_stream, test_tool_use, ) +from sglang.test.test_utils import MODEL_NAME_FOR_TEST class TestSRTBackend(unittest.TestCase): @@ -21,7 +22,7 @@ class TestSRTBackend(unittest.TestCase): @classmethod def setUpClass(cls): - cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST) sgl.set_default_backend(cls.backend) @classmethod diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py index 436cc18bffe..d392dc4c066 100644 --- a/test/srt/test_eval_accuracy.py +++ b/test/srt/test_eval_accuracy.py @@ -4,7 +4,7 @@ from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval -from sglang.test.test_utils import popen_launch_server +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server class TestAccuracy(unittest.TestCase): @@ -13,7 +13,7 @@ class TestAccuracy(unittest.TestCase): def setUpClass(cls): port = 30000 - cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + cls.model = MODEL_NAME_FOR_TEST cls.base_url = f"http://localhost:{port}" cls.process = popen_launch_server(cls.model, port, timeout=300) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 6c03a40db36..76a105a6254 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -4,7 +4,7 @@ import openai from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import popen_launch_server +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server class TestOpenAIServer(unittest.TestCase): @@ -13,7 +13,7 @@ class TestOpenAIServer(unittest.TestCase): def setUpClass(cls): port = 30000 - cls.model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + cls.model = MODEL_NAME_FOR_TEST cls.base_url = f"http://localhost:{port}/v1" cls.process = popen_launch_server(cls.model, port, timeout=300) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py new file mode 100644 index 00000000000..72b9e753dd0 --- /dev/null +++ b/test/srt/test_srt_endpoint.py @@ -0,0 +1,61 @@ +import json +import unittest +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import popen_launch_server, MODEL_NAME_FOR_TEST + + +class TestSRTEndpoint(unittest.TestCase): + + @classmethod + def setUpClass(cls): + port = 30000 + + cls.model = MODEL_NAME_FOR_TEST + cls.base_url = f"http://localhost:{port}" + cls.process = popen_launch_server(cls.model, port, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 32, + "n": n, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + + def test_simple_decode(self): + self.run_decode() + + def test_parallel_sample(self): + self.run_decode(n=3) + + def test_logprob(self): + for top_logprobs_num in [0, 3]: + for return_text in [True, False]: + self.run_decode( + return_logprob=True, + top_logprobs_num=top_logprobs_num, + return_text=return_text, + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") From 208ab71d480154fc2350def23f547fd22abdae92 Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:50:46 +0000 Subject: [PATCH 09/10] rename --- test/srt/{old => deprecated}/test_curl.sh | 0 test/srt/{old => deprecated}/test_flashinfer.py | 0 test/srt/{old => deprecated}/test_httpserver_classify.py | 0 test/srt/{old => deprecated}/test_httpserver_concurrent.py | 0 test/srt/{old => deprecated}/test_httpserver_decode.py | 0 test/srt/{old => deprecated}/test_httpserver_decode_stream.py | 0 test/srt/{old => deprecated}/test_httpserver_llava.py | 0 test/srt/{old => deprecated}/test_httpserver_reuse.py | 0 test/srt/{old => deprecated}/test_jump_forward.py | 0 test/srt/{old => deprecated}/test_openai_server.py | 0 test/srt/{old => deprecated}/test_robust.py | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename test/srt/{old => deprecated}/test_curl.sh (100%) rename test/srt/{old => deprecated}/test_flashinfer.py (100%) rename test/srt/{old => deprecated}/test_httpserver_classify.py (100%) rename test/srt/{old => deprecated}/test_httpserver_concurrent.py (100%) rename test/srt/{old => deprecated}/test_httpserver_decode.py (100%) rename test/srt/{old => deprecated}/test_httpserver_decode_stream.py (100%) rename test/srt/{old => deprecated}/test_httpserver_llava.py (100%) rename test/srt/{old => deprecated}/test_httpserver_reuse.py (100%) rename test/srt/{old => deprecated}/test_jump_forward.py (100%) rename test/srt/{old => deprecated}/test_openai_server.py (100%) rename test/srt/{old => deprecated}/test_robust.py (100%) diff --git a/test/srt/old/test_curl.sh b/test/srt/deprecated/test_curl.sh similarity index 100% rename from test/srt/old/test_curl.sh rename to test/srt/deprecated/test_curl.sh diff --git a/test/srt/old/test_flashinfer.py b/test/srt/deprecated/test_flashinfer.py similarity index 100% rename from test/srt/old/test_flashinfer.py rename to test/srt/deprecated/test_flashinfer.py diff --git a/test/srt/old/test_httpserver_classify.py b/test/srt/deprecated/test_httpserver_classify.py similarity index 100% rename from test/srt/old/test_httpserver_classify.py rename to test/srt/deprecated/test_httpserver_classify.py diff --git a/test/srt/old/test_httpserver_concurrent.py b/test/srt/deprecated/test_httpserver_concurrent.py similarity index 100% rename from test/srt/old/test_httpserver_concurrent.py rename to test/srt/deprecated/test_httpserver_concurrent.py diff --git a/test/srt/old/test_httpserver_decode.py b/test/srt/deprecated/test_httpserver_decode.py similarity index 100% rename from test/srt/old/test_httpserver_decode.py rename to test/srt/deprecated/test_httpserver_decode.py diff --git a/test/srt/old/test_httpserver_decode_stream.py b/test/srt/deprecated/test_httpserver_decode_stream.py similarity index 100% rename from test/srt/old/test_httpserver_decode_stream.py rename to test/srt/deprecated/test_httpserver_decode_stream.py diff --git a/test/srt/old/test_httpserver_llava.py b/test/srt/deprecated/test_httpserver_llava.py similarity index 100% rename from test/srt/old/test_httpserver_llava.py rename to test/srt/deprecated/test_httpserver_llava.py diff --git a/test/srt/old/test_httpserver_reuse.py b/test/srt/deprecated/test_httpserver_reuse.py similarity index 100% rename from test/srt/old/test_httpserver_reuse.py rename to test/srt/deprecated/test_httpserver_reuse.py diff --git a/test/srt/old/test_jump_forward.py b/test/srt/deprecated/test_jump_forward.py similarity index 100% rename from test/srt/old/test_jump_forward.py rename to test/srt/deprecated/test_jump_forward.py diff --git a/test/srt/old/test_openai_server.py b/test/srt/deprecated/test_openai_server.py similarity index 100% rename from test/srt/old/test_openai_server.py rename to test/srt/deprecated/test_openai_server.py diff --git a/test/srt/old/test_robust.py b/test/srt/deprecated/test_robust.py similarity index 100% rename from test/srt/old/test_robust.py rename to test/srt/deprecated/test_robust.py From 0134e5f8faa05e42880bc3c0e929124026472a8a Mon Sep 17 00:00:00 2001 From: Ying Sheng <sqy1415@gmail.com> Date: Fri, 2 Aug 2024 03:57:22 +0000 Subject: [PATCH 10/10] lint --- test/srt/test_srt_endpoint.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 72b9e753dd0..3454678586b 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -1,10 +1,11 @@ import json import unittest + import requests from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval -from sglang.test.test_utils import popen_launch_server, MODEL_NAME_FOR_TEST +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server class TestSRTEndpoint(unittest.TestCase): @@ -21,7 +22,9 @@ def setUpClass(cls): def tearDownClass(cls): kill_child_process(cls.process.pid) - def run_decode(self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1): + def run_decode( + self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 + ): response = requests.post( self.base_url + "/generate", json={