Skip to content

Commit

Permalink
Merge pull request #874 from njsmith/redo-fd-holder
Browse files Browse the repository at this point in the history
Rewrite unix pipe fd handling logic
  • Loading branch information
oremanj authored Jan 26, 2019
2 parents 14ff717 + ad28183 commit 360f1fb
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 136 deletions.
6 changes: 6 additions & 0 deletions newsfragments/661.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fixed several bugs in the new subprocess pipe support on Unix, where
(a) operations on a closed pipe could accidentally affect another
unrelated pipe due to internal file-descriptor reuse, (b) in very rare
circumstances, two tasks calling ``send_all`` on the same pipe at the
same time could end up with intermingled data instead of a
:exc:`BusyResourceError`.
3 changes: 3 additions & 0 deletions newsfragments/874.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
If you have a :class:`SocketStream` that's already been closed, then
``await socket_stream.send_all(b"")`` will now correctly raise
:exc:`ClosedResourceError`.
4 changes: 4 additions & 0 deletions trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ async def send_all(self, data):
with memoryview(data) as data:
if not data:
await _core.checkpoint()
if self.socket.fileno() == -1:
raise _core.ClosedResourceError(
"socket was already closed"
)
return
total_sent = 0
while total_sent < len(data):
Expand Down
227 changes: 132 additions & 95 deletions trio/_unix_pipes.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,157 @@
import fcntl
import os
from typing import Tuple
import errno

from . import _core
from ._abc import SendStream, ReceiveStream

__all__ = ["PipeSendStream", "PipeReceiveStream", "make_pipe"]


class _PipeMixin:
def __init__(self, pipefd: int) -> None:
if not isinstance(pipefd, int):
raise TypeError(
"{0.__class__.__name__} needs a pipe fd".format(self)
)

self._pipe = pipefd
self._closed = False

flags = fcntl.fcntl(self._pipe, fcntl.F_GETFL)
fcntl.fcntl(self._pipe, fcntl.F_SETFL, flags | os.O_NONBLOCK)

def _close(self):
if self._closed:
from ._util import ConflictDetector


class _FdHolder:
# This class holds onto a raw file descriptor, in non-blocking mode, and
# is responsible for managing its lifecycle. In particular, it's
# responsible for making sure it gets closed, and also for tracking
# whether it's been closed.
#
# The way we track closure is to set the .fd field to -1, discarding the
# original value. You might think that this is a strange idea, since it
# overloads the same field to do two different things. Wouldn't it be more
# natural to have a dedicated .closed field? But that would be more
# error-prone. Fds are represented by small integers, and once an fd is
# closed, its integer value may be reused immediately. If we accidentally
# used the old fd after being closed, we might end up doing something to
# another unrelated fd that happened to get assigned the same integer
# value. By throwing away the integer value immediately, it becomes
# impossible to make this mistake – we'll just get an EBADF.
#
# (This trick was copied from the stdlib socket module.)
def __init__(self, fd: int):
# make sure self.fd is always initialized to *something*, because even
# if we error out here then __del__ will run and access it.
self.fd = -1
if not isinstance(fd, int):
raise TypeError("file descriptor must be an int")
self.fd = fd
# Flip the fd to non-blocking mode
flags = fcntl.fcntl(self.fd, fcntl.F_GETFL)
fcntl.fcntl(self.fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)

@property
def closed(self):
return self.fd == -1

def _raw_close(self):
# This doesn't assume it's in a trio context, so it can be called from
# __del__. You should never call it from Trio context, because it
# skips calling notify_fd_close. But from __del__, skipping that is
# OK, because notify_fd_close just wakes up other tasks that are
# waiting on this fd, and those tasks hold a reference to this object.
# So if __del__ is being called, we know there aren't any tasks that
# need to be woken.
if self.closed:
return
fd = self.fd
self.fd = -1
os.close(fd)

self._closed = True
os.close(self._pipe)
def __del__(self):
self._raw_close()

async def aclose(self):
# XX: This would be in _close, but this can only be used from an
# async context.
_core.notify_fd_close(self._pipe)
self._close()
if not self.closed:
_core.notify_fd_close(self.fd)
self._raw_close()
await _core.checkpoint()

def fileno(self) -> int:
"""Gets the file descriptor for this pipe."""
return self._pipe

def __del__(self):
self._close()


class PipeSendStream(_PipeMixin, SendStream):
class PipeSendStream(SendStream):
"""Represents a send stream over an os.pipe object."""

def __init__(self, fd: int):
self._fd_holder = _FdHolder(fd)
self._conflict_detector = ConflictDetector(
"another task is using this pipe"
)

