Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove race condition when creating new HTTPChannel #435

Merged
merged 6 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/waitress/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def __init__(self, server, sock, addr, adj, map=None):
self.outbuf_lock = threading.Condition()

wasyncore.dispatcher.__init__(self, sock, map=map)

# Don't let wasyncore.dispatcher throttle self.addr on us.
self.connected = True
self.addr = addr
self.requests = []

Expand All @@ -92,13 +91,7 @@ def handle_write(self):
# Precondition: there's data in the out buffer to be sent, or
# there's a pending will_close request

if not self.connected:
# we dont want to close the channel twice

return

# try to flush any pending output

if not self.requests:
# 1. There are no running tasks, so we don't need to try to lock
# the outbuf before sending
Expand Down
69 changes: 7 additions & 62 deletions src/waitress/wasyncore.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,6 @@ def __init__(self, sock=None, map=None):
# get a socket from a blocking source.
sock.setblocking(0)
self.set_socket(sock, map)
self.connected = True
# The constructor no longer requires that the socket
# passed be connected.
try:
self.addr = sock.getpeername()
except OSError as err:
if err.args[0] in (ENOTCONN, EINVAL):
# To handle the case where we got an unconnected
# socket.
self.connected = False
else:
# The socket is broken in some unknown way, alert
# the user and remove it from the map (to prevent
# polling of broken sockets).
self.del_channel(map)
raise
else:
self.socket = None

Expand Down Expand Up @@ -394,23 +378,6 @@ def bind(self, addr):
self.addr = addr
return self.socket.bind(addr)

def connect(self, address):
self.connected = False
self.connecting = True
err = self.socket.connect_ex(address)
if (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK)
or err == EINVAL
and os.name == "nt"
): # pragma: no cover
self.addr = address
return
if err in (0, EISCONN):
self.addr = address
self.handle_connect_event()
else:
raise OSError(err, errorcode[err])

def accept(self):
# XXX can return either an address pair or None
try:
Expand Down Expand Up @@ -469,6 +436,8 @@ def close(self):
if why.args[0] not in (ENOTCONN, EBADF):
raise

self.socket = None

# log and log_info may be overridden to provide more sophisticated
# logging and warning methods. In general, log is for 'hit' logging
# and 'log_info' is for informational, warning and error logging.
Expand Down Expand Up @@ -519,7 +488,11 @@ def handle_expt_event(self):
# handle_expt_event() is called if there might be an error on the
# socket, or if there is OOB data
# check for the error condition first
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
err = (
self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if self.socket is not None
else 1
)
if err != 0:
# we can get here when select.select() says that there is an
# exceptional condition on the socket
Expand Down Expand Up @@ -572,34 +545,6 @@ def handle_close(self):
self.close()


# ---------------------------------------------------------------------------
# adds simple buffered output capability, useful for simple clients.
# [for more sophisticated usage use asynchat.async_chat]
# ---------------------------------------------------------------------------


class dispatcher_with_send(dispatcher):
def __init__(self, sock=None, map=None):
dispatcher.__init__(self, sock, map)
self.out_buffer = b""

def initiate_send(self):
num_sent = 0
num_sent = dispatcher.send(self, self.out_buffer[:65536])
self.out_buffer = self.out_buffer[num_sent:]

handle_write = initiate_send

def writable(self):
return (not self.connected) or len(self.out_buffer)

def send(self, data):
if self.debug: # pragma: no cover
self.log_info("sending %s" % repr(data))
self.out_buffer = self.out_buffer + data
self.initiate_send()


def close_all(map=None, ignore_all=False):
if map is None: # pragma: no cover
map = socket_map
Expand Down
107 changes: 18 additions & 89 deletions tests/test_wasyncore.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import _thread as thread
import contextlib
import errno
from errno import EALREADY, EINPROGRESS, EINVAL, EISCONN, EWOULDBLOCK, errorcode
import functools
import gc
from io import BytesIO
Expand Down Expand Up @@ -641,62 +642,6 @@ def test_strerror(self):
self.assertTrue(err != "")


class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover
def readable(self):
return False

def handle_connect(self):
pass


class DispatcherWithSendTests(unittest.TestCase):
def setUp(self):
pass

def tearDown(self):
asyncore.close_all()

@reap_threads
def test_send(self):
evt = threading.Event()
sock = socket.socket()
sock.settimeout(3)
port = bind_port(sock)

cap = BytesIO()
args = (evt, cap, sock)
t = threading.Thread(target=capture_server, args=args)
t.start()
try:
# wait a little longer for the server to initialize (it sometimes
# refuses connections on slow machines without this wait)
time.sleep(0.2)

data = b"Suppose there isn't a 16-ton weight?"
d = dispatcherwithsend_noread()
d.create_socket()
d.connect((HOST, port))

# give time for socket to connect
time.sleep(0.1)

d.send(data)
d.send(data)
d.send(b"\n")

n = 1000

while d.out_buffer and n > 0: # pragma: no cover
asyncore.poll()
n -= 1

evt.wait()

self.assertEqual(cap.getvalue(), data * 2)
finally:
join_thread(t, timeout=TIMEOUT)


@unittest.skipUnless(
hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required"
)
Expand Down Expand Up @@ -839,6 +784,23 @@ def __init__(self, family, address):
self.create_socket(family)
self.connect(address)

def connect(self, address):
self.connected = False
self.connecting = True
err = self.socket.connect_ex(address)
if (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK)
or err == EINVAL
and os.name == "nt"
): # pragma: no cover
self.addr = address
return
if err in (0, EISCONN):
self.addr = address
self.handle_connect_event()
else:
raise OSError(err, errorcode[err])

def handle_connect(self):
pass

Expand Down Expand Up @@ -1454,17 +1416,6 @@ def _makeOne(self, sock=None, map=None):

return dispatcher(sock=sock, map=map)

def test_unexpected_getpeername_exc(self):
sock = dummysocket()

def getpeername():
raise OSError(errno.EBADF)

map = {}
sock.getpeername = getpeername
self.assertRaises(socket.error, self._makeOne, sock=sock, map=map)
self.assertEqual(map, {})

def test___repr__accepting(self):
sock = dummysocket()
map = {}
Expand Down Expand Up @@ -1500,13 +1451,6 @@ def setsockopt(*arg, **kw):
inst.set_reuse_addr()
self.assertTrue(sock.errored)

def test_connect_raise_socket_error(self):
sock = dummysocket()
map = {}
sock.connect_ex = lambda *arg: 1
inst = self._makeOne(sock=sock, map=map)
self.assertRaises(socket.error, inst.connect, 0)

def test_accept_raise_TypeError(self):
sock = dummysocket()
map = {}
Expand Down Expand Up @@ -1675,21 +1619,6 @@ def test_handle_accepted(self):
self.assertTrue(sock.closed)


class Test_dispatcher_with_send(unittest.TestCase):
def _makeOne(self, sock=None, map=None):
from waitress.wasyncore import dispatcher_with_send

return dispatcher_with_send(sock=sock, map=map)

def test_writable(self):
sock = dummysocket()
map = {}
inst = self._makeOne(sock=sock, map=map)
inst.out_buffer = b"123"
inst.connected = True
self.assertTrue(inst.writable())


class Test_close_all(unittest.TestCase):
def _callFUT(self, map=None, ignore_all=False):
from waitress.wasyncore import close_all
Expand Down