Skip to content

Commit

Permalink
move run_grpc_server to grpc module
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro committed Jun 20, 2024
1 parent 89ed62f commit d1479e8
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 50 deletions.
30 changes: 2 additions & 28 deletions src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.usage.usage_lib import UsageContext

from .grpc import start_grpc_server
from .grpc import run_grpc_server
from .logging import init_logger
from .tgis_utils.args import EnvVarArgumentParser, add_tgis_args, postprocess_tgis_args

Expand Down Expand Up @@ -228,32 +228,6 @@ async def run_http_server(
await server.shutdown()


async def run_grpc_server(
engine: AsyncLLMEngine,
*,
disable_log_stats: bool,
) -> None:
async def _force_log() -> None:
while True:
await asyncio.sleep(10)
await engine.do_log_stats()

if not disable_log_stats:
asyncio.create_task(_force_log()) # noqa: RUF006

assert args is not None

server = await start_grpc_server(engine, args)

try:
while True:
await asyncio.sleep(60)
except asyncio.CancelledError:
print("Gracefully stopping gRPC server") # noqa: T201
await server.stop(30) # TODO configurable grace
await server.wait_for_termination()


if __name__ == "__main__":
# convert to our custom env var arg parser
parser = EnvVarArgumentParser(parser=make_arg_parser())
Expand Down Expand Up @@ -322,7 +296,7 @@ async def _force_log() -> None:
run_http_server(engine, args, model_config)
)
grpc_server_task = event_loop.create_task(
run_grpc_server(engine, disable_log_stats=engine_args.disable_log_stats)
run_grpc_server(engine, args, disable_log_stats=engine_args.disable_log_stats)
)

def signal_handler() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/vllm_tgis_adapter/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .grpc_server import start_grpc_server
from .grpc_server import run_grpc_server
28 changes: 28 additions & 0 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import inspect
import time
import uuid
Expand Down Expand Up @@ -804,3 +805,30 @@ async def start_grpc_server(
logger.info("gRPC Server started at %s", listen_on)

return server


async def run_grpc_server(
engine: AsyncLLMEngine,
args: argparse.Namespace,
*,
disable_log_stats: bool,
) -> None:
async def _force_log() -> None:
while True:
await asyncio.sleep(10)
await engine.do_log_stats()

if not disable_log_stats:
asyncio.create_task(_force_log()) # noqa: RUF006

assert args is not None

server = await start_grpc_server(engine, args)

try:
while True:
await asyncio.sleep(60)
except asyncio.CancelledError:
print("Gracefully stopping gRPC server") # noqa: T201
await server.stop(30) # TODO configurable grace
await server.wait_for_termination()
31 changes: 12 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding

from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService, start_grpc_server
import vllm_tgis_adapter
from vllm_tgis_adapter.__main__ import run_http_server
from vllm_tgis_adapter.grpc import run_grpc_server
from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService
from vllm_tgis_adapter.healthcheck import health_check
from vllm_tgis_adapter.tgis_utils.args import (
EnvVarArgumentParser,
Expand Down Expand Up @@ -90,7 +93,7 @@ def grpc_server_url(grpc_server_thread_port):


@pytest.fixture()
def grpc_server(engine, args, grpc_server_url):
def _grpc_server(engine, args, grpc_server_url):
"""Spins up the grpc server in a background thread."""

def _health_check():
Expand All @@ -101,34 +104,24 @@ def _health_check():
service=TextGenerationService.SERVICE_NAME,
)

global server # noqa: PLW0602

loop = asyncio.new_event_loop()

async def run_server():
global server # noqa: PLW0603

server = await start_grpc_server(engine, args)
while server._server.is_running(): # noqa: SLF001
await asyncio.sleep(1)
global task # noqa: PLW0602

def target():
loop.run_until_complete(run_server())
global task # noqa: PLW0603

task = loop.create_task(run_grpc_server(engine, args, disable_log_stats=False))
loop.run_until_complete(task)

t = threading.Thread(target=target)
t.start()

async def stop():
global server # noqa: PLW0602

await server.stop(grace=None)
await server.wait_for_termination()

try:
wait_until(_health_check)
yield server
yield
finally:
loop.create_task(stop()) # noqa: RUF006
task.cancel()
t.join()


Expand Down
7 changes: 5 additions & 2 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
def test_startup(grpc_server):
import pytest


@pytest.mark.usefixtures("_grpc_server")
def test_startup():
"""Test that the grpc_server fixture starts up properly."""
assert grpc_server._server.is_running() # noqa: SLF001

0 comments on commit d1479e8

Please sign in to comment.