From cb9e4e9e04765adfbc5862e7ab9c94fd56bd8afc Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Thu, 13 Sep 2018 14:47:58 -0400 Subject: [PATCH 1/3] Introduce a few very basic tests. This is preparation for introducing AsyncResource as a base class. --- .gitignore | 1 + README.md | 11 ++++++--- pytest.ini | 2 ++ setup.py | 4 ++-- tests/test_connection.py | 52 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/test_connection.py diff --git a/.gitignore b/.gitignore index 2da908c..8e13e34 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.pytest_cache __pycache__ dist examples/fake.* diff --git a/README.md b/README.md index 492b92c..55ee5f0 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ If you want to help develop `trio-websocket`, clone [the repository](https://github.com/hyperiongray/trio-websocket) and run this command from the repository root: - pip install --editable . + pip install --editable .[dev] ## Sample client @@ -39,7 +39,11 @@ example client sends a text message and then disconnects. trio.run(main) -A more detailed example is in `examples/client.py`. +A more detailed example is in `examples/client.py`. **Note:** if you want to run +this example client with SSL, you'll need to install the `trustme` module from +PyPI (installed automtically if you used the `[dev]` extras when installing +`trio-websocket`) and then generate a self-signed certificate by running +`example/generate-cert.py`. ## Sample server @@ -64,7 +68,8 @@ to each incoming message with an identical outgoing message. trio.run(main) -A longer example is in `examples/server.py`. +A longer example is in `examples/server.py`. **See the note above about using +SSL with the example client.** ## Integration Testing with Autobahn diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..5f4a13a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +trio_mode = true diff --git a/setup.py b/setup.py index a289667..4a7228b 100644 --- a/setup.py +++ b/setup.py @@ -28,9 +28,9 @@ ], keywords='websocket client server trio', packages=find_packages(exclude=['docs', 'examples', 'tests']), - install_requires=['trio', 'trustme', 'wsproto'], + install_requires=['trio', 'wsaccel', 'wsproto'], extras_require={ - 'wsaccel': ['wsaccel'], + 'dev': ['pytest', 'pytest-trio', 'trustme'], }, project_urls={ 'Bug Reports': 'https://github.com/HyperionGray/trio-websocket/issues', diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..b45f195 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,52 @@ +import pytest +from trio_websocket import ConnectionClosed, WebSocketClient, WebSocketServer +import trio + + +import logging +logging.basicConfig(level=logging.DEBUG) + + +@pytest.fixture +async def echo_server(nursery): + async def handler(conn): + try: + msg = await conn.get_message() + await conn.send_message(msg) + except ConnectionClosed: + pass + server = WebSocketServer(handler, 'localhost', 0, ssl_context=None) + await nursery.start(server.listen) + yield server + + +def client_for_server(server): + ''' Create a client configured to connect to ``server``. ''' + return WebSocketClient('localhost', server.port, 'resource', use_ssl=False) + + +async def test_client_send_and_receive(echo_server, nursery): + client = client_for_server(echo_server) + conn = await client.connect(nursery) + await conn.send_message('This is a test message.') + received_msg = await conn.get_message() + assert received_msg == 'This is a test message.' + await conn.close() + + +async def test_client_default_close(echo_server, nursery): + client = client_for_server(echo_server) + conn = await client.connect(nursery) + assert conn.closed is None + await conn.close() + assert conn.closed.code == 1000 + assert conn.closed.reason is None + + +async def test_client_nondefault_close(echo_server, nursery): + client = client_for_server(echo_server) + conn = await client.connect(nursery) + assert conn.closed is None + await conn.close(code=1001, reason='test reason') + assert conn.closed.code == 1001 + assert conn.closed.reason == 'test reason' From 90c352279fb76b3894e7fa66d985fc511e457e68 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Thu, 13 Sep 2018 23:13:59 -0400 Subject: [PATCH 2/3] WebSocketConnection now inherits from AsyncResource --- examples/client.py | 49 +++++----- tests/test_connection.py | 26 +++--- trio_websocket/__init__.py | 186 ++++++++++++++++++++++++------------- 3 files changed, 162 insertions(+), 99 deletions(-) diff --git a/examples/client.py b/examples/client.py index 3abf5f5..afdae4a 100644 --- a/examples/client.py +++ b/examples/client.py @@ -64,29 +64,36 @@ async def main(args): logging.error('Connection attempt failed: %s', ose) return False logging.info('Connected!') - while True: + async with connection: + await handle_connection(connection) + logging.info('Connection closed') + + +async def handle_connection(connection): + ''' Handle the connection. ''' + while True: + try: + logger.debug('top of loop') await trio.sleep(0.1) # allow time for connection logging - try: - cmd = await trio.run_sync_in_worker_thread(input, 'cmd> ', - cancellable=True) - if cmd.startswith('ping '): - await connection.ping(cmd[5:].encode('utf8')) - elif cmd.startswith('send '): - await connection.send_message(cmd[5:]) - message = await connection.get_message() - print('response> {}'.format(message)) - elif cmd.startswith('close'): - try: - reason = cmd[6:] - except IndexError: - reason = None - await connection.close(reason=reason) - break - else: - commands() - except ConnectionClosed: - logging.info('Connection closed') + cmd = await trio.run_sync_in_worker_thread(input, 'cmd> ', + cancellable=True) + if cmd.startswith('ping '): + await connection.ping(cmd[5:].encode('utf8')) + elif cmd.startswith('send '): + await connection.send_message(cmd[5:]) + message = await connection.get_message() + print('response> {}'.format(message)) + elif cmd.startswith('close'): + try: + reason = cmd[6:] + except IndexError: + reason = None + await connection.aclose(code=1000, reason=reason) break + else: + commands() + except ConnectionClosed: + break if __name__ == '__main__': diff --git a/tests/test_connection.py b/tests/test_connection.py index b45f195..b216a58 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,6 +5,8 @@ import logging logging.basicConfig(level=logging.DEBUG) +HOST = 'localhost' +RESOURCE = 'resource' @pytest.fixture @@ -15,38 +17,36 @@ async def handler(conn): await conn.send_message(msg) except ConnectionClosed: pass - server = WebSocketServer(handler, 'localhost', 0, ssl_context=None) + server = WebSocketServer(handler, HOST, 0, ssl_context=None) await nursery.start(server.listen) yield server def client_for_server(server): ''' Create a client configured to connect to ``server``. ''' - return WebSocketClient('localhost', server.port, 'resource', use_ssl=False) + return WebSocketClient(HOST, server.port, RESOURCE, use_ssl=False) async def test_client_send_and_receive(echo_server, nursery): client = client_for_server(echo_server) - conn = await client.connect(nursery) - await conn.send_message('This is a test message.') - received_msg = await conn.get_message() - assert received_msg == 'This is a test message.' - await conn.close() + async with await client.connect(nursery) as conn: + await conn.send_message('This is a test message.') + received_msg = await conn.get_message() + assert received_msg == 'This is a test message.' async def test_client_default_close(echo_server, nursery): client = client_for_server(echo_server) - conn = await client.connect(nursery) - assert conn.closed is None - await conn.close() + async with await client.connect(nursery) as conn: + assert conn.closed is None assert conn.closed.code == 1000 assert conn.closed.reason is None async def test_client_nondefault_close(echo_server, nursery): client = client_for_server(echo_server) - conn = await client.connect(nursery) - assert conn.closed is None - await conn.close(code=1001, reason='test reason') + async with await client.connect(nursery) as conn: + assert conn.closed is None + await conn.aclose(code=1001, reason='test reason') assert conn.closed.code == 1001 assert conn.closed.reason == 'test reason' diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index 7514b84..d4a4c29 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -5,6 +5,7 @@ from functools import partial import trio +import trio.abc import wsproto.connection as wsconnection import wsproto.events as wsevents import wsproto.frame_protocol as wsframeproto @@ -77,7 +78,7 @@ def __repr__(self): self.code, self.name, self.reason) -class WebSocketConnection: +class WebSocketConnection(trio.abc.AsyncResource): ''' A WebSocket connection. ''' CONNECTION_ID = itertools.count() @@ -90,7 +91,6 @@ def __init__(self, stream, wsproto, path=None): :param wsproto: a WSConnection instance :param client: a Trio cancel scope (only used by the server) ''' - self._closed = trio.Event() self._close_reason = None self._id = next(self.__class__.CONNECTION_ID) self._message_queue = trio.Queue(0) @@ -101,9 +101,12 @@ def __init__(self, stream, wsproto, path=None): self._str_message = '' self._reader_running = True self._path = path - # Set once the websocket handshake takes place (i.e. - # ConnectionRequested for server or ConnectedEstablished for client). - self._handshake_done = trio.Event() + # Set once the WebSocket open handshake takes place, i.e. + # ConnectionRequested for server or ConnectedEstablished for client. + self._open_handshake = trio.Event() + # Set once a WebSocket closed handshake takes place, i.e after a close + # frame has been sent and a close frame has been received. + self._close_handshake = trio.Event() @property def closed(self): @@ -130,7 +133,7 @@ def path(self): """Returns the path from the HTTP handshake.""" return self._path - async def close(self, code=1000, reason=None): + async def aclose(self, code=1000, reason=None): ''' Close the WebSocket connection. @@ -144,11 +147,17 @@ async def close(self, code=1000, reason=None): :raises ConnectionClosed: if connection is already closed ''' if self._close_reason: - raise ConnectionClosed(self._close_reason) + # Per AsyncResource interface, calling aclose() on a closed resource + # should succeed. + return self._wsproto.close(code=code, reason=reason) - self._close_reason = CloseReason(code, reason) - await self._write_pending() - await self._closed.wait() + try: + await self._write_pending() + await self._close_handshake.wait() + finally: + # If cancelled during WebSocket close, make sure that the underlying + # stream is closed. + await self._close_stream() async def get_message(self): ''' @@ -198,11 +207,21 @@ async def send_message(self, message): self._wsproto.send_data(message) await self._write_pending() - async def _close_message_queue(self): + async def _close_stream(self): + ''' Close the TCP connection. ''' + self._reader_running = False + try: + await self._stream.aclose() + except trio.BrokenStreamError: + # This means the TCP connection is already dead. + pass + + async def _close_web_socket(self, code, reason): ''' - If any tasks are suspended on get_message(), wake them up with a - ConnectionClosed exception. + Mark the WebSocket as closed. If any tasks are suspended on + get_message(), wake them up with a ConnectionClosed exception. ''' + self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) logger.debug('conn#%d websocket closed %r', self._id, exc) while True: @@ -212,60 +231,91 @@ async def _close_message_queue(self): except trio.WouldBlock: break - async def _close_stream(self): - ''' Close the TCP connection. ''' - self._reader_running = False - try: - await self._stream.aclose() - except trio.BrokenStreamError: - # This means the TCP connection is already dead. - pass - self._closed.set() + async def _handle_connection_requested_event(self, event): + ''' + Handle a ConnectionRequested event. - async def _handle_event(self, event): + :param event: ''' - Process one WebSocket event. + self._path = event.h11request.target + self._wsproto.accept(event) + await self._write_pending() + self._open_handshake.set() - :param event: a wsproto event + async def _handle_connection_established_event(self, event): ''' - if isinstance(event, wsevents.ConnectionRequested): - logger.debug('conn#%d accepting websocket', self._id) - self._path = event.h11request.target - self._wsproto.accept(event) - await self._write_pending() - self._handshake_done.set() - elif isinstance(event, wsevents.ConnectionEstablished): - logger.debug('conn#%d websocket established', self._id) - self._handshake_done.set() - elif isinstance(event, wsevents.ConnectionClosed): - if self._close_reason is None: - self._close_reason = CloseReason(event.code, event.reason) - await self._write_pending() - await self._close_message_queue() - await self._close_stream() - elif isinstance(event, wsevents.BytesReceived): - logger.debug('conn#%d received binary frame', self._id) - self._bytes_message += event.data - if event.message_finished: - await self._message_queue.put(self._bytes_message) - self._bytes_message = b'' - elif isinstance(event, wsevents.TextReceived): - logger.debug('conn#%d received text frame', self._id) - self._str_message += event.data - if event.message_finished: - await self._message_queue.put(self._str_message) - self._str_message = '' - elif isinstance(event, wsevents.PingReceived): - logger.debug('conn#%d ping', self._id) - # wsproto queues a pong automatically, we just need to send it: - await self._write_pending() - elif isinstance(event, wsevents.PongReceived): - logger.debug('conn#%d pong %r', self._id, event.payload) - else: - raise Exception('Unknown websocket event: {!r}'.format(event)) + Handle a ConnectionEstablished event. + + :param event: + ''' + self._open_handshake.set() + + async def _handle_connection_closed_event(self, event): + ''' + Handle a ConnectionClosed event. + + :param event: + ''' + await self._write_pending() + await self._close_web_socket(event.code, event.reason or None) + self._close_handshake.set() + + async def _handle_bytes_received_event(self, event): + ''' + Handle a BytesReceived event. + + :param event: + ''' + self._bytes_message += event.data + if event.message_finished: + await self._message_queue.put(self._bytes_message) + self._bytes_message = b'' + + async def _handle_text_received_event(self, event): + ''' + Handle a TextReceived event. + + :param event: + ''' + self._str_message += event.data + if event.message_finished: + await self._message_queue.put(self._str_message) + self._str_message = '' + + async def _handle_ping_received_event(self, event): + ''' + Handle a PingReceived event. + + Wsproto queues a pong frame automatically, so this handler just needs to + send it. + + :param event: + ''' + await self._write_pending() + + async def _handle_pong_received_event(self, event): + ''' + Handle a PongReceived event. + + Currently we don't do anything special for a Pong frame, but this may + change in the future. This handler is here as a placeholder. + + :param event: + ''' + logger.debug('conn#%d pong %r', self._id, event.payload) async def _reader_task(self): ''' A background task that reads network data and generates events. ''' + handlers = { + 'ConnectionRequested': self._handle_connection_requested_event, + 'ConnectionEstablished': self._handle_connection_established_event, + 'ConnectionClosed': self._handle_connection_closed_event, + 'BytesReceived': self._handle_bytes_received_event, + 'TextReceived': self._handle_text_received_event, + 'PingReceived': self._handle_ping_received_event, + 'PongReceived': self._handle_pong_received_event, + } + if self.is_client: # Clients need to initiate the negotiation: await self._write_pending() @@ -273,7 +323,14 @@ async def _reader_task(self): while self._reader_running: # Process events. for event in self._wsproto.events(): - await self._handle_event(event) + event_type = type(event).__name__ + try: + handler = handlers[event_type] + logger.debug('conn#%d received event: %s', self._id, + event_type) + await handler(event) + except KeyError: + logger.error('Received unknown event type: %s', event_type) # Get network data. try: @@ -286,10 +343,9 @@ async def _reader_task(self): # If TCP closed before WebSocket, then record it as an abnormal # closure. if not self._wsproto.closed: - self._close_creason = CloseReason( + await self._close_web_socket( wsframeproto.CloseReason.ABNORMAL_CLOSURE, 'TCP connection aborted') - await self._close_message_queue() await self._close_stream() break else: @@ -371,7 +427,7 @@ async def _handle_connection(self, stream): wsproto = wsconnection.WSConnection(wsconnection.SERVER) connection = WebSocketConnection(stream, wsproto) nursery.start_soon(connection._reader_task) - await connection._handshake_done.wait() + await connection._open_handshake.wait() nursery.start_soon(self._handler, connection) @@ -425,5 +481,5 @@ async def connect(self, nursery): host=host_header, resource=self._resource) connection = WebSocketConnection(stream, wsproto, path=self._resource) nursery.start_soon(connection._reader_task) - await connection._handshake_done.wait() + await connection._open_handshake.wait() return connection From 3a572130468ad524472137b1c247757159f1cd05 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Mon, 17 Sep 2018 11:28:27 -0400 Subject: [PATCH 3/3] Address code review feedback --- examples/client.py | 6 ++--- tests/test_connection.py | 45 ++++++++++++++++++++------------------ trio_websocket/__init__.py | 3 ++- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/examples/client.py b/examples/client.py index afdae4a..84fd696 100644 --- a/examples/client.py +++ b/examples/client.py @@ -44,7 +44,7 @@ def parse_args(): async def main(args): ''' Main entry point, returning False in the case of logged error. ''' async with trio.open_nursery() as nursery: - logging.info('Connecting to WebSocket…') + logging.debug('Connecting to WebSocket…') ssl_context = ssl.create_default_context() if args.ssl: try: @@ -63,10 +63,10 @@ async def main(args): except OSError as ose: logging.error('Connection attempt failed: %s', ose) return False - logging.info('Connected!') + logging.debug('Connected!') async with connection: await handle_connection(connection) - logging.info('Connection closed') + logging.debug('Connection closed') async def handle_connection(connection): diff --git a/tests/test_connection.py b/tests/test_connection.py index b216a58..9ab4833 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -11,6 +11,8 @@ @pytest.fixture async def echo_server(nursery): + ''' An echo server reads one message, sends back the same message, + then exits. ''' async def handler(conn): try: msg = await conn.get_message() @@ -22,31 +24,32 @@ async def handler(conn): yield server -def client_for_server(server): - ''' Create a client configured to connect to ``server``. ''' - return WebSocketClient(HOST, server.port, RESOURCE, use_ssl=False) +@pytest.fixture +async def echo_conn(echo_server, nursery): + ''' Return a client connection instance that is connected to an echo + server. ''' + client = WebSocketClient(HOST, echo_server.port, RESOURCE, use_ssl=False) + async with await client.connect(nursery) as conn: + yield conn -async def test_client_send_and_receive(echo_server, nursery): - client = client_for_server(echo_server) - async with await client.connect(nursery) as conn: - await conn.send_message('This is a test message.') - received_msg = await conn.get_message() +async def test_client_send_and_receive(echo_conn, nursery): + async with echo_conn: + await echo_conn.send_message('This is a test message.') + received_msg = await echo_conn.get_message() assert received_msg == 'This is a test message.' -async def test_client_default_close(echo_server, nursery): - client = client_for_server(echo_server) - async with await client.connect(nursery) as conn: - assert conn.closed is None - assert conn.closed.code == 1000 - assert conn.closed.reason is None +async def test_client_default_close(echo_conn, nursery): + async with echo_conn: + assert echo_conn.closed is None + assert echo_conn.closed.code == 1000 + assert echo_conn.closed.reason is None -async def test_client_nondefault_close(echo_server, nursery): - client = client_for_server(echo_server) - async with await client.connect(nursery) as conn: - assert conn.closed is None - await conn.aclose(code=1001, reason='test reason') - assert conn.closed.code == 1001 - assert conn.closed.reason == 'test reason' +async def test_client_nondefault_close(echo_conn, nursery): + async with echo_conn: + assert echo_conn.closed is None + await echo_conn.aclose(code=1001, reason='test reason') + assert echo_conn.closed.code == 1001 + assert echo_conn.closed.reason == 'test reason' diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index d4a4c29..675a9af 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -330,7 +330,8 @@ async def _reader_task(self): event_type) await handler(event) except KeyError: - logger.error('Received unknown event type: %s', event_type) + logger.warning('Received unknown event type: %s', + event_type) # Get network data. try: