Skip to content

Commit

Permalink
FEAT: Prometheus metrics exporter (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 19, 2024
1 parent 145ab64 commit c1e1c5a
Show file tree
Hide file tree
Showing 18 changed files with 536 additions and 33 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip install -U xoscar
${{ env.SELF_HOST_PYTHON }} -m pip install -U "python-jose[cryptography]"
${{ env.SELF_HOST_PYTHON }} -m pip install -U "passlib[bcrypt]"
${{ env.SELF_HOST_PYTHON }} -m pip install -U "aioprometheus[starlette]"
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ install_requires =
openai>1 # For typing
python-jose[cryptography]
passlib[bcrypt]
aioprometheus[starlette]

[options.packages.find]
exclude =
Expand Down
8 changes: 8 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import gradio as gr
import pydantic
import xoscar as xo
from aioprometheus import REGISTRY, MetricsMiddleware
from aioprometheus.asgi.starlette import metrics
from fastapi import (
APIRouter,
FastAPI,
Expand Down Expand Up @@ -389,7 +391,13 @@ def serve(self, logging_conf: Optional[dict] = None):
else None,
)

# Clear the global Registry for the MetricsMiddleware, or
# the MetricsMiddleware will register duplicated metrics if the port
# conflict (This serve method run more than once).
REGISTRY.clear()
self._app.add_middleware(MetricsMiddleware)
self._app.include_router(self._router)
self._app.add_route("/metrics", metrics)

# Check all the routes returns Response.
# This is to avoid `jsonable_encoder` performance issue:
Expand Down
4 changes: 3 additions & 1 deletion xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,9 @@ def setup_cluster():
from ...deploy.local import run_in_subprocess as supervisor_run_in_subprocess

supervisor_address = f"localhost:{xo.utils.get_next_port()}"
local_cluster = supervisor_run_in_subprocess(supervisor_address, TEST_LOGGING_CONF)
local_cluster = supervisor_run_in_subprocess(
supervisor_address, None, None, TEST_LOGGING_CONF
)

if not health_check(address=supervisor_address, max_attempts=20, sleep_interval=1):
raise RuntimeError("Supervisor is not available after multiple attempts")
Expand Down
6 changes: 5 additions & 1 deletion xinference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ async def _start_test_cluster(
SupervisorActor, address=address, uid=SupervisorActor.uid()
)
await start_worker_components(
address=address, supervisor_address=address, main_pool=pool
address=address,
supervisor_address=address,
main_pool=pool,
metrics_exporter_host=None,
metrics_exporter_port=None,
)
await pool.join()
except asyncio.CancelledError:
Expand Down
83 changes: 83 additions & 0 deletions xinference/core/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio

import uvicorn
from aioprometheus import Counter, Gauge
from aioprometheus.asgi.starlette import metrics
from fastapi import FastAPI
from fastapi.responses import RedirectResponse

DEFAULT_METRICS_SERVER_LOG_LEVEL = "warning"


generate_throughput = Gauge(
"xinference:generate_tokens_per_s", "Generate throughput in tokens/s."
)
# Latency
time_to_first_token = Gauge(
"xinference:time_to_first_token_ms", "First token latency in ms."
)
# Tokens counter
input_tokens_total_counter = Counter(
"xinference:input_tokens_total_counter", "Total number of input tokens."
)
output_tokens_total_counter = Counter(
"xinference:output_tokens_total_counter", "Total number of output tokens."
)


def record_metrics(name, op, kwargs):
collector = globals().get(name)
getattr(collector, op)(**kwargs)


def launch_metrics_export_server(q, host=None, port=None):
app = FastAPI()
app.add_route("/metrics", metrics)

@app.get("/")
async def root():
response = RedirectResponse(url="/metrics")
return response

async def main():
if host is not None and port is not None:
config = uvicorn.Config(
app, host=host, port=port, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
)
elif host is not None:
config = uvicorn.Config(
app, host=host, port=0, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
)
elif port is not None:
config = uvicorn.Config(
app, port=port, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
)
else:
config = uvicorn.Config(app, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL)

server = uvicorn.Server(config)
task = asyncio.create_task(server.serve())

while not server.started and not task.done():
await asyncio.sleep(0.1)

for server in server.servers:
for socket in server.sockets:
q.put(socket.getsockname())
await task

asyncio.run(main())
156 changes: 148 additions & 8 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import json
import os
import time
import types
import weakref
from typing import (
Expand All @@ -34,7 +35,9 @@
import xoscar as xo

if TYPE_CHECKING:
from .worker import WorkerActor
from ..model.llm.core import LLM
from ..model.core import ModelDescription
import PIL

import logging
Expand Down Expand Up @@ -140,13 +143,23 @@ async def __pre_destroy__(self):
gc.collect()
torch.cuda.empty_cache()

def __init__(self, model: "LLM", request_limits: Optional[int] = None):
def __init__(
self,
worker_address: str,
model: "LLM",
model_description: Optional["ModelDescription"] = None,
request_limits: Optional[int] = None,
):
super().__init__()
from ..model.llm.pytorch.core import PytorchModel
from ..model.llm.pytorch.spec_model import SpeculativeModel
from ..model.llm.vllm.core import VLLMModel

self._worker_address = worker_address
self._model = model
self._model_description = (
model_description.to_dict() if model_description else {}
)
self._request_limits = request_limits

self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
Expand All @@ -156,7 +169,65 @@ def __init__(self, model: "LLM", request_limits: Optional[int] = None):
if isinstance(self._model, (PytorchModel, SpeculativeModel, VLLMModel))
else asyncio.locks.Lock()
)
self._worker_ref = None
self._serve_count = 0
self._metrics_labels = {
"type": self._model_description.get("model_type", "unknown"),
"model": self.model_uid(),
"node": self._worker_address,
"format": self._model_description.get("model_format", "unknown"),
"quantization": self._model_description.get("quantization", "none"),
}
self._loop: Optional[asyncio.AbstractEventLoop] = None

async def __post_create__(self):
self._loop = asyncio.get_running_loop()

async def _record_completion_metrics(
self, duration, completion_tokens, prompt_tokens
):
coros = []
if completion_tokens > 0:
coros.append(
self.record_metrics(
"output_tokens_total_counter",
"add",
{
"labels": self._metrics_labels,
"value": completion_tokens,
},
)
)
if prompt_tokens > 0:
coros.append(
self.record_metrics(
"input_tokens_total_counter",
"add",
{"labels": self._metrics_labels, "value": prompt_tokens},
)
)
if completion_tokens > 0:
generate_throughput = completion_tokens / duration
coros.append(
self.record_metrics(
"generate_throughput",
"set",
{
"labels": self._metrics_labels,
"value": generate_throughput,
},
)
)
await asyncio.gather(*coros)

async def _get_worker_ref(self) -> xo.ActorRefType["WorkerActor"]:
from .worker import WorkerActor

if self._worker_ref is None:
self._worker_ref = await xo.actor_ref(
address=self._worker_address, uid=WorkerActor.uid()
)
return self._worker_ref

def is_vllm_backend(self) -> bool:
from ..model.llm.vllm.core import VLLMModel
Expand All @@ -178,19 +249,46 @@ def model_uid(self):
)

