Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unit tests for the frontend language part #872

Merged
merged 8 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: lint
name: Lint

on: [push, pull_request]

Expand Down
1 change: 0 additions & 1 deletion python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def to_litellm_kwargs(self):
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def shutdown(self):
parent.wait(timeout=5)
self.pid = None

def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)

def get_tokenizer(self):
return get_tokenizer(
self.server_args.tokenizer_path,
Expand Down
15 changes: 9 additions & 6 deletions python/sglang/test/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def decode_json(s):
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += "}"

ret = decode_json.run()
ret = decode_json.run(temperature=0.0)
try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print(ret["json_output"])
print("JSONDecodeError", ret["json_output"])
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
Expand All @@ -141,8 +140,12 @@ def decode_json(s):
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
s += "}"

ret = decode_json.run()
js_obj = json.loads(ret["json_output"])
ret = decode_json.run(max_new_tokens=64)
try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print("JSONDecodeError", ret["json_output"])
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)

Expand Down
16 changes: 15 additions & 1 deletion test/lang/run_all.py → test/lang/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

from sglang.utils import run_with_timeout

suites = {
"minimal": ["test_openai_backend.py", "test_srt_backend.py"],
}


def run_unittest_files(files, args):
for filename in files:
Expand Down Expand Up @@ -45,9 +49,19 @@ def run_one_file():
default=1000,
help="The time limit for running one file in seconds.",
)
arg_parser.add_argument(
"--suite",
type=str,
default=list(suites.keys())[0],
choices=list(suites.keys()) + ["all"],
help="The suite to run",
)
args = arg_parser.parse_args()

files = glob.glob("**/test_*.py", recursive=True)
if args.suite == "all":
files = glob.glob("**/test_*.py", recursive=True)
else:
files = suites[args.suite]

tic = time.time()
success = run_unittest_files(files, args)
Expand Down
13 changes: 5 additions & 8 deletions test/lang/test_anthropic_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@

class TestAnthropicBackend(unittest.TestCase):
backend = None
chat_backend = None

def setUp(self):
cls = type(self)

if cls.backend is None:
cls.backend = Anthropic("claude-3-haiku-20240307")
set_default_backend(cls.backend)
@classmethod
def setUpClass(cls):
cls.backend = Anthropic("claude-3-haiku-20240307")
set_default_backend(cls.backend)

def test_mt_bench(self):
test_mt_bench()
Expand All @@ -30,5 +27,5 @@ def test_stream(self):

# global_config.verbosity = 2
# t = TestAnthropicBackend()
# t.setUp()
# t.setUpClass()
# t.test_mt_bench()
20 changes: 8 additions & 12 deletions test/lang/test_bind_cache.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
"""
Usage:
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
python3 test_bind_cache.py
"""

import unittest

import sglang as sgl
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint


class TestBind(unittest.TestCase):
backend = None

def setUp(self):
cls = type(self)
@classmethod
def setUpClass(cls):
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct")
sgl.set_default_backend(cls.backend)

if cls.backend is None:
cls.backend = RuntimeEndpoint(base_url="http://localhost:30000")
@classmethod
def tearDownClass(cls):
cls.backend.shutdown()

def test_bind(self):
@sgl.function
Expand Down Expand Up @@ -54,5 +50,5 @@ def few_shot_qa(s, prompt, question):
unittest.main(warnings="ignore")

# t = TestBind()
# t.setUp()
# t.setUpClass()
# t.test_cache()
11 changes: 4 additions & 7 deletions test/lang/test_litellm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@


class TestAnthropicBackend(unittest.TestCase):
backend = None
chat_backend = None

def setUp(self):
cls = type(self)

if cls.backend is None:
cls.backend = LiteLLM("gpt-3.5-turbo")
set_default_backend(cls.backend)
@classmethod
def setUpClass(cls):
cls.chat_backend = LiteLLM("gpt-3.5-turbo")
set_default_backend(cls.chat_backend)

def test_mt_bench(self):
test_mt_bench()
Expand Down
40 changes: 19 additions & 21 deletions test/lang/test_openai_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,68 +20,66 @@


class TestOpenAIBackend(unittest.TestCase):
backend = None
instruct_backend = None
chat_backend = None
chat_vision_backend = None

def setUp(self):
cls = type(self)

if cls.backend is None:
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
@classmethod
def setUpClass(cls):
cls.instruct_backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-turbo")

def test_few_shot_qa(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_few_shot_qa()

def test_mt_bench(self):
set_default_backend(self.chat_backend)
test_mt_bench()

def test_select(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_select(check_answer=True)

def test_decode_int(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_decode_int()

def test_decode_json(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_decode_json()

def test_expert_answer(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_expert_answer()

def test_tool_use(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_tool_use()

def test_react(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_react()

def test_parallel_decoding(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_parallel_decoding()

def test_parallel_encoding(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_parallel_encoding()

def test_image_qa(self):
set_default_backend(self.chat_vision_backend)
test_image_qa()

def test_stream(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_stream()

def test_completion_speculative(self):
set_default_backend(self.backend)
set_default_backend(self.instruct_backend)
test_completion_speculative()

def test_chat_completion_speculative(self):
Expand All @@ -96,5 +94,5 @@ def test_chat_completion_speculative(self):

# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t.setUp()
# t.test_chat_completion_speculative()
# t.setUpClass()
# t.test_stream()
28 changes: 10 additions & 18 deletions test/lang/test_srt_backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
"""
Usage:
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
python3 test_srt_backend.py
"""

import json
import unittest

Expand All @@ -15,8 +9,6 @@
test_few_shot_qa,
test_mt_bench,
test_parallel_decoding,
test_parallel_encoding,
test_react,
test_regex,
test_select,
test_stream,
Expand All @@ -27,12 +19,14 @@
class TestSRTBackend(unittest.TestCase):
backend = None

def setUp(self):
cls = type(self)
@classmethod
def setUpClass(cls):
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct")
sgl.set_default_backend(cls.backend)

if cls.backend is None:
cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000")
sgl.set_default_backend(cls.backend)
@classmethod
def tearDownClass(cls):
cls.backend.shutdown()

def test_few_shot_qa(self):
test_few_shot_qa()
Expand Down Expand Up @@ -64,9 +58,6 @@ def test_stream(self):
def test_regex(self):
test_regex()

# def test_parallel_encoding(self):
# test_parallel_encoding(check_answer=False)


if __name__ == "__main__":
unittest.main(warnings="ignore")
Expand All @@ -75,5 +66,6 @@ def test_regex(self):

# global_config.verbosity = 2
# t = TestSRTBackend()
# t.setUp()
# t.test_regex()
# t.setUpClass()
# t.test_few_shot_qa()
# t.tearDownClass()
14 changes: 6 additions & 8 deletions test/lang/test_vertexai_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ class TestVertexAIBackend(unittest.TestCase):
chat_backend = None
chat_vision_backend = None

def setUp(self):
cls = type(self)

if cls.backend is None:
cls.backend = VertexAI("gemini-pro")
cls.chat_backend = VertexAI("gemini-pro")
cls.chat_vision_backend = VertexAI("gemini-pro-vision")
@classmethod
def setUpClass(cls):
cls.backend = VertexAI("gemini-pro")
cls.chat_backend = VertexAI("gemini-pro")
cls.chat_vision_backend = VertexAI("gemini-pro-vision")

def test_few_shot_qa(self):
set_default_backend(self.backend)
Expand Down Expand Up @@ -61,5 +59,5 @@ def test_stream(self):

# global_config.verbosity = 2
# t = TestVertexAIBackend()
# t.setUp()
# t.setUpClass()
# t.test_stream()
Loading