async def send_all(self, data: bytes):
# we have to do this no matter what
await _core.checkpoint()
if self._closed:
raise _core.ClosedResourceError("this pipe is already closed")
async with self._conflict_detector:
# have to check up front, because send_all(b"") on a closed pipe
# should raise
if self._fd_holder.closed:
raise _core.ClosedResourceError("this pipe was already closed")

length = len(data)
# adapted from the SocketStream code
with memoryview(data) as view:
sent = 0
while sent < length:
with view[sent:] as remaining:
try:
sent += os.write(self._fd_holder.fd, remaining)
except BlockingIOError:
await _core.wait_writable(self._fd_holder.fd)
except OSError as e:
if e.errno == errno.EBADF:
raise _core.ClosedResourceError(
"this pipe was closed"
) from None
else:
raise _core.BrokenResourceError from e

if not data:
return
async def wait_send_all_might_not_block(self) -> None:
async with self._conflict_detector:
if self._fd_holder.closed:
raise _core.ClosedResourceError("this pipe was already closed")
try:
await _core.wait_writable(self._fd_holder.fd)
except BrokenPipeError as e:
# kqueue: raises EPIPE on wait_writable instead
# of sending, which is annoying
raise _core.BrokenResourceError from e

length = len(data)
# adapted from the SocketStream code
with memoryview(data) as view:
total_sent = 0
while total_sent < length:
with view[total_sent:] as remaining:
try:
total_sent += os.write(self._pipe, remaining)
except BrokenPipeError as e:
await _core.checkpoint()
raise _core.BrokenResourceError from e
except BlockingIOError:
await self.wait_send_all_might_not_block()
async def aclose(self):
await self._fd_holder.aclose()

async def wait_send_all_might_not_block(self) -> None:
if self._closed:
await _core.checkpoint()
raise _core.ClosedResourceError("This pipe is already closed")

try:
await _core.wait_writable(self._pipe)
except BrokenPipeError as e:
# kqueue: raises EPIPE on wait_writable instead
# of sending, which is annoying
# also doesn't checkpoint so we have to do that
# ourselves here too
await _core.checkpoint()
raise _core.BrokenResourceError from e


class PipeReceiveStream(_PipeMixin, ReceiveStream):
"""Represents a receive stream over an os.pipe object."""
def fileno(self):
return self._fd_holder.fd

async def receive_some(self, max_bytes: int) -> bytes:
if self._closed:
await _core.checkpoint()
raise _core.ClosedResourceError("this pipe is already closed")

if not isinstance(max_bytes, int):
await _core.checkpoint()
raise TypeError("max_bytes must be integer >= 1")
class PipeReceiveStream(ReceiveStream):
"""Represents a receive stream over an os.pipe object."""

if max_bytes < 1:
await _core.checkpoint()
raise ValueError("max_bytes must be integer >= 1")
def __init__(self, fd: int):
self._fd_holder = _FdHolder(fd)
self._conflict_detector = ConflictDetector(
"another task is using this pipe"
)

while True:
try:
await _core.checkpoint_if_cancelled()
data = os.read(self._pipe, max_bytes)
except BlockingIOError:
await _core.wait_readable(self._pipe)
else:
await _core.cancel_shielded_checkpoint()
break
async def receive_some(self, max_bytes: int) -> bytes:
async with self._conflict_detector:
if not isinstance(max_bytes, int):
raise TypeError("max_bytes must be integer >= 1")

if max_bytes < 1:
raise ValueError("max_bytes must be integer >= 1")

while True:
try:
data = os.read(self._fd_holder.fd, max_bytes)
except BlockingIOError:
await _core.wait_readable(self._fd_holder.fd)
except OSError as e:
if e.errno == errno.EBADF:
raise _core.ClosedResourceError(
"this pipe was closed"
) from None
else:
raise _core.BrokenResourceError from e
else:
break

return data
return data

async def aclose(self):
await self._fd_holder.aclose()

async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]:
"""Makes a new pair of pipes."""
(r, w) = os.pipe()
return PipeSendStream(w), PipeReceiveStream(r)
def fileno(self):
return self._fd_holder.fd
4 changes: 4 additions & 0 deletions trio/testing/_check_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ async def expect_broken_stream_on_send():
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"x" * 100)

# even if it's an empty send
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"")

# ditto for wait_send_all_might_not_block
with _assert_raises(_core.ClosedResourceError):
with assert_checkpoints():
Expand Down
Loading

0 comments on commit 360f1fb

Please sign in to comment.