From 16220d99235e8704380d0d45901d4ee0ffd29b9c Mon Sep 17 00:00:00 2001 From: Yoch Melka Date: Tue, 5 Nov 2024 10:42:24 +0200 Subject: [PATCH] better exception handling in matcher Signed-off-by: Yoch Melka --- src/paho/mqtt/matcher.py | 21 +++++++++++---------- tests/test_matcher.py | 1 + 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/paho/mqtt/matcher.py b/src/paho/mqtt/matcher.py index b73c13ac..4a0651c5 100644 --- a/src/paho/mqtt/matcher.py +++ b/src/paho/mqtt/matcher.py @@ -30,11 +30,12 @@ def __getitem__(self, key): node = self._root for sym in key.split('/'): node = node._children[sym] + except KeyError as err: + raise KeyError(key) from err + else: if node._content is None: raise KeyError(key) return node._content - except KeyError as ke: - raise KeyError(key) from ke def __delitem__(self, key): """Delete the value associated with some topic filter :key""" @@ -44,11 +45,13 @@ def __delitem__(self, key): for k in key.split('/'): parent, node = node, node._children[k] lst.append((parent, k, node)) - # TODO + except KeyError as err: + raise KeyError(key) from err + else: + if node._content is None: + raise KeyError(key) node._content = None - except KeyError as ke: - raise KeyError(key) from ke - else: # cleanup + # cleanup for parent, k, node in reversed(lst): if node._children or node._content is not None: break @@ -66,11 +69,9 @@ def rec(node, i=0): else: part = lst[i] if part in node._children: - for content in rec(node._children[part], i + 1): - yield content + yield from rec(node._children[part], i + 1) if '+' in node._children and (normal or i > 0): - for content in rec(node._children['+'], i + 1): - yield content + yield from rec(node._children['+'], i + 1) if '#' in node._children and (normal or i > 0): content = node._children['#']._content if content is not None: diff --git a/tests/test_matcher.py b/tests/test_matcher.py index e2dc02a4..497fd47b 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -17,6 +17,7 @@ class Test_client_function: ("#", "/foo/bar"), ("/#", "/foo/bar"), ("$SYS/bar", "$SYS/bar"), + ("$SYS/#", "$SYS/foo/bar"), ]) def test_matching(self, sub, topic): assert client.topic_matches_sub(sub, topic)