Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set-pl command and fixed reply #29

Merged
merged 6 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mautrix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.7.10"
__version__ = "0.7.11"
__author__ = "Tulir Asokan <[email protected]>"
__all__ = ["api", "appservice", "bridge", "client", "crypto", "errors", "util", "types"]
5 changes: 3 additions & 2 deletions mautrix/bridge/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .handler import (HelpSection, HelpCacheKey, command_handler, CommandHandler, CommandProcessor,
CommandHandlerFunc, CommandEvent, SECTION_GENERAL)
CommandHandlerFunc, CommandEvent, SECTION_GENERAL, SECTION_ADMIN)
from .meta import cancel, unknown_command, help_cmd
from . import admin

__all__ = ["HelpSection", "HelpCacheKey", "command_handler", "CommandHandler", "CommandProcessor",
"CommandHandlerFunc", "CommandEvent", "SECTION_GENERAL"]
"CommandHandlerFunc", "CommandEvent", "SECTION_GENERAL", "SECTION_ADMIN"]
37 changes: 37 additions & 0 deletions mautrix/bridge/commands/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2020 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from mautrix.types import EventID

from mautrix.errors import (MatrixRequestError, IntentError)

from .handler import (command_handler, CommandEvent, SECTION_ADMIN)


@command_handler(needs_admin=True, needs_auth=False, name="set-pl",
help_section=SECTION_ADMIN,
help_args="<_level_> [_mxid_]",
help_text="Set a temporary power level without affecting the bridge.")
async def set_power_level(evt: CommandEvent) -> EventID:
try:
level = int(evt.args[0])
except (KeyError, IndexError):
return await evt.reply("**Usage:** `$cmdprefix+sp set-pl <level> [mxid]`")
except ValueError:
return await evt.reply("The level must be an integer.")
if evt.is_portal:
portal = await evt.processor.bridge.get_portal(evt.room_id)
intent = portal.main_intent
else:
intent = evt.az.intent
levels = await intent.get_power_levels(evt.room_id)
mxid = evt.args[1] if len(evt.args) > 1 else evt.sender.mxid
levels.users[mxid] = level
try:
return await intent.set_power_levels(evt.room_id, levels)
except (MatrixRequestError, IntentError):
evt.log.exception("Failed to set power level.")
return await evt.reply("Failed to set power level.")
35 changes: 28 additions & 7 deletions mautrix/bridge/commands/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
HelpCacheKey = NamedTuple('HelpCacheKey', is_management=bool, is_portal=bool)

SECTION_GENERAL = HelpSection("General", 0, "")
SECTION_ADMIN = HelpSection("Administration", 50, "")


def ensure_trailing_newline(s: str) -> str:
Expand Down Expand Up @@ -112,7 +113,7 @@ def print_error_traceback(self) -> bool:
"""
return self.is_management

def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
async def reply(self, message: str, allow_html: bool = False, render_markdown: bool = True
) -> Awaitable[EventID]:
"""Write a reply to the room in which the command was issued.

