Skip to content

Commit

Permalink
Added support for WebSocket in TestClient (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
Randomneo authored Jan 16, 2025
1 parent dacfe6b commit d09b82c
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 45 deletions.
21 changes: 21 additions & 0 deletions blacksheep/testing/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from typing import Optional

from blacksheep.contents import Content
from blacksheep.server.application import Application
from blacksheep.server.responses import Response
from blacksheep.testing.simulator import AbstractTestSimulator, TestSimulator
from blacksheep.testing.websocket import TestWebSocket

from .helpers import CookiesType, HeadersType, QueryType

Expand Down Expand Up @@ -157,3 +159,22 @@ async def trace(
content=None,
cookies=cookies,
)

def websocket_connect(
self,
path: str,
headers: HeadersType = None,
query: QueryType = None,
cookies: CookiesType = None,
) -> TestWebSocket:
return self._test_simulator.websocket_connect(
path=path,
headers=headers,
query=query,
content=None,
cookies=cookies,
)

async def websocket_all_closed(self):
await asyncio.gather(*self._test_simulator.websocket_tasks)
self._test_simulator.websocket_tasks = []
38 changes: 38 additions & 0 deletions blacksheep/testing/simulator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import abc
import asyncio
from typing import Dict, Optional

from blacksheep.contents import Content
from blacksheep.messages import Request
from blacksheep.server.application import Application
from blacksheep.server.responses import Response
from blacksheep.testing.helpers import get_example_scope
from blacksheep.testing.websocket import TestWebSocket

from .helpers import CookiesType, HeadersType, QueryType

Expand Down Expand Up @@ -66,6 +68,17 @@ async def send_request(
Then you can define an own TestClient, with the custom logic.
"""

@abc.abstractmethod
async def websocket_connect(
self,
path,
headers: HeadersType = None,
query: QueryType = None,
content: Optional[Content] = None,
cookies: CookiesType = None,
) -> TestWebSocket:
"""Entrypoint for WebSocket"""


class TestSimulator(AbstractTestSimulator):
"""Base Test simulator class
Expand All @@ -76,6 +89,7 @@ class TestSimulator(AbstractTestSimulator):

def __init__(self, app: Application):
self.app = app
self.websocket_tasks = []
self._is_started_app()

async def send_request(
Expand Down Expand Up @@ -110,6 +124,30 @@ async def send_request(

return response

def websocket_connect(
self,
path: str,
headers: HeadersType = None,
query: QueryType = None,
content: Optional[Content] = None,
cookies: CookiesType = None,
) -> TestWebSocket:
scope = _create_scope("GET_WS", path, headers, query, cookies=cookies)
scope["type"] = "websocket"
test_websocket = TestWebSocket()

self.websocket_tasks.append(
asyncio.create_task(
self.app(
scope,
test_websocket._receive,
test_websocket._send,
),
),
)

return test_websocket

def _is_started_app(self):
assert (
self.app.started
Expand Down
68 changes: 68 additions & 0 deletions blacksheep/testing/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import asyncio
from typing import Any, AnyStr, MutableMapping

from blacksheep.settings.json import json_settings


class TestWebSocket:
def __init__(self):
self.send_queue = asyncio.Queue()
self.receive_queue = asyncio.Queue()

async def _send(self, data: MutableMapping[str, Any]) -> None:
await self.send_queue.put(data)

async def _receive(self) -> MutableMapping[str, AnyStr]:
return await self.receive_queue.get()

async def send(self, data: MutableMapping[str, Any]) -> None:
await self.receive_queue.put(data)

async def send_text(self, data: str) -> None:
await self.send({"type": "websocket.send", "text": data})

async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})

async def send_json(self, data: MutableMapping[Any, Any]) -> None:
await self.send(
{
"type": "websocket.send",
"text": json_settings.dumps(data),
}
)

async def receive(self) -> MutableMapping[str, AnyStr]:
return await self.send_queue.get()

async def receive_text(self) -> str:
data = await self.receive()
assert data["type"] == "websocket.send"
return data["text"]

async def receive_bytes(self) -> bytes:
data = await self.receive()
assert data["type"] == "websocket.send"
return data["bytes"]

async def receive_json(self) -> MutableMapping[str, Any]:
data = await self.receive()
assert data["type"] == "websocket.send"
return json_settings.loads(data["text"])

async def __aenter__(self) -> TestWebSocket:
await self.send({"type": "websocket.connect"})
received = await self.receive()
assert received.get("type") == "websocket.accept"
return self

async def __aexit__(self, exc_type, exc_value, exc_tb) -> None:
await self.send(
{
"type": "websocket.disconnect",
"code": 1000,
"reason": "TestWebSocket context closed",
},
)
129 changes: 84 additions & 45 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
WebSocketState,
format_reason,
)
from blacksheep.testing import TestClient
from blacksheep.testing.messages import MockReceive, MockSend
from tests.utils.application import FakeApplication

Expand Down Expand Up @@ -322,18 +323,19 @@ async def test_websocket_raises_for_receive_when_closed_by_client(example_scope)


@pytest.mark.asyncio
async def test_application_handling_websocket_request_not_found(example_scope):
async def test_application_handling_websocket_request_not_found():
"""
If a client tries to open a WebSocket connection on an endpoint that is not handled,
the application returns an ASGI message to close the connection.
"""
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive()
await app.start()

await app(example_scope, mock_receive, mock_send)
client = TestClient(app)
test_websocket = client.websocket_connect("/ws")
await test_websocket.send({"type": "websocket.connect"})
close_message = await test_websocket.receive()

close_message = mock_send.messages[0]
assert close_message == {"type": "websocket.close", "reason": None, "code": 1000}


Expand All @@ -344,8 +346,6 @@ async def test_application_handling_proper_websocket_request():
the application websocket handler is called.
"""
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

@app.router.ws("/ws/{foo}")
async def websocket_handler(websocket, foo):
Expand All @@ -358,21 +358,17 @@ async def websocket_handler(websocket, foo):
await websocket.accept()

await app.start()
await app(
{"type": "websocket", "path": "/ws/001", "query_string": "", "headers": []},
mock_receive,
mock_send,
)
client = TestClient(app)
async with client.websocket_connect("/ws/001"):
pass


@pytest.mark.asyncio
async def test_application_handling_proper_websocket_request_with_query():
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

@app.router.ws("/ws/{foo}")
async def websocket_handler(websocket, foo, from_query: int):
async def websocket_handler(websocket: WebSocket, foo, from_query: int):
assert isinstance(websocket, WebSocket)
assert websocket.application_state == WebSocketState.CONNECTING
assert websocket.client_state == WebSocketState.CONNECTING
Expand All @@ -383,41 +379,27 @@ async def websocket_handler(websocket, foo, from_query: int):
await websocket.accept()

await app.start()
await app(
{
"type": "websocket",
"path": "/ws/001",
"query_string": b"from_query=200",
"headers": [],
},
mock_receive,
mock_send,
)
client = TestClient(app)
async with client.websocket_connect("/ws/001", query="from_query=200"):
pass


@pytest.mark.asyncio
async def test_application_handling_proper_websocket_request_header_binding(
example_scope,
):
async def test_application_handling_proper_websocket_request_header_binding():
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

class UpgradeHeader(FromHeader[str]):
name = "Upgrade"

called = False

@app.router.ws("/ws")
async def websocket_handler(connect_header: UpgradeHeader):
async def websocket_handler(websocket: WebSocket, connect_header: UpgradeHeader):
assert connect_header.value == "websocket"

nonlocal called
called = True
await websocket.accept()

await app.start()
await app(example_scope, mock_receive, mock_send)
assert called is True
client = TestClient(app)
async with client.websocket_connect("/ws", headers={"upgrade": "websocket"}):
pass


@pytest.mark.asyncio
Expand All @@ -426,8 +408,6 @@ async def test_application_websocket_binding_by_type_annotation():
This test verifies that the WebSocketBinder can bind a WebSocket by type annotation.
"""
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

@app.router.ws("/ws")
async def websocket_handler(my_ws: WebSocket):
Expand All @@ -438,11 +418,9 @@ async def websocket_handler(my_ws: WebSocket):
await my_ws.accept()

await app.start()
await app(
{"type": "websocket", "path": "/ws", "query_string": "", "headers": []},
mock_receive,
mock_send,
)
client = TestClient(app)
async with client.websocket_connect("/ws"):
pass


@pytest.mark.asyncio
Expand All @@ -466,6 +444,67 @@ async def websocket_handler(my_ws: WebSocket):
assert await route.handler(...) is None


@pytest.mark.asyncio
async def test_testwebsocket_closing():
"""
This test verifies that websocket.disconnect is sent by TestWebSocket
"""
app = FakeApplication()
disconnected = False

@app.router.ws("/ws")
async def websocket_handler(my_ws: WebSocket):
await my_ws.accept()
try:
await my_ws.receive()
except WebSocketDisconnectError:
nonlocal disconnected
disconnected = True

await app.start()
client = TestClient(app)
async with client.websocket_connect("/ws"):
pass
await client.websocket_all_closed()
assert disconnected is True


@pytest.mark.asyncio
async def test_testwebsocket_send_receive_methods():
"""
This test verifies that TestWebSocket sends and receives different formats
"""
app = FakeApplication()

@app.router.ws("/ws")
async def websocket_handler(my_ws: WebSocket):
await my_ws.accept()
await my_ws.send_text("send")
await my_ws.send_json({"send": "test"})
await my_ws.send_bytes(b"send")
received = await my_ws.receive_text()
assert received == "received"
received = await my_ws.receive_json()
assert received == {"received": "test"}
received = await my_ws.receive_bytes()
assert received == b"received"
await my_ws.close()

await app.start()
client = TestClient(app)
async with client.websocket_connect("/ws") as ws:
received = await ws.receive_text()
assert received == "send"
received = await ws.receive_json()
assert received == {"send": "test"}
received = await ws.receive_bytes()
assert received == b"send"

await ws.send_text("received")
await ws.send_json({"received": "test"})
await ws.send_text(b"received")


LONG_REASON = "WRY" * 41
QIN = "秦" # Qyn dynasty in Chinese, 3 bytes.
TOO_LONG_REASON = QIN * 42
Expand Down

0 comments on commit d09b82c

Please sign in to comment.