diff --git a/tests/test_application.py b/tests/test_application.py index 84981cd..8cee602 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -12,6 +12,7 @@ from zigpy_deconz import types as t import zigpy_deconz.api as deconz_api +from zigpy_deconz.config import CONF_DECONZ_CONFIG, CONF_MAX_CONCURRENT_REQUESTS import zigpy_deconz.exception import zigpy_deconz.zigbee.application as application @@ -24,7 +25,8 @@ zigpy.config.CONF_NWK_EXTENDED_PAN_ID: "11:22:33:44:55:66:77:88", zigpy.config.CONF_NWK_UPDATE_ID: 22, zigpy.config.CONF_NWK_KEY: [0xAA] * 16, - } + }, + CONF_DECONZ_CONFIG: {CONF_MAX_CONCURRENT_REQUESTS: 20}, } @@ -369,40 +371,6 @@ async def test_request_send_aps_data_error(app): assert r[0] != 0 -async def test_request_retry(app): - req_id = sentinel.req_id - app.get_sequence = MagicMock(return_value=req_id) - - device = zigpy.device.Device(app, sentinel.ieee, 0x1122) - device.relays = [0x5678, 0x1234] - app.get_device = MagicMock(return_value=device) - - async def req_mock( - req_id, - dst_addr_ep, - profile, - cluster, - src_ep, - data, - *, - relays=None, - tx_options=t.DeconzTransmitOptions.USE_NWK_KEY_SECURITY, - radius=0 - ): - app._pending[req_id].result.set_result(1) - - app._api.aps_data_request = MagicMock(side_effect=req_mock) - app._api.protocol_version = application.PROTO_VER_MANUAL_SOURCE_ROUTE - - await app.request(device, 0x0260, 1, 2, 3, 123, b"\x01\x02\x03") - - assert len(app._api.aps_data_request.mock_calls) == 2 - without_relays, with_relays = app._api.aps_data_request.mock_calls - - assert without_relays[2]["relays"] is None - assert with_relays[2]["relays"] == [0x0000, 0x1234, 0x5678] - - async def _test_broadcast(app, send_success=True, aps_data_error=False, **kwargs): seq = sentinel.req_id @@ -746,3 +714,45 @@ async def test_delayed_scan(): with patch.object(app, "get_device", return_value=coord): await app._delayed_neighbour_scan() assert coord.neighbors.scan.await_count == 1 + + +async def test_request_concurrency(app): + """Test the request concurrency limit.""" + max_concurrency = 0 + num_concurrent = 0 + + async def req_mock( + req_id, + dst_addr_ep, + profile, + cluster, + src_ep, + data, + *, + relays=None, + tx_options=t.DeconzTransmitOptions.USE_NWK_KEY_SECURITY, + radius=0 + ): + nonlocal num_concurrent, max_concurrency + + num_concurrent += 1 + max_concurrency = max(num_concurrent, max_concurrency) + + try: + await asyncio.sleep(0.01) + app._pending[req_id].result.set_result(0) + finally: + num_concurrent -= 1 + + app._api.aps_data_request = MagicMock(side_effect=req_mock) + app._api.protocol_version = 0 + device = zigpy.device.Device(app, sentinel.ieee, 0x1122) + app.get_device = MagicMock(return_value=device) + + requests = [ + app.request(device, 0x0260, 1, 2, 3, seq, b"\x01\x02\x03") for seq in range(100) + ] + + await asyncio.gather(*requests) + + assert max_concurrency == 20 diff --git a/zigpy_deconz/config.py b/zigpy_deconz/config.py index f12a056..ed2ad40 100644 --- a/zigpy_deconz/config.py +++ b/zigpy_deconz/config.py @@ -18,6 +18,11 @@ cv_boolean, ) +CONF_DECONZ_CONFIG = "deconz_config" + +CONF_MAX_CONCURRENT_REQUESTS = "max_concurrent_requests" +CONF_MAX_CONCURRENT_REQUESTS_DEFAULT = 8 + CONF_WATCHDOG_TTL = "watchdog_ttl" CONF_WATCHDOG_TTL_DEFAULT = 600 @@ -25,6 +30,14 @@ { vol.Optional(CONF_WATCHDOG_TTL, default=CONF_WATCHDOG_TTL_DEFAULT): vol.All( int, vol.Range(min=180) - ) + ), + vol.Optional(CONF_DECONZ_CONFIG, default={}): vol.Schema( + { + vol.Optional( + CONF_MAX_CONCURRENT_REQUESTS, + default=CONF_MAX_CONCURRENT_REQUESTS_DEFAULT, + ): vol.All(int, vol.Range(min=1)) + } + ), } ) diff --git a/zigpy_deconz/zigbee/application.py b/zigpy_deconz/zigbee/application.py index 5d57358..b68fba6 100644 --- a/zigpy_deconz/zigbee/application.py +++ b/zigpy_deconz/zigbee/application.py @@ -2,8 +2,10 @@ import asyncio import binascii +import contextlib import logging import re +import time from typing import Any, Dict import zigpy.application @@ -18,7 +20,13 @@ from zigpy_deconz import types as t from zigpy_deconz.api import Deconz, NetworkParameter, NetworkState, Status -from zigpy_deconz.config import CONF_WATCHDOG_TTL, CONFIG_SCHEMA, SCHEMA_DEVICE +from zigpy_deconz.config import ( + CONF_DECONZ_CONFIG, + CONF_MAX_CONCURRENT_REQUESTS, + CONF_WATCHDOG_TTL, + CONFIG_SCHEMA, + SCHEMA_DEVICE, +) import zigpy_deconz.exception LOGGER = logging.getLogger(__name__) @@ -31,6 +39,8 @@ PROTO_VER_NEIGBOURS = 0x0107 WATCHDOG_TTL = 600 +MAX_REQUEST_RETRY_DELAY = 1.0 + class ControllerApplication(zigpy.application.ControllerApplication): SCHEMA = CONFIG_SCHEMA @@ -43,7 +53,13 @@ def __init__(self, config: Dict[str, Any]): super().__init__(config=zigpy.config.ZIGPY_SCHEMA(config)) self._api = None + self._pending = zigpy.util.Requests() + self._concurrent_requests_semaphore = asyncio.Semaphore( + self._config[CONF_DECONZ_CONFIG][CONF_MAX_CONCURRENT_REQUESTS] + ) + self._currently_waiting_requests = 0 + self._nwk = 0 self.version = 0 @@ -199,6 +215,38 @@ async def form_network(self): await asyncio.sleep(CHANGE_NETWORK_WAIT) raise Exception("Could not form network.") + @contextlib.asynccontextmanager + async def _limit_concurrency(self): + """Async context manager to prevent devices from being overwhelmed by requests. + + Mainly a thin wrapper around `asyncio.Semaphore` that logs when it has to wait. + """ + + start_time = time.time() + was_locked = self._concurrent_requests_semaphore.locked() + + if was_locked: + self._currently_waiting_requests += 1 + LOGGER.debug( + "Max concurrency (%s) reached, delaying requests (%s enqueued)", + self._config[CONF_DECONZ_CONFIG][CONF_MAX_CONCURRENT_REQUESTS], + self._currently_waiting_requests, + ) + + try: + async with self._concurrent_requests_semaphore: + if was_locked: + LOGGER.debug( + "Previously delayed request is now running, " + "delayed by %0.2f seconds", + time.time() - start_time, + ) + + yield + finally: + if was_locked: + self._currently_waiting_requests -= 1 + async def mrequest( self, group_id, @@ -238,20 +286,21 @@ async def mrequest( dst_addr_ep.address_mode = t.ADDRESS_MODE.GROUP dst_addr_ep.address = group_id - with self._pending.new(req_id) as req: - try: - await self._api.aps_data_request( - req_id, dst_addr_ep, profile, cluster, min(1, src_ep), data - ) - except zigpy_deconz.exception.CommandError as ex: - return ex.status, "Couldn't enqueue send data request: {}".format(ex) + async with self._limit_concurrency(): + with self._pending.new(req_id) as req: + try: + await self._api.aps_data_request( + req_id, dst_addr_ep, profile, cluster, min(1, src_ep), data + ) + except zigpy_deconz.exception.CommandError as ex: + return ex.status, f"Couldn't enqueue send data request: {ex!r}" - r = await asyncio.wait_for(req.result, SEND_CONFIRM_TIMEOUT) - if r: - LOGGER.debug("Error while sending %s req id frame: %s", req_id, r) - return r, "message send failure" + r = await asyncio.wait_for(req.result, SEND_CONFIRM_TIMEOUT) + if r: + LOGGER.debug("Error while sending %s req id frame: %s", req_id, r) + return r, f"message send failure: {r}" - return Status.SUCCESS, "message send success" + return Status.SUCCESS, "message send success" @zigpy.util.retryable_request async def request( @@ -282,13 +331,9 @@ async def request( dst_addr_ep.address_mode = t.uint8_t(t.ADDRESS_MODE.NWK) dst_addr_ep.address = device.nwk - relays = None tx_options = t.DeconzTransmitOptions.USE_NWK_KEY_SECURITY - if expect_reply: - tx_options |= t.DeconzTransmitOptions.USE_APS_ACKS - - for attempt in (1, 2): + async with self._limit_concurrency(): with self._pending.new(req_id) as req: try: await self._api.aps_data_request( @@ -298,25 +343,18 @@ async def request( cluster, min(1, src_ep), data, - relays=relays, tx_options=tx_options, ) except zigpy_deconz.exception.CommandError as ex: - return ex.status, f"Couldn't enqueue send data request: {ex}" + return ex.status, f"Couldn't enqueue send data request: {ex!r}" r = await asyncio.wait_for(req.result, SEND_CONFIRM_TIMEOUT) - if not r: - return r, "message send success" - - LOGGER.debug("Error while sending %s req id frame: %s", req_id, r) - - if attempt == 2: + if r: + LOGGER.debug("Error while sending %s req id frame: %s", req_id, r) return r, "message send failure" - elif self._api.protocol_version >= PROTO_VER_MANUAL_SOURCE_ROUTE: - # Force the request to send by including the coordinator - relays = [0x0000] + (device.relays or [])[::-1] - LOGGER.debug("Trying manual source route: %s", relays) + + return r, "message send success" async def broadcast( self, @@ -342,23 +380,26 @@ async def broadcast( dst_addr_ep.address = t.uint16_t(broadcast_address) dst_addr_ep.endpoint = dst_ep - with self._pending.new(req_id) as req: - try: - await self._api.aps_data_request( - req_id, dst_addr_ep, profile, cluster, min(1, src_ep), data - ) - except zigpy_deconz.exception.CommandError as ex: - return ( - ex.status, - "Couldn't enqueue send data request for broadcast: {}".format(ex), - ) + async with self._limit_concurrency(): + with self._pending.new(req_id) as req: + try: + await self._api.aps_data_request( + req_id, dst_addr_ep, profile, cluster, min(1, src_ep), data + ) + except zigpy_deconz.exception.CommandError as ex: + return ( + ex.status, + f"Couldn't enqueue send data request for broadcast: {ex!r}", + ) - r = await asyncio.wait_for(req.result, SEND_CONFIRM_TIMEOUT) + r = await asyncio.wait_for(req.result, SEND_CONFIRM_TIMEOUT) - if r: - LOGGER.debug("Error while sending %s req id broadcast: %s", req_id, r) - return r, "broadcast send failure" - return r, "broadcast send success" + if r: + LOGGER.debug( + "Error while sending %s req id broadcast: %s", req_id, r + ) + return r, f"broadcast send failure: {r}" + return r, "broadcast send success" async def permit_ncp(self, time_s=60): assert 0 <= time_s <= 254