Skip to content

Commit

Permalink
[Refactor] fix ruff rule B039: mutable-contextvar-default (#49854)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Replace mutable default value with None to prevent shared state across
ContextVar.get() calls

ref : https://docs.astral.sh/ruff/rules/mutable-contextvar-default/

## Related issue number

#47991 

---------

Signed-off-by: LeoLiao123 <[email protected]>
  • Loading branch information
LeoLiao123 authored Jan 18, 2025
1 parent 252e78c commit 3cfe8ae
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 21 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ ignore = [
"B026",
"B027",
"B035",
"B039",
"B904",
"C416",
"C419",
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_init_handle_options(**kwargs):


def get_request_metadata(init_options, handle_options):
_request_context = ray.serve.context._serve_request_context.get()
_request_context = ray.serve.context._get_serve_request_context()

request_protocol = RequestProtocol.UNDEFINED
if init_options and init_options._source == DeploymentHandleSource.PROXY:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ServeContextFilter(logging.Filter):
"""

def filter(self, record):
request_context = ray.serve.context._serve_request_context.get()
request_context = ray.serve.context._get_serve_request_context()
if request_context.route:
setattr(record, SERVE_LOG_ROUTE, request_context.route)
if request_context.request_id:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator:

latency_ms = (time.time() - start_time) * 1000.0
if response_handler_info.should_record_access_log:
request_context = ray.serve.context._serve_request_context.get()
request_context = ray.serve.context._get_serve_request_context()
logger.info(
access_log_msg(
method=proxy_request.method,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ async def assign_request(

# Keep track of requests that have been sent out to replicas
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
_request_context = ray.serve.context._serve_request_context.get()
_request_context = ray.serve.context._get_serve_request_context()
request_id: str = _request_context.request_id
self._metrics_manager.inc_num_running_requests_for_replica(
replica_id
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def get_multiplexed_model_id() -> str:
def my_deployment_function(request):
assert serve.get_multiplexed_model_id() == "model_1"
"""
_request_context = ray.serve.context._serve_request_context.get()
_request_context = ray.serve.context._get_serve_request_context()
return _request_context.multiplexed_model_id


Expand Down
17 changes: 15 additions & 2 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,22 @@ class _RequestContext:


_serve_request_context = contextvars.ContextVar(
"Serve internal request context variable", default=_RequestContext()
"Serve internal request context variable", default=None
)


def _get_serve_request_context():
"""Get the current request context.
Returns:
The current request context
"""

if _serve_request_context.get() is None:
_serve_request_context.set(_RequestContext())
return _serve_request_context.get()


def _set_request_context(
route: str = "",
request_id: str = "",
Expand All @@ -195,7 +207,8 @@ def _set_request_context(
"""Set the request context. If the value is not set,
the current context value will be used."""

current_request_context = _serve_request_context.get()
current_request_context = _get_serve_request_context()

_serve_request_context.set(
_RequestContext(
route=route or current_request_context.route,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _add_serve_metric_default_tags(default_tags: Dict[str, str]):
def _add_serve_context_tag_values(tag_keys: Tuple, tags: Dict[str, str]):
"""Add serve context tag values to the metric tags"""

_request_context = ray.serve.context._serve_request_context.get()
_request_context = ray.serve.context._get_serve_request_context()
if ROUTE_TAG in tag_keys and ROUTE_TAG not in tags:
tags[ROUTE_TAG] = _request_context.route

Expand Down
12 changes: 6 additions & 6 deletions python/ray/serve/tests/test_http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_request_id_header_by_default(serve_instance):
@serve.deployment
class Model:
def __call__(self):
request_id = ray.serve.context._serve_request_context.get().request_id
request_id = ray.serve.context._get_serve_request_context().request_id
return request_id

serve.run(Model.bind())
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_basic(self, serve_instance):
@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
request_id = ray.serve.context._get_serve_request_context().request_id
assert request_id == "123-234"
return 1

Expand All @@ -67,7 +67,7 @@ def test_fastapi(self, serve_instance):
class Model:
@app.get("/")
def say_hi(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
request_id = ray.serve.context._get_serve_request_context().request_id
assert request_id == "123-234"
return 1

Expand All @@ -78,7 +78,7 @@ def test_starlette_resp(self, serve_instance):
@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
request_id = ray.serve.context._get_serve_request_context().request_id
assert request_id == "123-234"
return starlette.responses.Response("1", media_type="application/json")

Expand All @@ -94,7 +94,7 @@ def test_set_request_id_headers_with_two_attributes(serve_instance):
@serve.deployment
class Model:
def __call__(self):
request_id = ray.serve.context._serve_request_context.get().request_id
request_id = ray.serve.context._get_serve_request_context().request_id
return request_id

serve.run(Model.bind())
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_reuse_request_id(serve_instance):
class MyFastAPIDeployment:
@app.post("/hello")
def root(self, user_input: Dict[str, str]) -> Dict[str, str]:
request_id = ray.serve.context._serve_request_context.get().request_id
request_id = ray.serve.context._get_serve_request_context().request_id
return {
"app_name": user_input["app_name"],
"serve_context_request_id": request_id,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_context_information_in_logging(serve_and_ray_shutdown, json_log_format)
)
def fn(*args):
logger.info("user func")
request_context = ray.serve.context._serve_request_context.get()
request_context = ray.serve.context._get_serve_request_context()
return {
"request_id": request_context.request_id,
"route": request_context.route,
Expand All @@ -362,7 +362,7 @@ def fn(*args):
class Model:
def __call__(self, req: starlette.requests.Request):
logger.info("user log message from class method")
request_context = ray.serve.context._serve_request_context.get()
request_context = ray.serve.context._get_serve_request_context()
return {
"request_id": request_context.request_id,
"route": request_context.route,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/tests/test_replica_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from fastapi.responses import PlainTextResponse

from ray import serve
from ray.serve.context import _serve_request_context
from ray.serve.context import _get_serve_request_context


def _get_request_context_route() -> str:
return _serve_request_context.get().route
return _get_serve_request_context().route


class TestHTTPRoute:
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/tests/test_replica_sync_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class D:
@fastapi_app.get("/")
def root(self):
return PlainTextResponse(
serve.context._serve_request_context.get().request_id
serve.context._get_serve_request_context().request_id
)

else:
Expand All @@ -92,7 +92,7 @@ def root(self):
class D:
def __call__(self) -> str:
return PlainTextResponse(
serve.context._serve_request_context.get().request_id
serve.context._get_serve_request_context().request_id
)

serve.run(D.bind())
Expand Down

0 comments on commit 3cfe8ae

Please sign in to comment.