From 3cfe8ae7a0b8a32a152c473c21eab02650634cdd Mon Sep 17 00:00:00 2001 From: Leo Liao <93932709+LeoLiao123@users.noreply.github.com> Date: Sun, 19 Jan 2025 00:45:14 +0800 Subject: [PATCH] [Refactor] fix ruff rule B039: mutable-contextvar-default (#49854) ## Why are these changes needed? Replace mutable default value with None to prevent shared state across ContextVar.get() calls ref : https://docs.astral.sh/ruff/rules/mutable-contextvar-default/ ## Related issue number #47991 --------- Signed-off-by: LeoLiao123 --- pyproject.toml | 1 - python/ray/serve/_private/default_impl.py | 2 +- python/ray/serve/_private/logging_utils.py | 2 +- python/ray/serve/_private/proxy.py | 2 +- python/ray/serve/_private/router.py | 2 +- python/ray/serve/api.py | 2 +- python/ray/serve/context.py | 17 +++++++++++++++-- python/ray/serve/metrics.py | 2 +- python/ray/serve/tests/test_http_headers.py | 12 ++++++------ python/ray/serve/tests/test_logging.py | 4 ++-- .../serve/tests/test_replica_request_context.py | 4 ++-- .../serve/tests/test_replica_sync_methods.py | 4 ++-- 12 files changed, 33 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1618ad50b69f..517e1d7c2193 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ ignore = [ "B026", "B027", "B035", - "B039", "B904", "C416", "C419", diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index cb404dc43789..6325077d0989 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -93,7 +93,7 @@ def create_init_handle_options(**kwargs): def get_request_metadata(init_options, handle_options): - _request_context = ray.serve.context._serve_request_context.get() + _request_context = ray.serve.context._get_serve_request_context() request_protocol = RequestProtocol.UNDEFINED if init_options and init_options._source == DeploymentHandleSource.PROXY: diff --git a/python/ray/serve/_private/logging_utils.py b/python/ray/serve/_private/logging_utils.py index 5081829670bc..d8058f9419c6 100644 --- a/python/ray/serve/_private/logging_utils.py +++ b/python/ray/serve/_private/logging_utils.py @@ -84,7 +84,7 @@ class ServeContextFilter(logging.Filter): """ def filter(self, record): - request_context = ray.serve.context._serve_request_context.get() + request_context = ray.serve.context._get_serve_request_context() if request_context.route: setattr(record, SERVE_LOG_ROUTE, request_context.route) if request_context.request_id: diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index 6b018e78d142..a00327eb154d 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -456,7 +456,7 @@ async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator: latency_ms = (time.time() - start_time) * 1000.0 if response_handler_info.should_record_access_log: - request_context = ray.serve.context._serve_request_context.get() + request_context = ray.serve.context._get_serve_request_context() logger.info( access_log_msg( method=proxy_request.method, diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index d6b67611c1ed..2462c000b271 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -620,7 +620,7 @@ async def assign_request( # Keep track of requests that have been sent out to replicas if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: - _request_context = ray.serve.context._serve_request_context.get() + _request_context = ray.serve.context._get_serve_request_context() request_id: str = _request_context.request_id self._metrics_manager.inc_num_running_requests_for_replica( replica_id diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 506aee8b8587..ca68fbc1d7ab 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -725,7 +725,7 @@ def get_multiplexed_model_id() -> str: def my_deployment_function(request): assert serve.get_multiplexed_model_id() == "model_1" """ - _request_context = ray.serve.context._serve_request_context.get() + _request_context = ray.serve.context._get_serve_request_context() return _request_context.multiplexed_model_id diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index e87e052e4200..8febf452df26 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -181,10 +181,22 @@ class _RequestContext: _serve_request_context = contextvars.ContextVar( - "Serve internal request context variable", default=_RequestContext() + "Serve internal request context variable", default=None ) +def _get_serve_request_context(): + """Get the current request context. + + Returns: + The current request context + """ + + if _serve_request_context.get() is None: + _serve_request_context.set(_RequestContext()) + return _serve_request_context.get() + + def _set_request_context( route: str = "", request_id: str = "", @@ -195,7 +207,8 @@ def _set_request_context( """Set the request context. If the value is not set, the current context value will be used.""" - current_request_context = _serve_request_context.get() + current_request_context = _get_serve_request_context() + _serve_request_context.set( _RequestContext( route=route or current_request_context.route, diff --git a/python/ray/serve/metrics.py b/python/ray/serve/metrics.py index 44316a45f573..bce6e4944004 100644 --- a/python/ray/serve/metrics.py +++ b/python/ray/serve/metrics.py @@ -60,7 +60,7 @@ def _add_serve_metric_default_tags(default_tags: Dict[str, str]): def _add_serve_context_tag_values(tag_keys: Tuple, tags: Dict[str, str]): """Add serve context tag values to the metric tags""" - _request_context = ray.serve.context._serve_request_context.get() + _request_context = ray.serve.context._get_serve_request_context() if ROUTE_TAG in tag_keys and ROUTE_TAG not in tags: tags[ROUTE_TAG] = _request_context.route diff --git a/python/ray/serve/tests/test_http_headers.py b/python/ray/serve/tests/test_http_headers.py index b85e9816264d..d1c36c8a9b12 100644 --- a/python/ray/serve/tests/test_http_headers.py +++ b/python/ray/serve/tests/test_http_headers.py @@ -20,7 +20,7 @@ def test_request_id_header_by_default(serve_instance): @serve.deployment class Model: def __call__(self): - request_id = ray.serve.context._serve_request_context.get().request_id + request_id = ray.serve.context._get_serve_request_context().request_id return request_id serve.run(Model.bind()) @@ -52,7 +52,7 @@ def test_basic(self, serve_instance): @serve.deployment class Model: def __call__(self) -> int: - request_id = ray.serve.context._serve_request_context.get().request_id + request_id = ray.serve.context._get_serve_request_context().request_id assert request_id == "123-234" return 1 @@ -67,7 +67,7 @@ def test_fastapi(self, serve_instance): class Model: @app.get("/") def say_hi(self) -> int: - request_id = ray.serve.context._serve_request_context.get().request_id + request_id = ray.serve.context._get_serve_request_context().request_id assert request_id == "123-234" return 1 @@ -78,7 +78,7 @@ def test_starlette_resp(self, serve_instance): @serve.deployment class Model: def __call__(self) -> int: - request_id = ray.serve.context._serve_request_context.get().request_id + request_id = ray.serve.context._get_serve_request_context().request_id assert request_id == "123-234" return starlette.responses.Response("1", media_type="application/json") @@ -94,7 +94,7 @@ def test_set_request_id_headers_with_two_attributes(serve_instance): @serve.deployment class Model: def __call__(self): - request_id = ray.serve.context._serve_request_context.get().request_id + request_id = ray.serve.context._get_serve_request_context().request_id return request_id serve.run(Model.bind()) @@ -128,7 +128,7 @@ def test_reuse_request_id(serve_instance): class MyFastAPIDeployment: @app.post("/hello") def root(self, user_input: Dict[str, str]) -> Dict[str, str]: - request_id = ray.serve.context._serve_request_context.get().request_id + request_id = ray.serve.context._get_serve_request_context().request_id return { "app_name": user_input["app_name"], "serve_context_request_id": request_id, diff --git a/python/ray/serve/tests/test_logging.py b/python/ray/serve/tests/test_logging.py index 64841ec344ca..4e16d074c012 100644 --- a/python/ray/serve/tests/test_logging.py +++ b/python/ray/serve/tests/test_logging.py @@ -341,7 +341,7 @@ def test_context_information_in_logging(serve_and_ray_shutdown, json_log_format) ) def fn(*args): logger.info("user func") - request_context = ray.serve.context._serve_request_context.get() + request_context = ray.serve.context._get_serve_request_context() return { "request_id": request_context.request_id, "route": request_context.route, @@ -362,7 +362,7 @@ def fn(*args): class Model: def __call__(self, req: starlette.requests.Request): logger.info("user log message from class method") - request_context = ray.serve.context._serve_request_context.get() + request_context = ray.serve.context._get_serve_request_context() return { "request_id": request_context.request_id, "route": request_context.route, diff --git a/python/ray/serve/tests/test_replica_request_context.py b/python/ray/serve/tests/test_replica_request_context.py index b9cc1ecda9aa..101bcd7e95fe 100644 --- a/python/ray/serve/tests/test_replica_request_context.py +++ b/python/ray/serve/tests/test_replica_request_context.py @@ -6,11 +6,11 @@ from fastapi.responses import PlainTextResponse from ray import serve -from ray.serve.context import _serve_request_context +from ray.serve.context import _get_serve_request_context def _get_request_context_route() -> str: - return _serve_request_context.get().route + return _get_serve_request_context().route class TestHTTPRoute: diff --git a/python/ray/serve/tests/test_replica_sync_methods.py b/python/ray/serve/tests/test_replica_sync_methods.py index d6485704138f..43c7a14d829c 100644 --- a/python/ray/serve/tests/test_replica_sync_methods.py +++ b/python/ray/serve/tests/test_replica_sync_methods.py @@ -83,7 +83,7 @@ class D: @fastapi_app.get("/") def root(self): return PlainTextResponse( - serve.context._serve_request_context.get().request_id + serve.context._get_serve_request_context().request_id ) else: @@ -92,7 +92,7 @@ def root(self): class D: def __call__(self) -> str: return PlainTextResponse( - serve.context._serve_request_context.get().request_id + serve.context._get_serve_request_context().request_id ) serve.run(D.bind())