Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreF committed Jan 10, 2024
1 parent 41355d0 commit e7fa784
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
23 changes: 12 additions & 11 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def __init__(
clean_session: bool | None = None,
userdata: Any = None,
protocol: int = MQTTv311,
transport: str = "tcp",
transport: Literal["tcp", "websockets"] = "tcp",
reconnect_on_failure: bool = True,
manual_ack: bool = False,
) -> None:
Expand Down Expand Up @@ -627,8 +627,13 @@ def __init__(
locally.
"""
transport = transport.lower() # type: ignore
if transport not in ("websockets", "tcp"):
raise ValueError(
f'transport must be "websockets" or "tcp", not {transport}')

self._manual_ack = manual_ack
self.transport = transport
self._transport = transport
self._protocol = protocol
self._userdata = userdata
self._sock: SocketLike | None = None
Expand Down Expand Up @@ -790,25 +795,21 @@ def keepalive(self, value: int) -> None:
self._keepalive = value

@property
def transport(self) -> str:
def transport(self) -> Literal["tcp", "websockets"]:
"""Transport method used for the connection."""
return self._transport

@transport.setter
def transport(self, value: str) -> None:
def transport(self, value: Literal["tcp", "websockets"]) -> None:
"""
Update transport which should be "tcp" or "websockets".
This will only be used on future (re)connection. You should probably
use reconnect() to update the connection if established.
"""
if value.lower() not in ("websockets", "tcp"):
raise ValueError(
f'transport must be "websockets" or "tcp", not {value}')

self._transport = value.lower()
self._transport = value

@property
def protocol(self) -> int:
def protocol(self) -> MQTTProtocolVersion:
"""Protocol version used (MQTT v3, MQTT v3.11, MQTTv5)"""
return self.protocol

Expand All @@ -818,7 +819,7 @@ def connect_timeout(self) -> float:
return self._connect_timeout

@connect_timeout.setter
def connect_timeout(self, value: float):
def connect_timeout(self, value: float) -> None:
"Change connect_timeout for future (re)connection"
if value <= 0.0:
raise ValueError("timeout must be a positive number")
Expand Down
10 changes: 8 additions & 2 deletions src/paho/mqtt/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
except ImportError:
from typing_extensions import NotRequired, Required, TypedDict

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore



class AuthParameter(TypedDict, total=False):
username: Required[str]
Expand Down Expand Up @@ -108,7 +114,7 @@ def multiple(
auth: AuthParameter | None = None,
tls: TLSParameter | None = None,
protocol: int = paho.MQTTv311,
transport: str = "tcp",
transport: Literal["tcp", "websockets"] = "tcp",
proxy_args: Any | None = None,
) -> None:
"""Publish multiple messages to a broker, then disconnect cleanly.
Expand Down Expand Up @@ -231,7 +237,7 @@ def single(
auth: AuthParameter | None = None,
tls: TLSParameter | None = None,
protocol: int = paho.MQTTv311,
transport: str = "tcp",
transport: Literal["tcp", "websockets"] = "tcp",
proxy_args: Any | None = None,
) -> None:
"""Publish a single message to a broker, then disconnect cleanly.
Expand Down

0 comments on commit e7fa784

Please sign in to comment.