def _to_json_generator(self, gen: types.GeneratorType):
start_time = time.time()
time_to_first_token = None
final_usage = None
try:
for v in gen:
if time_to_first_token is None:
time_to_first_token = (time.time() - start_time) * 1000
final_usage = v.pop("usage", None)
v = dict(data=json.dumps(v))
yield sse_starlette.sse.ensure_bytes(v, None)
except OutOfMemoryError:
logger.exception(
"Model actor is out of memory, model id: %s", self.model_uid()
)
os._exit(1)
finally:
if self._loop is not None and time_to_first_token is not None:
coro = self.record_metrics(
"time_to_first_token",
"set",
{"labels": self._metrics_labels, "value": time_to_first_token},
)
asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
if self._loop is not None and final_usage is not None:
coro = self._record_completion_metrics(
time.time() - start_time,
completion_tokens=final_usage["completion_tokens"],
prompt_tokens=final_usage["prompt_tokens"],
)
asyncio.run_coroutine_threadsafe(coro, loop=self._loop)

async def _to_json_async_gen(self, gen: types.AsyncGeneratorType):
start_time = time.time()
time_to_first_token = None
final_usage = None
try:
async for v in gen:
if time_to_first_token is None:
time_to_first_token = (time.time() - start_time) * 1000
final_usage = v.pop("usage", None)
v = await asyncio.to_thread(json.dumps, v)
v = dict(data=v) # noqa: F821
yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
Expand All @@ -199,6 +297,25 @@ async def _to_json_async_gen(self, gen: types.AsyncGeneratorType):
"Model actor is out of memory, model id: %s", self.model_uid()
)
os._exit(1)
finally:
coros = []
if time_to_first_token is not None:
coros.append(
self.record_metrics(
"time_to_first_token",
"set",
{"labels": self._metrics_labels, "value": time_to_first_token},
)
)
if final_usage is not None:
coros.append(
self._record_completion_metrics(
time.time() - start_time,
completion_tokens=final_usage["completion_tokens"],
prompt_tokens=final_usage["prompt_tokens"],
)
)
await asyncio.gather(*coros)

@oom_check
async def _call_wrapper(self, fn: Callable, *args, **kwargs):
Expand Down Expand Up @@ -245,13 +362,32 @@ async def generate(self, prompt: str, *args, **kwargs):
@request_limit
@xo.generator
async def chat(self, prompt: str, *args, **kwargs):
if hasattr(self._model, "chat"):
return await self._call_wrapper(self._model.chat, prompt, *args, **kwargs)
if hasattr(self._model, "async_chat"):
return await self._call_wrapper(
self._model.async_chat, prompt, *args, **kwargs
)
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
start_time = time.time()
response = None
try:
if hasattr(self._model, "chat"):
response = await self._call_wrapper(
self._model.chat, prompt, *args, **kwargs
)
return response
if hasattr(self._model, "async_chat"):
response = await self._call_wrapper(
self._model.async_chat, prompt, *args, **kwargs
)
return response
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
finally:
# For the non stream result.
if response is not None and isinstance(response, dict):
usage = response["usage"]
# Some backends may not have a valid usage, we just skip them.
completion_tokens = usage["completion_tokens"]
prompt_tokens = usage["prompt_tokens"]
await self._record_completion_metrics(
time.time() - start_time,
completion_tokens,
prompt_tokens,
)

@log_async(logger=logger)
@request_limit
Expand Down Expand Up @@ -341,3 +477,7 @@ async def image_to_image(
raise AttributeError(
f"Model {self._model.model_spec} is not for creating image."
)

async def record_metrics(self, name, op, kwargs):
worker_ref = await self._get_worker_ref()
await worker_ref.record_metrics(name, op, kwargs)
5 changes: 5 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ..core import ModelActor
from ..core.status_guard import InstanceInfo, LaunchStatus
from .metrics import record_metrics
from .resource import ResourceStatus
from .utils import (
build_replica_model_uid,
Expand Down Expand Up @@ -750,3 +751,7 @@ async def report_worker_status(
self._worker_status[worker_address] = WorkerStatus(
update_time=time.time(), status=status
)

@staticmethod
def record_metrics(name, op, kwargs):
record_metrics(name, op, kwargs)
Loading

0 comments on commit c1e1c5a

Please sign in to comment.