Expand All @@ -136,11 +137,16 @@ def reply(self, message: str, allow_html: bool = False, render_markdown: bool =
html = self._render_message(message, allow_html=allow_html,
render_markdown=render_markdown)

return self.az.intent.send_notice(self.room_id, message, html=html)
if self.is_portal:
portal = await self.processor.bridge.get_portal(self.room_id)
return await portal.main_intent.send_notice(self.room_id, message, html=html)
else:
return await self.az.intent.send_notice(self.room_id, message, html=html)

def mark_read(self) -> Awaitable[None]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mark_read(self) -> Awaitable[None]:
async def mark_read(self) -> Awaitable[None]:

"""Marks the command as read by the bot."""
return self.az.intent.mark_read(self.room_id, self.event_id)
if not self.is_portal:
return self.az.intent.mark_read(self.room_id, self.event_id)
Comment on lines +148 to +149
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not self.is_portal:
return self.az.intent.mark_read(self.room_id, self.event_id)
if self.room_id in await self.az.intent.get_joined_rooms():
return self.az.intent.mark_read(self.room_id, self.event_id)


def _replace_command_prefix(self, message: str) -> str:
"""Returns the string with the proper command prefix entered."""
Expand Down Expand Up @@ -184,20 +190,26 @@ class CommandHandler:
name: The name of this command.
help_section: Section of the help in which this command will appear.
"""
management_only: bool
name: str

management_only: bool
needs_admin: bool
needs_auth: bool

_help_text: str
_help_args: str
help_section: HelpSection

def __init__(self, handler: CommandHandlerFunc, management_only: bool, name: str,
help_text: str, help_args: str, help_section: HelpSection, **kwargs) -> None:
help_text: str, help_args: str, help_section: HelpSection,
needs_auth: bool, needs_admin: bool, **kwargs) -> None:
"""
Args:
handler: The function handling the execution of this command.
management_only: Whether the command can exclusively be issued
in a management room.
needs_auth: Whether the command needs the bridge to be authed already
needs_admin: Whether the command needs the issuer to be bridge admin
name: The name of this command.
help_text: The text displayed in the help for this command.
help_args: Help text for the arguments of this command.
Expand All @@ -207,6 +219,8 @@ def __init__(self, handler: CommandHandlerFunc, management_only: bool, name: str
setattr(self, key, value)
self._handler = handler
self.management_only = management_only
self.needs_admin = needs_admin
self.needs_auth = needs_auth
self.name = name
self._help_text = help_text
self._help_args = help_args
Expand All @@ -224,6 +238,10 @@ async def get_permission_error(self, evt: CommandEvent) -> Optional[str]:
if self.management_only and not evt.is_management:
return (f"`{evt.command}` is a restricted command: "
"you may only run it in management rooms.")
elif self.needs_admin and not evt.sender.is_admin:
return "This command requires administrator privileges."
elif self.needs_auth and not await evt.sender.is_logged_in():
return "This command requires you to be logged in."
return None

def has_permission(self, key: HelpCacheKey) -> bool:
Expand All @@ -236,7 +254,9 @@ def has_permission(self, key: HelpCacheKey) -> bool:
True if a user with the given state is allowed to issue the
command.
"""
return not self.management_only or key.is_management
return ((not self.management_only or key.is_management) and
(not self.needs_admin or key.is_admin) and
(not self.needs_auth or key.is_logged_in))

async def __call__(self, evt: CommandEvent) -> Any:
"""Executes the command if evt was issued with proper rights.
Expand Down Expand Up @@ -267,13 +287,14 @@ def command_handler(_func: Optional[CommandHandlerFunc] = None, *, management_on
name: Optional[str] = None, help_text: str = "", help_args: str = "",
help_section: HelpSection = None, aliases: Optional[List[str]] = None,
_handler_class: Type[CommandHandler] = CommandHandler,
needs_auth: bool = True, needs_admin: bool = False,
**kwargs) -> Callable[[CommandHandlerFunc], CommandHandler]:
"""Decorator to create CommandHandlers"""

def decorator(func: CommandHandlerFunc) -> CommandHandler:
actual_name = name or func.__name__.replace("_", "-")
handler = _handler_class(func, management_only, actual_name, help_text, help_args,
help_section, **kwargs)
help_section, needs_auth, needs_admin, **kwargs)
command_handlers[handler.name] = handler
if aliases:
for alias in aliases:
Expand Down
69 changes: 36 additions & 33 deletions mautrix/crypto/encrypt_megolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Tuple
from collections import defaultdict
from datetime import timedelta
import asyncio
Expand All @@ -18,7 +18,7 @@
from .types import DeviceIdentity, TrustState
from .encrypt_olm import OlmEncryptionMachine
from .device_lists import DeviceListMachine
from .sessions import OutboundGroupSession, InboundGroupSession
from .sessions import OutboundGroupSession, InboundGroupSession, Session


class Sentinel:
Expand All @@ -28,23 +28,23 @@ class Sentinel:
already_shared = Sentinel()
key_missing = Sentinel()

DeviceSessionWrapper = Tuple[Session, DeviceIdentity]
DeviceMap = Dict[UserID, Dict[DeviceID, DeviceSessionWrapper]]
SessionEncryptResult = Union[
type(already_shared), # already shared
EncryptedOlmEventContent, # share successful
DeviceSessionWrapper, # share successful
RoomKeyWithheldEventContent, # won't share
type(key_missing), # missing device
]


class MegolmEncryptionMachine(OlmEncryptionMachine, DeviceListMachine):
_megolm_locks: Dict[RoomID, asyncio.Lock]
_olm_locks: Dict[IdentityKey, asyncio.Lock]
_sharing_group_session: Dict[RoomID, asyncio.Event]

def __init__(self) -> None:
super().__init__()
self._megolm_locks = defaultdict(lambda: asyncio.Lock())
self._olm_locks = defaultdict(lambda: asyncio.Lock())
self._sharing_group_session = {}

async def encrypt_megolm_event(self, room_id: RoomID, event_type: EventType, content: Any
Expand Down Expand Up @@ -89,8 +89,8 @@ async def _encrypt_megolm_event(self, room_id: RoomID, event_type: EventType, co
relates_to = None
await self.crypto_store.update_outbound_group_session(session)
return EncryptedMegolmEventContent(sender_key=self.account.identity_key,
device_id=self.client.device_id, session_id=session.id,
ciphertext=ciphertext, relates_to=relates_to)
device_id=self.client.device_id, ciphertext=ciphertext,
session_id=SessionID(session.id), relates_to=relates_to)

def is_sharing_group_session(self, room_id: RoomID) -> bool:
"""
Expand Down Expand Up @@ -158,7 +158,7 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No
self.log.debug("Got stored encryption state event and configured session to rotate "
f"after {session.max_messages} messages or {session.max_age}")

share_key_msgs = defaultdict(lambda: {})
olm_sessions: DeviceMap = defaultdict(lambda: {})
withhold_key_msgs = defaultdict(lambda: {})
missing_sessions: Dict[UserID, Dict[DeviceID, DeviceIdentity]] = defaultdict(lambda: {})
fetch_keys = []
Expand All @@ -173,13 +173,13 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No
else:
self.log.debug(f"Trying to encrypt group session {session.id} for {user_id}")
for device_id, device in devices.items():
result = await self._encrypt_group_session(session, user_id, device_id, device)
if isinstance(result, EncryptedOlmEventContent):
share_key_msgs[user_id][device_id] = result
elif isinstance(result, RoomKeyWithheldEventContent):
result = await self._find_olm_sessions(session, user_id, device_id, device)
if isinstance(result, RoomKeyWithheldEventContent):
withhold_key_msgs[user_id][device_id] = result
elif result == key_missing:
missing_sessions[user_id][device_id] = device
elif isinstance(result, tuple):
olm_sessions[user_id][device_id] = result

if fetch_keys:
self.log.debug(f"Fetching missing keys for {fetch_keys}")
Expand All @@ -193,17 +193,16 @@ async def _share_group_session(self, room_id: RoomID, users: List[UserID]) -> No

for user_id, devices in missing_sessions.items():
for device_id, device in devices.items():
result = await self._encrypt_group_session(session, user_id, device_id, device)
if isinstance(result, EncryptedOlmEventContent):
share_key_msgs[user_id][device_id] = result
elif isinstance(result, RoomKeyWithheldEventContent):
result = await self._find_olm_sessions(session, user_id, device_id, device)
if isinstance(result, RoomKeyWithheldEventContent):
withhold_key_msgs[user_id][device_id] = result
elif isinstance(result, tuple):
olm_sessions[user_id][device_id] = result
# We don't care about missing keys at this point

if len(share_key_msgs) > 0:
event_count = sum(len(map) for map in share_key_msgs.values())
self.log.debug(f"Sending {event_count} to-device events to share {session.id}")
await self.client.send_to_device(EventType.TO_DEVICE_ENCRYPTED, share_key_msgs)
if len(olm_sessions) > 0:
async with self._olm_lock:
await self._encrypt_and_share_group_session(session, olm_sessions)
if len(withhold_key_msgs) > 0:
event_count = sum(len(map) for map in withhold_key_msgs.values())
self.log.debug(f"Sending {event_count} to-device events "
Expand All @@ -221,6 +220,19 @@ async def _new_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSes
room_id, SessionID(session.id), session.session_key)
return session

async def _encrypt_and_share_group_session(self, session: OutboundGroupSession,
olm_sessions: DeviceMap):
msgs = defaultdict(lambda: {})
count = 0
for user_id, devices in olm_sessions.items():
count += len(devices)
for device_id, (olm_session, device_identity) in devices.items():
msgs[user_id][device_id] = await self._encrypt_olm_event(
olm_session, device_identity, EventType.ROOM_KEY, session.share_content)
self.log.debug(f"Sending to-device events to {count} devices of {len(msgs)} users "
f"to share {session.id}")
await self.client.send_to_device(EventType.TO_DEVICE_ENCRYPTED, msgs)

async def _create_group_session(self, sender_key: IdentityKey, signing_key: SigningKey,
room_id: RoomID, session_id: SessionID, session_key: str
) -> None:
Expand All @@ -231,15 +243,9 @@ async def _create_group_session(self, sender_key: IdentityKey, signing_key: Sign
self._mark_session_received(session_id)
self.log.debug(f"Created inbound group session {room_id}/{sender_key}/{session_id}")

async def _encrypt_group_session(self, session: OutboundGroupSession, user_id: UserID,
device_id: DeviceID, device: DeviceIdentity
) -> SessionEncryptResult:
async with self._olm_locks[device.identity_key]:
return await self._encrypt_group_session_locked(session, user_id, device_id, device)

async def _encrypt_group_session_locked(self, session: OutboundGroupSession, user_id: UserID,
device_id: DeviceID, device: DeviceIdentity
) -> SessionEncryptResult:
async def _find_olm_sessions(self, session: OutboundGroupSession, user_id: UserID,
device_id: DeviceID, device: DeviceIdentity
) -> SessionEncryptResult:
key = (user_id, device_id)
if key in session.users_ignored or key in session.users_shared_with:
return already_shared
Expand Down Expand Up @@ -267,8 +273,5 @@ async def _encrypt_group_session_locked(self, session: OutboundGroupSession, use
device_session = await self.crypto_store.get_latest_session(device.identity_key)
if not device_session:
return key_missing
encrypted = await self._encrypt_olm_event(device_session, device, EventType.ROOM_KEY,
session.share_content)
session.users_shared_with.add(key)
self.log.debug(f"Encrypted group session {session.id} for {device_id} of {user_id}")
return encrypted
return device_session, device
9 changes: 6 additions & 3 deletions mautrix/crypto/encrypt_olm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

class OlmEncryptionMachine(BaseOlmMachine):
_claim_keys_lock: asyncio.Lock
_olm_lock: asyncio.Lock

def __init__(self):
self._claim_keys_lock = asyncio.Lock()
self._olm_lock = asyncio.Lock()

async def _encrypt_olm_event(self, session: Session, recipient: DeviceIdentity,
event_type: EventType, content: Any) -> EncryptedOlmEventContent:
Expand Down Expand Up @@ -66,6 +68,7 @@ async def send_encrypted_to_device(self, device: DeviceIdentity, event_type: Eve
content: ToDeviceEventContent) -> None:
await self._create_outbound_sessions({device.user_id: {device.device_id: device}})
session = await self.crypto_store.get_latest_session(device.identity_key)
encrypted_content = await self._encrypt_olm_event(session, device, event_type, content)
await self.client.send_to_one_device(EventType.TO_DEVICE_ENCRYPTED, device.user_id,
device.device_id, encrypted_content)
async with self._olm_lock:
encrypted_content = await self._encrypt_olm_event(session, device, event_type, content)
await self.client.send_to_one_device(EventType.TO_DEVICE_ENCRYPTED, device.user_id,
device.device_id, encrypted_content)
Loading