From c41b515e3ee44cbe8d67770cea3297f2b42d059b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Jul 2021 13:37:52 +0300 Subject: [PATCH] Add support for MSC3202 in appservice module --- mautrix/appservice/as_handler.py | 50 +++++++++++++++++++++----------- mautrix/types/misc.py | 13 +++++++-- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/mautrix/appservice/as_handler.py b/mautrix/appservice/as_handler.py index 9e78e0a2..dac524a8 100644 --- a/mautrix/appservice/as_handler.py +++ b/mautrix/appservice/as_handler.py @@ -4,13 +4,14 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. # Partly based on github.com/Cadair/python-appservice-framework (MIT license) -from typing import Optional, Callable, Awaitable, List, Set +from typing import Optional, Callable, Awaitable, List, Set, Dict, Any from json import JSONDecodeError from aiohttp import web import asyncio import logging -from mautrix.types import JSON, UserID, RoomAlias, Event, EphemeralEvent, SerializerError +from mautrix.types import (JSON, UserID, RoomAlias, Event, EphemeralEvent, SerializerError, + DeviceOTKCount, DeviceLists) QueryFunc = Callable[[web.Request], Awaitable[Optional[web.Response]]] HandlerFunc = Callable[[Event], Awaitable] @@ -102,6 +103,17 @@ async def _http_query_alias(self, request: web.Request) -> web.Response: return web.json_response({}, status=404) return web.json_response(response) + @staticmethod + def _get_with_fallback(json: Dict[str, Any], field: str, unstable_prefix: str, + default: Any = None) -> Any: + try: + return json.pop(field) + except KeyError: + try: + return json.pop(f"{unstable_prefix}.{field}") + except KeyError: + return default + async def _http_handle_transaction(self, request: web.Request) -> web.Response: if not self._check_token(request): return web.json_response({"error": "Invalid auth token"}, status=401) @@ -116,29 +128,30 @@ async def _http_handle_transaction(self, request: web.Request) -> web.Response: return web.json_response({"error": "Body is not JSON"}, status=400) try: - events = json["events"] + events = json.pop("events") except KeyError: return web.json_response({"error": "Missing events object in body"}, status=400) - if self.ephemeral_events: - try: - ephemeral = json["ephemeral"] - except KeyError: - try: - ephemeral = json["de.sorunome.msc2409.ephemeral"] - except KeyError: - ephemeral = None - else: - ephemeral = None + ephemeral = (self._get_with_fallback(json, "ephemeral", "de.sorunome.msc2409") + if self.ephemeral_events else None) + device_lists = DeviceLists.deserialize( + self._get_with_fallback(json, "device_lists", "org.matrix.msc3202")) + otk_counts = {user_id: DeviceOTKCount.deserialize(count) + for user_id, count + in self._get_with_fallback(json, "device_one_time_keys_count", + "org.matrix.msc3202", default={}).items()} try: - await self.handle_transaction(transaction_id, events=events, ephemeral=ephemeral) + output = await self.handle_transaction(transaction_id, events=events, extra_data=json, + ephemeral=ephemeral, device_lists=device_lists, + device_otk_count=otk_counts) except Exception: self.log.exception("Exception in transaction handler") + output = None self.transactions.add(transaction_id) - return web.json_response({}) + return web.json_response(output or {}) @staticmethod def _fix_prev_content(raw_event: JSON) -> None: @@ -150,8 +163,10 @@ def _fix_prev_content(raw_event: JSON) -> None: except KeyError: pass - async def handle_transaction(self, txn_id: str, events: List[JSON], - ephemeral: Optional[List[JSON]] = None) -> None: + async def handle_transaction(self, txn_id: str, *, events: List[JSON], extra_data: JSON, + ephemeral: Optional[List[JSON]] = None, + device_otk_count: Optional[Dict[UserID, DeviceOTKCount]] = None, + device_lists: Optional[DeviceLists] = None) -> Optional[JSON]: for raw_edu in ephemeral or []: try: edu = EphemeralEvent.deserialize(raw_edu) @@ -167,6 +182,7 @@ async def handle_transaction(self, txn_id: str, events: List[JSON], self.log.exception("Failed to deserialize event %s", raw_event) else: self.handle_matrix_event(event) + return {} def handle_matrix_event(self, event: Event) -> None: if event.type.is_state and event.state_key is None: diff --git a/mautrix/types/misc.py b/mautrix/types/misc.py index 6f941767..883e3542 100644 --- a/mautrix/types/misc.py +++ b/mautrix/types/misc.py @@ -13,8 +13,17 @@ from .util import SerializableAttrs from .event import Event -DeviceLists = NamedTuple("DeviceLists", changed=List[UserID], left=List[UserID]) -DeviceOTKCount = NamedTuple("DeviceOTKCount", curve25519=int, signed_curve25519=int) + +@dataclass +class DeviceLists(SerializableAttrs): + changed: List[UserID] = attr.ib(factory=lambda: []) + left: List[UserID] = attr.ib(factory=lambda: []) + + +@dataclass +class DeviceOTKCount(SerializableAttrs): + curve25519: int + signed_curve25519: int class RoomCreatePreset(Enum):