Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.8] gh-108342: Make ssl TestPreHandshakeClose more reliable (GH-108370) #108408

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,11 @@ Tools/ssl/win32
# Ignore ./python binary on Unix but still look into ./Python/ directory.
/python
!/Python/

# Artifacts generated by 3.11 lying around when switching branches:
/_bootstrap_python
/Programs/_freeze_module
/Modules/Setup.bootstrap
/Modules/Setup.stdlib
/Python/deepfreeze/
/Python/frozen_modules/
102 changes: 71 additions & 31 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4828,12 +4828,16 @@ class TestPreHandshakeClose(unittest.TestCase):

class SingleConnectionTestServerThread(threading.Thread):

def __init__(self, *, name, call_after_accept):
def __init__(self, *, name, call_after_accept, timeout=None):
self.call_after_accept = call_after_accept
self.received_data = b'' # set by .run()
self.wrap_error = None # set by .run()
self.listener = None # set by .start()
self.port = None # set by .start()
if timeout is None:
self.timeout = support.SHORT_TIMEOUT
else:
self.timeout = timeout
super().__init__(name=name)

def __enter__(self):
Expand All @@ -4856,13 +4860,19 @@ def start(self):
self.ssl_ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
self.listener = socket.socket()
self.port = support.bind_port(self.listener)
self.listener.settimeout(2.0)
self.listener.settimeout(self.timeout)
self.listener.listen(1)
super().start()

def run(self):
conn, address = self.listener.accept()
self.listener.close()
try:
conn, address = self.listener.accept()
except TimeoutError:
# on timeout, just close the listener
return
finally:
self.listener.close()

with conn:
if self.call_after_accept(conn):
return
Expand Down Expand Up @@ -4890,8 +4900,13 @@ def non_linux_skip_if_other_okay_error(self, err):
# we're specifically trying to test. The way this test is written
# is known to work on Linux. We'll skip it anywhere else that it
# does not present as doing so.
self.skipTest(f"Could not recreate conditions on {sys.platform}:"
f" {err=}")
try:
self.skipTest(f"Could not recreate conditions on {sys.platform}:"
f" {err=}")
finally:
# gh-108342: Explicitly break the reference cycle
err = None

# If maintaining this conditional winds up being a problem.
# just turn this into an unconditional skip anything but Linux.
# The important thing is that our CI has the logic covered.
Expand All @@ -4902,7 +4917,7 @@ def test_preauth_data_to_tls_server(self):

def call_after_accept(unused):
server_accept_called.set()
if not ready_for_server_wrap_socket.wait(2.0):
if not ready_for_server_wrap_socket.wait(support.SHORT_TIMEOUT):
raise RuntimeError("wrap_socket event never set, test may fail.")
return False # Tell the server thread to continue.

Expand All @@ -4924,20 +4939,31 @@ def call_after_accept(unused):

ready_for_server_wrap_socket.set()
server.join()

wrap_error = server.wrap_error
self.assertEqual(b"", server.received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
server.wrap_error = None
try:
self.assertEqual(b"", server.received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
finally:
# gh-108342: Explicitly break the reference cycle
wrap_error = None
server = None

def test_preauth_data_to_tls_client(self):
server_can_continue_with_wrap_socket = threading.Event()
client_can_continue_with_wrap_socket = threading.Event()

def call_after_accept(conn_to_client):
if not server_can_continue_with_wrap_socket.wait(support.SHORT_TIMEOUT):
print("ERROR: test client took too long")

# This forces an immediate connection close via RST on .close().
set_socket_so_linger_on_with_zero_timeout(conn_to_client)
conn_to_client.send(
Expand All @@ -4959,8 +4985,10 @@ def call_after_accept(conn_to_client):

with socket.socket() as client:
client.connect(server.listener.getsockname())
if not client_can_continue_with_wrap_socket.wait(2.0):
self.fail("test server took too long.")
server_can_continue_with_wrap_socket.set()

if not client_can_continue_with_wrap_socket.wait(support.SHORT_TIMEOUT):
self.fail("test server took too long")
ssl_ctx = ssl.create_default_context()
try:
tls_client = ssl_ctx.wrap_socket(
Expand All @@ -4974,24 +5002,31 @@ def call_after_accept(conn_to_client):
tls_client.close()

server.join()
self.assertEqual(b"", received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
try:
self.assertEqual(b"", received_data)
self.assertIsInstance(wrap_error, OSError) # All platforms.
self.non_linux_skip_if_other_okay_error(wrap_error)
self.assertIsInstance(wrap_error, ssl.SSLError)
self.assertIn("before TLS handshake with data", wrap_error.args[1])
self.assertIn("before TLS handshake with data", wrap_error.reason)
self.assertNotEqual(0, wrap_error.args[0])
self.assertIsNone(wrap_error.library, msg="attr must exist")
finally:
# gh-108342: Explicitly break the reference cycle
wrap_error = None
server = None

def test_https_client_non_tls_response_ignored(self):

server_responding = threading.Event()

class SynchronizedHTTPSConnection(http.client.HTTPSConnection):
def connect(self):
# Call clear text HTTP connect(), not the encrypted HTTPS (TLS)
# connect(): wrap_socket() is called manually below.
http.client.HTTPConnection.connect(self)

# Wait for our fault injection server to have done its thing.
if not server_responding.wait(1.0) and support.verbose:
if not server_responding.wait(support.SHORT_TIMEOUT) and support.verbose:
sys.stdout.write("server_responding event never set.")
self.sock = self._context.wrap_socket(
self.sock, server_hostname=self.host)
Expand All @@ -5006,29 +5041,34 @@ def call_after_accept(conn_to_client):
server_responding.set()
return True # Tell the server to stop.

timeout = 2.0
server = self.SingleConnectionTestServerThread(
call_after_accept=call_after_accept,
name="non_tls_http_RST_responder")
name="non_tls_http_RST_responder",
timeout=timeout)
server.__enter__() # starts it
self.addCleanup(server.__exit__) # ... & unittest.TestCase stops it.
# Redundant; call_after_accept sets SO_LINGER on the accepted conn.
set_socket_so_linger_on_with_zero_timeout(server.listener)

connection = SynchronizedHTTPSConnection(
f"localhost",
server.listener.getsockname()[0],
port=server.port,
context=ssl.create_default_context(),
timeout=2.0,
timeout=timeout,
)

# There are lots of reasons this raises as desired, long before this
# test was added. Sending the request requires a successful TLS wrapped
# socket; that fails if the connection is broken. It may seem pointless
# to test this. It serves as an illustration of something that we never
# want to happen... properly not happening.
with self.assertRaises(OSError) as err_ctx:
with self.assertRaises(OSError):
connection.request("HEAD", "/test", headers={"Host": "localhost"})
response = connection.getresponse()

server.join()


def test_main(verbose=False):
if support.verbose:
Expand Down