diff --git a/docs/source/history.rst b/docs/source/history.rst index 3c460e5763..0cbabed0ca 100644 --- a/docs/source/history.rst +++ b/docs/source/history.rst @@ -86,10 +86,11 @@ the :mod:`trio` and :mod:`trio.hazmat` modules https://github.com/python-trio/trio/issues/314 -* ``trio.socket.SocketType`` will no longer be exposed publically in - 0.3.0. Since it had no public constructor, the only thing you could - do with it was ``isinstance(obj, SocketType)``. Instead, use - :func:`trio.socket.is_trio_socket`. (https://github.com/python-trio/trio/issues/170) +* ``trio.socket.SocketType`` is now an empty abstract base class, with + the actual socket class made private. This shouldn't effect anyone, + since the only thing you could directly use it for in the first + place was ``isinstance`` checks, and those still work + (https://github.com/python-trio/trio/issues/170) * The following classes and functions have moved from :mod:`trio` to :mod:`trio.hazmat`: diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 845342417d..6632ca353c 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -248,9 +248,7 @@ library socket into a trio socket: Unlike :func:`socket.socket`, :func:`trio.socket.socket` is a function, not a class; if you want to check whether an object is a -trio socket, use: - -.. autofunction:: is_trio_socket +trio socket, use ``isinstance(obj, trio.socket.SocketType)``. For name lookup, Trio provides the standard functions, but with some changes: @@ -297,7 +295,13 @@ broken features: Socket objects ~~~~~~~~~~~~~~ -.. interface:: The trio socket object interface +.. class:: SocketType + + .. note:: :class:`trio.socket.SocketType` is an abstract class and + cannot be instantiated directly; you get concrete socket objects + by calling constructors like :func:`trio.socket.socket`. + However, you can use it to check if an object is a Trio socket + via ``isinstance(obj, trio.socket.SocketType)``. Trio socket objects are overall very similar to the :ref:`standard library socket objects `, with a few diff --git a/trio/_abc.py b/trio/_abc.py index 536c827bc7..44b934ec50 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -198,6 +198,10 @@ class SocketFactory(metaclass=ABCMeta): def socket(self, family=None, type=None, proto=None): """Create and return a socket object. + Your socket object must inherit from :class:`trio.socket.SocketType`, + which is an empty class whose only purpose is to "mark" which classes + should be considered valid trio sockets. + Called by :func:`trio.socket.socket`. Note that unlike :func:`trio.socket.socket`, this does not take a @@ -207,16 +211,6 @@ def socket(self, family=None, type=None, proto=None): """ - @abstractmethod - def is_trio_socket(self, obj): - """Check if the given object is a socket instance. - - Called by :func:`trio.socket.is_trio_socket`, which returns True if - the given object is a builtin trio socket object *or* if this method - returns True. - - """ - class AsyncResource(metaclass=ABCMeta): """A standard interface for resources that needs to be cleaned up, and diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index ad1b74b419..eaab80bca1 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -60,7 +60,7 @@ class SocketStream(HalfCloseableStream): """ def __init__(self, socket): - if not tsocket.is_trio_socket(socket): + if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires trio socket object") if real_socket_type(socket.type) != tsocket.SOCK_STREAM: raise ValueError("SocketStream requires a SOCK_STREAM socket") @@ -329,7 +329,7 @@ class SocketListener(Listener): """ def __init__(self, socket): - if not tsocket.is_trio_socket(socket): + if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires trio socket object") if real_socket_type(socket.type) != tsocket.SOCK_STREAM: raise ValueError("SocketListener requires a SOCK_STREAM socket") diff --git a/trio/_socket.py b/trio/_socket.py index ef012482e7..0dd7953086 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -23,6 +23,11 @@ def _reexport(name): __all__.append(name) +def _add_to_all(obj): + __all__.append(obj.__name__) + return obj + + # Usage: # # async with _try_sync(): @@ -115,6 +120,7 @@ async def __aexit__(self, etype, value, tb): _overrides = _core.RunLocal(hostname_resolver=None, socket_factory=None) +@_add_to_all def set_custom_hostname_resolver(hostname_resolver): """Set a custom hostname resolver. @@ -147,9 +153,7 @@ def set_custom_hostname_resolver(hostname_resolver): return old -__all__.append("set_custom_hostname_resolver") - - +@_add_to_all def set_custom_socket_factory(socket_factory): """Set a custom socket object factory. @@ -159,8 +163,7 @@ def set_custom_socket_factory(socket_factory): details. Setting a custom socket factory affects all future calls to :func:`socket` - and :func:`is_trio_socket` within the enclosing call to - :func:`trio.run`. + within the enclosing call to :func:`trio.run`. Generally you should call this function just once, right at the beginning of your program. @@ -178,18 +181,10 @@ def set_custom_socket_factory(socket_factory): return old -__all__.append("set_custom_socket_factory") - ################################################################ # getaddrinfo and friends ################################################################ - -def _add_to_all(obj): - __all__.append(obj.__name__) - return obj - - _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV @@ -348,25 +343,6 @@ def socket(family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): return from_stdlib_socket(stdlib_socket) -################################################################ -# Type checking -################################################################ - - -@_add_to_all -def is_trio_socket(obj): - """Check whether the given object is a trio socket. - - This function's behavior can be customized using - :func:`set_custom_socket_factory`. - - """ - sf = _overrides.socket_factory - if sf is not None and sf.is_trio_socket(obj): - return True - return isinstance(obj, _SocketType) - - ################################################################ # _SocketType ################################################################ @@ -394,7 +370,16 @@ def real_socket_type(type_num): return type_num & _SOCK_TYPE_MASK -class _SocketType: +@_add_to_all +class SocketType: + def __init__(self): + raise TypeError( + "SocketType is an abstract class; use trio.socket.socket if you " + "want to construct a socket object" + ) + + +class _SocketType(SocketType): def __init__(self, sock): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we diff --git a/trio/socket.py b/trio/socket.py index 3a7dd841a4..93af20e383 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -6,14 +6,3 @@ # here. from ._socket import * from ._socket import __all__ - -from . import _deprecate -from ._socket import _SocketType -_deprecate.enable_attribute_deprecations(__name__) -__deprecated_attributes__ = { - "SocketType": - _deprecate.DeprecatedAttribute( - _SocketType, "0.2.0", issue=170, instead="is_trio_socket" - ) -} -del _deprecate, _SocketType diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/tests/test_highlevel_open_tcp_listeners.py index 001cf30948..1498e0abae 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/tests/test_highlevel_open_tcp_listeners.py @@ -145,7 +145,7 @@ class FakeOSError(OSError): @attr.s -class FakeSocket: +class FakeSocket(tsocket.SocketType): family = attr.ib() type = attr.ib() proto = attr.ib() @@ -177,9 +177,6 @@ class FakeSocketFactory: poison_after = attr.ib() sockets = attr.ib(default=attr.Factory(list)) - def is_trio_socket(self, obj): - return isinstance(obj, FakeSocket) - def socket(self, family, type, proto): sock = FakeSocket(family, type, proto) self.poison_after -= 1 diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/tests/test_highlevel_open_tcp_stream.py index 834c125972..f413c18d45 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/tests/test_highlevel_open_tcp_stream.py @@ -95,7 +95,7 @@ async def test_open_tcp_stream_input_validation(): @attr.s -class FakeSocket: +class FakeSocket(trio.socket.SocketType): scenario = attr.ib() family = attr.ib() type = attr.ib() @@ -154,9 +154,6 @@ def socket(self, family, type, proto): self.socket_count += 1 return FakeSocket(self, family, type, proto) - def is_trio_socket(self, obj): - return isinstance(obj, FakeSocket) - def _ip_to_gai_entry(self, ip): if ":" in ip: family = trio.socket.AF_INET6 diff --git a/trio/tests/test_highlevel_socket.py b/trio/tests/test_highlevel_socket.py index 561a2cd07c..139d878662 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/tests/test_highlevel_socket.py @@ -153,7 +153,7 @@ async def test_SocketListener_socket_closed_underfoot(): async def test_SocketListener_accept_errors(): - class FakeSocket: + class FakeSocket(tsocket.SocketType): def __init__(self, events): self._events = iter(events) @@ -178,12 +178,6 @@ async def accept(self): else: return event, None - class FakeSocketFactory: - def is_trio_socket(self, obj): - return isinstance(obj, FakeSocket) - - tsocket.set_custom_socket_factory(FakeSocketFactory()) - fake_server_sock = FakeSocket([]) fake_listen_sock = FakeSocket( diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index faeef9c494..dbf1f600d8 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -6,7 +6,7 @@ from .. import _core from .. import socket as tsocket -from .._socket import _NUMERIC_ONLY, _try_sync, _SocketType +from .._socket import _NUMERIC_ONLY, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked ################################################################ @@ -191,10 +191,10 @@ async def test_getnameinfo(): async def test_from_stdlib_socket(): sa, sb = stdlib_socket.socketpair() - assert not tsocket.is_trio_socket(sa) + assert not isinstance(sa, tsocket.SocketType) with sa, sb: ta = tsocket.from_stdlib_socket(sa) - assert tsocket.is_trio_socket(ta) + assert isinstance(ta, tsocket.SocketType) assert sa.fileno() == ta.fileno() await ta.send(b"x") assert sb.recv(1) == b"x" @@ -248,13 +248,11 @@ async def test_fromshare(): async def test_socket(): with tsocket.socket() as s: - assert isinstance(s, _SocketType) - assert tsocket.is_trio_socket(s) + assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s: - assert isinstance(s, _SocketType) - assert tsocket.is_trio_socket(s) + assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET6 @@ -319,8 +317,7 @@ async def test_SocketType_dup(): with a, b: a2 = a.dup() with a2: - assert isinstance(a2, _SocketType) - assert tsocket.is_trio_socket(a2) + assert isinstance(a2, tsocket.SocketType) assert a2.fileno() != a.fileno() a.close() await a2.send(b"x") @@ -820,13 +817,8 @@ class CustomSocketFactory: def socket(self, family, type, proto): return ("hi", family, type, proto) - def is_trio_socket(self, obj): - return obj == "foo" - csf = CustomSocketFactory() - assert not tsocket.is_trio_socket("foo") - assert tsocket.set_custom_socket_factory(csf) is None assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0) @@ -844,8 +836,9 @@ def is_trio_socket(self, obj): assert hasattr(a, "bind") assert hasattr(b, "bind") - assert tsocket.is_trio_socket("foo") - assert tsocket.set_custom_socket_factory(None) is csf - assert not tsocket.is_trio_socket("foo") + +async def test_SocketType_is_abstract(): + with pytest.raises(TypeError): + tsocket.SocketType()