Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix subscribing topics with wildcards #1

Merged
merged 8 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[MASTER]
max-line-length = 88

[FORMAT]
good-names=dt
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased

### Fixed:

- Subscribing topics with wildcards, [PR-1](https://github.com/panda-official/DriftMqtt/pull/1)

## 0.1.0 - 2022-09-23

Initial implementation
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Drift MQTT tools

![GitHub Workflow Status](https://img.shields.io/github/workflow/status/panda-official/DriftMqtt/ci)
![PyPI](https://img.shields.io/pypi/v/drift-mqtt)
![PyPI - Downloads](https://img.shields.io/pypi/dm/drift-mqtt)


A collection of helpers to work with MQTT:
* `Client` - wrapper around `paho.mqtt.Client` that correctly handles subscriptions after reconnect

Expand Down
4 changes: 2 additions & 2 deletions pkg/drift_mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def _load_version() -> str:
"""Load version from VERSION file"""
from pathlib import Path # pylint: disable=import-outside-toplevel

here = Path(__file__).parent.resolve()
with open(str(here / "VERSION")) as version_file:
here = Path(__file__).parent
with open(here / "VERSION", "r", encoding="utf-8") as version_file:
return version_file.read().strip()


Expand Down
18 changes: 11 additions & 7 deletions pkg/drift_mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
"""

import logging
import re
from dataclasses import dataclass
from typing import Callable
from urllib.parse import urlparse

from paho.mqtt.client import Client as PahoClient
from paho.mqtt.client import MQTTMessage
import paho.mqtt.client as mqtt

logger = logging.getLogger("drift-mqtt")
Expand Down Expand Up @@ -35,7 +38,7 @@ def __init__(self, uri: str, client_id: str = "drift-mqtt-client"):
self._uri = urlparse(uri)
self._transport = "websockets" if self._uri.scheme == "ws" else "tcp"
# TODO: check if we need v311 or v5 # pylint: disable=fixme
self._client = mqtt.Client(client_id=client_id, transport=self._transport)
self._client = PahoClient(client_id=client_id, transport=self._transport)
self._client.on_connect = self.on_connect
self._client.on_disconnect = self.on_disconnect
self._client.on_subscribe = self.on_subscribe
Expand All @@ -44,14 +47,14 @@ def __init__(self, uri: str, client_id: str = "drift-mqtt-client"):
self._client.enable_logger()
self._subscriptions = []

def on_message(self, _client, _userdata, message: mqtt.MQTTMessage):
def on_message(self, _client, _userdata, message: MQTTMessage):
"""Message read callback"""
for sub in self._subscriptions:
if message.topic.startswith(sub.topic):
try:
try:
if re.match(sub.topic.replace("#", "(.*)"), message.topic):
sub.handler(message)
except Exception: # pylint: disable=broad-except
logger.exception("Error in a message handler")
except Exception as err: # pylint: disable=broad-except
logger.exception("Error in a message handler: %s", err)

def __getattr__(self, item):
"""Forward unknown methods to MQTT client"""
Expand Down Expand Up @@ -80,7 +83,8 @@ def on_connect(self, _client, _userdata, _flags, return_code, _properties=None):
@staticmethod
def on_disconnect(_client, _userdata, return_code):
"""Callback on mqtt disconnected"""
# this is a bug in paho, return_code 1 is a general error (connection lost in this case)
# this is a bug in paho, return_code 1 is a general error
# (connection lost in this case)
if return_code == mqtt.MQTT_ERR_NOMEM:
return_code = mqtt.MQTT_ERR_CONN_LOST

Expand Down
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ def get_long_description(base_path: Path):
python_requires=">=3.7",
install_requires=["paho-mqtt~=1.6.0"],
extras_require={
"test": [
"pytest~=6.1.2",
],
"test": ["pytest~=7.1.3", "pytest-mock~=3.8.2 "],
"lint": [
"pylint~=2.5.3",
"pylint~=2.15.3",
],
"format": ["black~=22.8.0"],
},
Expand Down
64 changes: 62 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,76 @@
""" Test Mqtt client
"""

from pathlib import Path

import pytest
from paho.mqtt.client import MQTTMessage

from drift_mqtt import Client

HERE = Path(__file__).parent.resolve()


def test_connection_refused():
@pytest.fixture(name="paho_client")
def _make_client(mocker):
client_klass = mocker.patch("drift_mqtt.client.PahoClient")
client = mocker.Mock()
client_klass.return_value = client
return client


def test__connection_refused():
"""Should rise connection exception"""
client = Client(uri="tcp://0.0.0.0:1883", client_id="some_test_client_id")
with pytest.raises(ConnectionRefusedError):
client.connect()


@pytest.fixture(name="last_message")
def _make_last_message():
last_message = [None]
return last_message


@pytest.fixture(name="handler")
def _make_handler(last_message):
def handler(message: MQTTMessage):
"""Simple handler just to keep a received message"""
last_message[0] = message

return handler


@pytest.mark.parametrize(
"topic, received_topic",
[
("drift/topic1", b"drift/topic1"),
("drift/#", b"drift/topic1"),
("drft/#/subpath", b"drft/topic/subpath"),
],
)
def test__handle_subscriptions(
paho_client, topic, received_topic, handler, last_message
):
"""Should parse topic and call handlers"""
client = Client(uri="tcp://0.0.0.0:1883")

client.subscribe(topic, handler)
paho_client.on_message(None, None, message=MQTTMessage(topic=received_topic))
assert last_message[0].topic == received_topic.decode("ascii")


@pytest.mark.parametrize(
"topic, received_topic",
[
("drift/topic1", b"drift/skip"),
("drift/#", b"skip/topic1"),
("drft/#/subpath", b"drft/topic/skip"),
],
)
def test__skip_subscriptions(paho_client, topic, received_topic, handler, last_message):
"""Should parse topic and skip"""
client = Client(uri="tcp://0.0.0.0:1883")

client.subscribe(topic, handler)
paho_client.on_message(None, None, message=MQTTMessage(topic=received_topic))
assert last_message[0] is None