Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with middleware args passing #2752

Merged
merged 2 commits into from
Nov 14, 2024

Conversation

uriyyo
Copy link
Contributor

@uriyyo uriyyo commented Nov 14, 2024

Summary

A feature introduced in #2381 doesn't allow to pass args inside middleware, here is an example:

from starlette.applications import Starlette
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.testclient import TestClient


class Middleware:
    def __init__(self, app: ASGIApp, arg: str) -> None:
        self.app = app
        self.arg = arg

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        self.app(scope, receive, send)


app = Starlette()
app.add_middleware(Middleware, "foo")

client = TestClient(app)
response = client.get("/")
TypeError: Middleware.__init__() got multiple values for argument 'app'

Also, make type annotation for middleware less strict, remove restriction for middleware to be class, and make it callable that returns ASGI application, it's a fully backward-compatible change.

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.testclient import TestClient


def forbid_http_middleware(app: ASGIApp) -> ASGIApp:
    async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] == "http":
            _response = JSONResponse({"detail": "Forbidden"}, status_code=403)
            await _response(scope, receive, send)
        else:
            await app(scope, receive, send)

    return middleware


app = Starlette()
app.add_middleware(forbid_http_middleware)

client = TestClient(app)
response = client.get("/")

Checklist

  • I understand that this PR may be closed in case there was no previous discussion. (This doesn't apply to typos!)
  • I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
  • I've updated the documentation accordingly.

@@ -38,5 +36,6 @@ def __repr__(self) -> str:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
cls_name = self.cls.__name__ # type: ignore[attr-defined]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe to do? I guess the point is that the self.cls isn't necessarily a class, it may be a function, but even functions have __name__ defined. Maybe we should do getattr(self.cls, "__name__", "") or something to make sure we don't error in any case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(alternatively we could add __name__ to the protocol above 😆 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adriangb Thanks for suggestion, changed it to be getattr with default value call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Are there any actual changes to the reprs? I'm not seeing why that code had to change (aside from what I pointed out in this comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were no actual changes to the repr. I changed it only because of mypy error.

@uriyyo uriyyo requested a review from adriangb November 14, 2024 21:56
@adriangb adriangb merged commit 427a8dc into encode:master Nov 14, 2024
6 checks passed
@adriangb
Copy link
Member

Thanks!

@uriyyo uriyyo deleted the bugfix/middleware-args branch November 15, 2024 16:41
@arni-xalgo-io
Copy link

Upgrading from 0.41.2 to 0.41.3 gives me the following type error

REDACTED/api.py:104: error: Missing positional arguments "scope", "receive", "send" in call to "add_middleware" of "Starlette"  [call-arg]
REDACTED/api.py:104: error: Argument 1 to "add_middleware" of "Starlette" has incompatible type "Type[ExceptionHandlerMiddleware]"; expected "_MiddlewareFactory[[MutableMapping[str, Any], Callable[[], Awaitable[MutableMapping[str, Any]]], Callable[[MutableMapping[str, Any]], Awaitable[None]]]]"  [arg-type]
REDACTED/api.py:104: note: Following member(s) of "ExceptionHandlerMiddleware" have conflicts:
REDACTED/api.py:104: note:     Expected:
REDACTED/api.py:104: note:         def __call__(app: Callable[[MutableMapping[str, Any], Callable[[], Awaitable[MutableMapping[str, Any]]], Callable[[MutableMapping[str, Any]], Awaitable[None]]], Awaitable[None]], scope: MutableMapping[str, Any], receive: Callable[[], Awaitable[MutableMapping[str, Any]]], send: Callable[[MutableMapping[str, Any]], Awaitable[None]]) -> Callable[[MutableMapping[str, Any], Callable[[], Awaitable[MutableMapping[str, Any]]], Callable[[MutableMapping[str, Any]], Awaitable[None]]], Awaitable[None]]
REDACTED/api.py:104: note:     Got:
REDACTED/api.py:104: note:         def __call__(self: ExceptionHandlerMiddleware, scope: MutableMapping[str, Any], receive: Callable[[], Awaitable[MutableMapping[str, Any]]], send: Callable[[MutableMapping[str, Any]], Awaitable[None]]) -> Coroutine[Any, Any, None]
Found 2 errors in 1 file (checked 44 source files)

✕ mypy failed.

for the following line of code

app.add_middleware(ExceptionHandlerMiddleware)

where app is a FastAPI app.

The middleware is like this

from starlette.types import ASGIApp, Receive, Scope, Send

class ExceptionHandlerMiddleware:
    def __init__(self, app: ASGIApp):
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        # Implementation here
        ...

I'm guessing it's a regression resulting from this PR.

@uriyyo
Copy link
Contributor Author

uriyyo commented Nov 19, 2024

@arni-xalgo-io There is discussion about this issue - #2757

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants