From cdeb882865145399ee0fb7d0e7623418916d6b78 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 08:48:53 +0100 Subject: [PATCH] Don't log an error when process_request returns a response. Fix #1513. --- src/websockets/asyncio/client.py | 6 +- src/websockets/asyncio/server.py | 17 ++-- src/websockets/protocol.py | 29 ++++-- src/websockets/server.py | 6 -- src/websockets/sync/client.py | 6 +- src/websockets/sync/server.py | 11 ++- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 146 ++++++++++++++++++++----------- tests/test_protocol.py | 20 ++++- 9 files changed, 163 insertions(+), 80 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ff7916d39..d276ac171 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -95,9 +95,9 @@ async def handshake( return_when=asyncio.FIRST_COMPLETED, ) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 180d3a5a9..15c9ba13e 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -192,10 +192,13 @@ async def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -360,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_transport() + return + try: connection.start_keepalive() await self.handler(connection) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 19b813526..0f6fea250 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -518,15 +518,34 @@ def close_expected(self) -> bool: Whether the TCP connection is expected to close soon. """ - # We expect a TCP close if and only if we sent a close frame: + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent # Private methods for receiving data. @@ -616,14 +635,14 @@ def discard(self) -> Generator[None]: # connection in the same circumstances where discard() replaces parse(). # The client closes it when it receives EOF from the server or times # out. (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.state != CONNECTING and self.side is CLIENT: + if self.side is CLIENT and self.state is not CONNECTING: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. diff --git a/src/websockets/server.py b/src/websockets/server.py index 527db8990..e3fdcc646 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -14,7 +14,6 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, - InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -536,11 +535,6 @@ def send_response(self, response: Response) -> None: self.logger.info("connection open") else: - # handshake_exc may be already set if accept() encountered an error. - # If the connection isn't open, set handshake_exc to guarantee that - # handshake_exc is None if and only if opening handshake succeeded. - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) self.logger.info( "connection rejected (%d %s)", response.status_code, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0aada658e..54d0aef68 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -87,9 +87,9 @@ def handshake( if not self.response_rcvd.wait(timeout): raise TimeoutError("timed out during handshake") - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 44dbd7290..8601ccef9 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -170,10 +170,13 @@ def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a3b65e956..c98765d80 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -51,7 +51,7 @@ async def asyncTearDown(self): if sys.version_info[:2] < (3, 10): # pragma: no cover @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): + def assertNoLogs(self, logger=None, level=None): """ No message is logged on the given logger with at least the given level. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 1dcb8c7b7..c817f5ef6 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -148,14 +148,17 @@ def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" @@ -166,44 +169,65 @@ async def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_request_raises_exception(self): """Server returns an error if async process_request raises an exception.""" async def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" @@ -277,31 +301,49 @@ async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_response_raises_exception(self): """Server returns an error if async process_response raises an exception.""" async def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_override_server(self): """Server can override Server header with server_header.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0ae804bb3..1c092459d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -20,7 +20,7 @@ Frame, ) from websockets.protocol import * -from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER +from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER from .extensions.utils import Rsv2Extension from .test_frames import FramesTestCase @@ -1696,6 +1696,24 @@ def test_server_fails_connection(self): server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) + def test_client_is_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + self.assertFalse(client.close_expected()) + + def test_server_is_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + self.assertFalse(server.close_expected()) + + def test_client_failed_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + client.send_eof() + self.assertTrue(client.close_expected()) + + def test_server_failed_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + server.send_eof() + self.assertTrue(server.close_expected()) + class ConnectionClosedTests(ProtocolTestCase): """