Skip to content

Commit

Permalink
Avoid linear type searches in ServiceBrowsers (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 24, 2021
1 parent 27e50ff commit ff76634
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_integration_with_listener_class(self):
sub_service_updated = Event()
duplicate_service_added = Event()

subtype_name = "My special Subtype"
subtype_name = "_printer"
type_ = "_http._tcp.local."
subtype = subtype_name + "._sub." + type_
name = "UPPERxxxyyyæøå"
Expand Down
22 changes: 22 additions & 0 deletions tests/utils/test_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,25 @@ def test_service_type_name_overlong_full_name():
nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.")
with pytest.raises(BadTypeInNameException):
nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False)


def test_possible_types():
"""Test possible types from name."""
assert nameutils.possible_types('.') == set()
assert nameutils.possible_types('local.') == set()
assert nameutils.possible_types('_tcp.local.') == set()
assert nameutils.possible_types('_test-srvc-type._tcp.local.') == {'_test-srvc-type._tcp.local.'}
assert nameutils.possible_types('_any._tcp.local.') == {'_any._tcp.local.'}
assert nameutils.possible_types('.._x._tcp.local.') == {'_x._tcp.local.'}
assert nameutils.possible_types('x.y._http._tcp.local.') == {'_http._tcp.local.'}
assert nameutils.possible_types('1.2.3._mqtt._tcp.local.') == {'_mqtt._tcp.local.'}
assert nameutils.possible_types('x.sub._http._tcp.local.') == {'_http._tcp.local.'}
assert nameutils.possible_types('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.') == {
'_http._tcp.local.',
'_zget._http._tcp.local.',
}
assert nameutils.possible_types('my._printer._sub._http._tcp.local.') == {
'_http._tcp.local.',
'_sub._http._tcp.local.',
'_printer._sub._http._tcp.local.',
}
15 changes: 5 additions & 10 deletions zeroconf/_services/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
SignalRegistrationInterface,
)
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.name import service_type_name
from .._utils.name import possible_types, service_type_name
from .._utils.time import current_time_millis, millis_to_seconds
from ..const import (
_BROWSER_BACKOFF_LIMIT,
Expand Down Expand Up @@ -326,7 +326,7 @@ def service_state_changed(self) -> SignalRegistrationInterface:

def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]:
"""Return the type and name for records matching the types we are browsing."""
return [(type_, name) for type_ in self.types for name in names if name.endswith(f".{type_}")]
return [(type_, name) for name in names for type_ in self.types.intersection(possible_types(name))]

def _enqueue_callback(
self,
Expand All @@ -352,16 +352,11 @@ def _async_process_record_update(
) -> None:
"""Process a single record update from a batch of updates."""
if isinstance(record, DNSPointer):
name = record.name
alias = record.alias
matches = self._names_matching_types((alias,))
if name in self.types:
matches.append((name, alias))
for type_, name in matches:
for type_ in self.types.intersection(possible_types(record.name)):
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, name)
self._enqueue_callback(ServiceStateChange.Added, type_, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, name)
self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias)
else:
self.reschedule_type(type_, now, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
return
Expand Down
15 changes: 15 additions & 0 deletions zeroconf/_utils/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
USA
"""

from typing import Set

from .._exceptions import BadTypeInNameException
from ..const import (
_HAS_ASCII_CONTROL_CHARS,
Expand Down Expand Up @@ -155,3 +157,16 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis
)

return service_name + trailer


def possible_types(name: str) -> Set[str]:
"""Build a set of all possible types from a fully qualified name."""
labels = name.split('.')
label_count = len(labels)
types = set()
for count in range(label_count):
parts = labels[label_count - count - 4 :]
if not parts[0].startswith('_'):
break
types.add('.'.join(parts))
return types

0 comments on commit ff76634

Please sign in to comment.