From bca2ce4782d74078cd9a11cceff295cbddb27b9a Mon Sep 17 00:00:00 2001 From: Min RK Date: Wed, 3 Apr 2024 08:49:38 +0200 Subject: [PATCH 1/3] types: fully resolve default socket type zmq.Context() returns a sync `zmq.Context[zmq.Socket[bytes]]` has consequences for super where it's more restrictive than it should be, but that shouldn't be a big deal --- zmq/eventloop/future.py | 2 +- zmq/sugar/context.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/zmq/eventloop/future.py b/zmq/eventloop/future.py index b2fca5b7f..0f34f0ef9 100644 --- a/zmq/eventloop/future.py +++ b/zmq/eventloop/future.py @@ -101,4 +101,4 @@ def __init__(self: Context, *args: Any, **kwargs: Any) -> None: DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/zmq/sugar/context.py b/zmq/sugar/context.py index 54708b1f9..98ff2c9ef 100644 --- a/zmq/sugar/context.py +++ b/zmq/sugar/context.py @@ -78,18 +78,18 @@ class Context(ContextBase, AttributeSetter, Generic[_SocketType]): _socket_class: type[_SocketType] = Socket # type: ignore @overload - def __init__(self: Context[Socket], io_threads: int = 1): ... + def __init__(self: Context[Socket[bytes]], io_threads: int = 1): ... @overload - def __init__(self: Context[Socket], io_threads: Context): + def __init__(self: Context[Socket[bytes]], io_threads: Context): # this should be positional-only, but that requires 3.8 ... @overload - def __init__(self: Context[Socket], *, shadow: Context | int): ... + def __init__(self: Context[Socket[bytes]], *, shadow: Context | int): ... def __init__( - self: Context[Socket], + self: Context[Socket[bytes]], io_threads: int | Context = 1, shadow: Context | int = 0, ) -> None: From 9b16992b796670dc29f4aedb520318e667aed636 Mon Sep 17 00:00:00 2001 From: Min RK Date: Wed, 3 Apr 2024 08:56:55 +0200 Subject: [PATCH 2/3] add some overloads for async socket.recv --- zmq/_future.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zmq/_future.py b/zmq/_future.py index 59aafbfb6..052f8f6b5 100644 --- a/zmq/_future.py +++ b/zmq/_future.py @@ -292,6 +292,19 @@ def recv_multipart( 'recv_multipart', dict(flags=flags, copy=copy, track=track) ) + @overload # type: ignore + def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ... + + @overload + def recv( + self, flags: int = 0, *, copy: Literal[True], track: bool = False + ) -> Awaitable[bytes]: ... + + @overload + def recv( + self, flags: int = 0, *, copy: Literal[False], track: bool = False + ) -> Awaitable[_zmq.Frame]: ... + def recv( # type: ignore self, flags: int = 0, copy: bool = True, track: bool = False ) -> Awaitable[bytes | _zmq.Frame]: From f5a664d6b8cee1984a9238265e82e4e868238bbc Mon Sep 17 00:00:00 2001 From: Min RK Date: Wed, 3 Apr 2024 10:29:44 +0200 Subject: [PATCH 3/3] add SyncSocket, SyncContext TypeAliases and note in the docs about Generics --- docs/source/api/zmq.md | 23 +++++++++++++++++++++++ zmq/_typing.py | 9 +++++++++ zmq/sugar/context.py | 16 ++++++++++------ zmq/sugar/socket.py | 6 ++++-- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/docs/source/api/zmq.md b/docs/source/api/zmq.md index 4dc1a3c79..fa265ce31 100644 --- a/docs/source/api/zmq.md +++ b/docs/source/api/zmq.md @@ -9,6 +9,29 @@ ## Basic Classes +````{note} +For typing purposes, `zmq.Context` and `zmq.Socket` are Generics, +which means they will accept any Context or Socket implementation. + +The base `zmq.Context()` constructor returns the type +`zmq.Context[zmq.Socket[bytes]]`. +If you are using type annotations and want to _exclude_ the async subclasses, +use the resolved types instead of the base Generics: + +```python +ctx: zmq.Context[zmq.Socket[bytes]] = zmq.Context() +sock: zmq.Socket[bytes] +``` + +in pyzmq 26, these are available as the Type Aliases (not actual classes!): + +```python +ctx: zmq.SyncContext = zmq.Context() +sock: zmq.SyncSocket +``` + +```` + ### {class}`Context` ```{eval-rst} diff --git a/zmq/_typing.py b/zmq/_typing.py index 7bb211af2..92ec879c2 100644 --- a/zmq/_typing.py +++ b/zmq/_typing.py @@ -19,3 +19,12 @@ def __getitem__(self, key): class TypedDict(Dict): # type: ignore pass + + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + try: + from typing_extensions import TypeAlias + except ImportError: + TypeAlias = type # type: ignore diff --git a/zmq/sugar/context.py b/zmq/sugar/context.py index 98ff2c9ef..a83e4cc88 100644 --- a/zmq/sugar/context.py +++ b/zmq/sugar/context.py @@ -13,13 +13,14 @@ from weakref import WeakSet import zmq +from zmq._typing import TypeAlias from zmq.backend import Context as ContextBase from zmq.constants import ContextOption, Errno, SocketOption from zmq.error import ZMQError from zmq.utils.interop import cast_int_addr from .attrsettr import AttributeSetter, OptValT -from .socket import Socket +from .socket import Socket, SyncSocket # notice when exiting, to avoid triggering term on exit _exiting = False @@ -78,18 +79,18 @@ class Context(ContextBase, AttributeSetter, Generic[_SocketType]): _socket_class: type[_SocketType] = Socket # type: ignore @overload - def __init__(self: Context[Socket[bytes]], io_threads: int = 1): ... + def __init__(self: SyncContext, io_threads: int = 1): ... @overload - def __init__(self: Context[Socket[bytes]], io_threads: Context): + def __init__(self: SyncContext, io_threads: Context): # this should be positional-only, but that requires 3.8 ... @overload - def __init__(self: Context[Socket[bytes]], *, shadow: Context | int): ... + def __init__(self: SyncContext, *, shadow: Context | int): ... def __init__( - self: Context[Socket[bytes]], + self: SyncContext, io_threads: int | Context = 1, shadow: Context | int = 0, ) -> None: @@ -415,4 +416,7 @@ def __delattr__(self, key: str) -> None: del self.sockopts[opt] -__all__ = ['Context'] +SyncContext: TypeAlias = Context[SyncSocket] + + +__all__ = ['Context', 'SyncContext'] diff --git a/zmq/sugar/socket.py b/zmq/sugar/socket.py index 4332d0d9c..fbd5390a4 100644 --- a/zmq/sugar/socket.py +++ b/zmq/sugar/socket.py @@ -23,7 +23,7 @@ from warnings import warn import zmq -from zmq._typing import Literal +from zmq._typing import Literal, TypeAlias from zmq.backend import Socket as SocketBase from zmq.error import ZMQBindError, ZMQError from zmq.utils import jsonapi @@ -1107,4 +1107,6 @@ def disable_monitor(self) -> None: self.monitor(None, 0) -__all__ = ['Socket'] +SyncSocket: TypeAlias = Socket[bytes] + +__all__ = ['Socket', 'SyncSocket']