Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations #206

Merged
merged 6 commits into from
Aug 15, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mautrix_telegram/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from typing import Coroutine, List, Optional
import argparse
import asyncio
import logging.config
Expand Down Expand Up @@ -115,7 +115,7 @@
startup_actions = (init_puppet(context) +
init_user(context) +
[start,
context.mx.init_as_bot()])
context.mx.init_as_bot()]) # type: List[Coroutine]

if context.bot:
startup_actions.append(context.bot.start())
Expand Down
47 changes: 25 additions & 22 deletions mautrix_telegram/abstract_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .tgclient import MautrixTelegramClient

if TYPE_CHECKING:
from .types import TelegramId
from .context import Context
from .config import Config
from .bot import Bot
Expand All @@ -60,17 +61,18 @@ class AbstractUser(ABC):
bot = None # type: Bot
ignore_incoming_bot_events = True # type: bool

def __init__(self):
def __init__(self) -> None:
self.is_admin = False # type: bool
self.matrix_puppet_whitelisted = False # type: bool
self.puppet_whitelisted = False # type: bool
self.whitelisted = False # type: bool
self.relaybot_whitelisted = False # type: bool
self.client = None # type: MautrixTelegramClient
self.tgid = None # type: int
self.tgid = None # type: TelegramId
self.mxid = None # type: str
self.is_relaybot = False # type: bool
self.is_bot = False # type: bool
self.relaybot = None # type: Optional[Bot]

@property
def connected(self) -> bool:
Expand All @@ -93,7 +95,7 @@ def _proxy_settings(self) -> Optional[Tuple[int, str, str, str, str, str]]:
config["telegram.proxy.rdns"],
config["telegram.proxy.username"], config["telegram.proxy.password"])

def _init_client(self):
def _init_client(self) -> None:
self.log.debug(f"Initializing client for {self.name}")
device = f"{platform.system()} {platform.release()}"
sysversion = MautrixTelegramClient.__version__
Expand All @@ -114,18 +116,18 @@ async def update(self, update: TypeUpdate) -> bool:
return False

@abstractmethod
async def post_login(self):
async def post_login(self) -> None:
raise NotImplementedError()

@abstractmethod
def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()

@abstractmethod
def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
raise NotImplementedError()

async def _update_catch(self, update: TypeUpdate):
async def _update_catch(self, update: TypeUpdate) -> None:
try:
if not await self.update(update):
await self._update(update)
Expand Down Expand Up @@ -175,13 +177,13 @@ async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
await self.start(delete_unless_authenticated=not even_if_no_session)
return self

async def stop(self):
async def stop(self) -> None:
await self.client.disconnect()
self.client = None

# region Telegram update handling

async def _update(self, update: TypeUpdate):
async def _update(self, update: TypeUpdate) -> None:
if isinstance(update, (UpdateShortChatMessage, UpdateShortMessage, UpdateNewChannelMessage,
UpdateNewMessage, UpdateEditMessage, UpdateEditChannelMessage)):
await self.update_message(update)
Expand All @@ -207,18 +209,18 @@ async def _update(self, update: TypeUpdate):
self.log.debug("Unhandled update: %s", update)

@staticmethod
async def update_pinned_messages(update: UpdateChannelPinnedMessage):
async def update_pinned_messages(update: UpdateChannelPinnedMessage) -> None:
portal = po.Portal.get_by_tgid(update.channel_id)
if portal and portal.mxid:
await portal.receive_telegram_pin_id(update.id)

@staticmethod
async def update_participants(update: UpdateChatParticipants):
async def update_participants(update: UpdateChatParticipants) -> None:
portal = po.Portal.get_by_tgid(update.participants.chat_id)
if portal and portal.mxid:
await portal.update_telegram_participants(update.participants.participants)

async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
async def update_read_receipt(self, update: UpdateReadHistoryOutbox) -> None:
if not isinstance(update.peer, PeerUser):
self.log.debug("Unexpected read receipt peer: %s", update.peer)
return
Expand All @@ -235,7 +237,8 @@ async def update_read_receipt(self, update: UpdateReadHistoryOutbox):
puppet = pu.Puppet.get(update.peer.user_id)
await puppet.intent.mark_read(portal.mxid, message.mxid)

async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]):
async def update_admin(self,
update: Union[UpdateChatAdmins, UpdateChatParticipantAdmin]) -> None:
# TODO duplication not checked
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
if isinstance(update, UpdateChatAdmins):
Expand All @@ -245,15 +248,15 @@ async def update_admin(self, update: Union[UpdateChatAdmins, UpdateChatParticipa
else:
self.log.warning("Unexpected admin status update: %s", update)

async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]):
async def update_typing(self, update: Union[UpdateUserTyping, UpdateChatUserTyping]) -> None:
if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else:
portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.user_id)
await portal.handle_telegram_typing(sender, update)

async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]):
async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto]) -> None:
# TODO duplication not checked
puppet = pu.Puppet.get(update.user_id)
if isinstance(update, UpdateUserName):
Expand All @@ -265,7 +268,7 @@ async def update_others_info(self, update: Union[UpdateUserName, UpdateUserPhoto
else:
self.log.warning("Unexpected other user info update: %s", update)

async def update_status(self, update: UpdateUserStatus):
async def update_status(self, update: UpdateUserStatus) -> None:
puppet = pu.Puppet.get(update.user_id)
if isinstance(update.status, UserStatusOnline):
await puppet.default_mxid_intent.set_presence("online")
Expand Down Expand Up @@ -300,15 +303,15 @@ def get_message_details(self, update: UpdateMessage) -> Tuple[UpdateMessageConte
return update, sender, portal

@staticmethod
async def _try_redact(portal: po.Portal, message: DBMessage):
async def _try_redact(portal: po.Portal, message: DBMessage) -> None:
if not portal:
return
try:
await portal.main_intent.redact(message.mx_room, message.mxid)
except MatrixRequestError:
pass

async def delete_message(self, update: UpdateDeleteMessages):
async def delete_message(self, update: UpdateDeleteMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return

Expand All @@ -324,7 +327,7 @@ async def delete_message(self, update: UpdateDeleteMessages):
await self._try_redact(portal, message)
self.db.commit()

async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
async def delete_channel_message(self, update: UpdateDeleteChannelMessages) -> None:
if len(update.messages) > MAX_DELETIONS:
return

Expand All @@ -340,7 +343,7 @@ async def delete_channel_message(self, update: UpdateDeleteChannelMessages):
await self._try_redact(portal, message)
self.db.commit()

async def update_message(self, original_update: UpdateMessage):
async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = self.get_message_details(original_update)
if self.ignore_incoming_bot_events and self.bot and sender.id == self.bot.tgid:
self.log.debug(f"Ignoring relaybot-sent message %s to %s", update, portal.tgid_log)
Expand Down Expand Up @@ -369,9 +372,9 @@ async def update_message(self, original_update: UpdateMessage):
# endregion


def init(context: "Context"):
def init(context: "Context") -> None:
global config, MAX_DELETIONS
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context
AbstractUser.az, AbstractUser.db, config, AbstractUser.loop, AbstractUser.relaybot = context.core
AbstractUser.ignore_incoming_bot_events = config["bridge.relaybot.ignore_own_incoming_events"]
AbstractUser.session_container = context.session_container
MAX_DELETIONS = config.get("bridge.max_telegram_delete", 10)
55 changes: 32 additions & 23 deletions mautrix_telegram/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Awaitable, Callable, Pattern, Dict, TYPE_CHECKING
from typing import Awaitable, Callable, Dict, List, Optional, Pattern, TYPE_CHECKING
import logging
import re

from telethon.tl.types import *
from telethon.tl.types import (
ChannelParticipantAdmin, ChannelParticipantCreator, ChatForbidden, ChatParticipantAdmin,
ChatParticipantCreator, InputChannel, InputUser, Message, MessageActionChatAddUser,
MessageActionChatDeleteUser, MessageEntityBotCommand, MessageService, PeerChannel, PeerChat,
TypePeer, UpdateNewChannelMessage, UpdateNewMessage)
from telethon.tl.functions.messages import GetChatsRequest, GetFullChatRequest
from telethon.tl.functions.channels import GetChannelsRequest, GetParticipantRequest
from telethon.errors import ChannelInvalidError, ChannelPrivateError

from .types import MatrixUserId
from .abstract_user import AbstractUser
from .db import BotChat
from . import puppet as pu, portal as po, user as u

if TYPE_CHECKING:
from .config import Config
from .context import Context

config = None # type: Config

Expand All @@ -39,7 +45,7 @@ class Bot(AbstractUser):
log = logging.getLogger("mau.bot") # type: logging.Logger
mxid_regex = re.compile("@.+:.+") # type: Pattern

def __init__(self, token: str):
def __init__(self, token: str) -> None:
super().__init__()
self.token = token # type: str
self.puppet_whitelisted = True # type: bool
Expand All @@ -53,7 +59,7 @@ def __init__(self, token: str):
self.whitelist_group_admins = (config["bridge.relaybot.whitelist_group_admins"]
or False) # type: bool

async def init_permissions(self):
async def init_permissions(self) -> None:
whitelist = config["bridge.relaybot.whitelist"] or []
for id in whitelist:
if isinstance(id, str):
Expand All @@ -72,7 +78,7 @@ async def start(self, delete_unless_authenticated: bool = False) -> "Bot":
await self.post_login()
return self

async def post_login(self):
async def post_login(self) -> None:
await self.init_permissions()
info = await self.client.get_me()
self.tgid = info.id
Expand Down Expand Up @@ -100,19 +106,19 @@ async def post_login(self):
except Exception:
self.log.exception("Failed to run catch_up() for bot")

def register_portal(self, portal: po.Portal):
def register_portal(self, portal: po.Portal) -> None:
self.add_chat(portal.tgid, portal.peer_type)

def unregister_portal(self, portal: po.Portal):
def unregister_portal(self, portal: po.Portal) -> None:
self.remove_chat(portal.tgid)

def add_chat(self, id: int, type: str):
def add_chat(self, id: int, type: str) -> None:
if id not in self.chats:
self.chats[id] = type
self.db.add(BotChat(id=id, type=type))
self.db.commit()

def remove_chat(self, id: int):
def remove_chat(self, id: int) -> None:
try:
del self.chats[id]
except KeyError:
Expand Down Expand Up @@ -141,14 +147,15 @@ async def _can_use_commands(self, chat: TypePeer, tgid: int) -> bool:
for p in participants:
if p.user_id == tgid:
return isinstance(p, (ChatParticipantCreator, ChatParticipantAdmin))
return False

async def check_can_use_commands(self, event: Message, reply: ReplyFunc) -> bool:
if not await self._can_use_commands(event.to_id, event.from_id):
await reply("You do not have the permission to use that command.")
return False
return True

async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc):
async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc) -> None:
if not config["bridge.relaybot.authless_portals"]:
return await reply("This bridge doesn't allow portal creation from Telegram.")

Expand All @@ -164,15 +171,16 @@ async def handle_command_portal(self, portal: po.Portal, reply: ReplyFunc):
return await reply(
"Portal is not public. Use `/invite <mxid>` to get an invite.")

async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid: str):
if len(mxid) == 0:
async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc,
mxid_input: MatrixUserId) -> Message:
if len(mxid_input) == 0:
return await reply("Usage: `/invite <mxid>`")
elif not portal.mxid:
return await reply("Portal does not have Matrix room. "
"Create one with /portal first.")
if not self.mxid_regex.match(mxid):
if not self.mxid_regex.match(mxid_input):
return await reply("That doesn't look like a Matrix ID.")
user = await u.User.get_by_mxid(mxid).ensure_started()
user = await u.User.get_by_mxid(MatrixUserId(mxid_input)).ensure_started()
if not user.relaybot_whitelisted:
return await reply("That user is not whitelisted to use the bridge.")
elif await user.is_logged_in():
Expand All @@ -183,7 +191,7 @@ async def handle_command_invite(self, portal: po.Portal, reply: ReplyFunc, mxid:
await portal.main_intent.invite(portal.mxid, user.mxid)
return await reply(f"Invited `{user.mxid}` to the portal.")

def handle_command_id(self, message: Message, reply: ReplyFunc):
def handle_command_id(self, message: Message, reply: ReplyFunc) -> Awaitable[Message]:
# Provide the prefixed ID to the user so that the user wouldn't need to specify whether the
# chat is a normal group or a supergroup/channel when using the ID.
if isinstance(message.to_id, PeerChannel):
Expand All @@ -205,8 +213,8 @@ def match_command(self, text: str, command: str) -> bool:

return False

async def handle_command(self, message: Message):
def reply(reply_text):
async def handle_command(self, message: Message) -> None:
def reply(reply_text: str) -> Awaitable[Message]:
return self.client.send_message(message.to_id, reply_text, reply_to=message.id)

text = message.message
Expand All @@ -227,9 +235,9 @@ def reply(reply_text):
mxid = text[text.index(" ") + 1:]
except ValueError:
mxid = ""
await self.handle_command_invite(portal, reply, mxid=mxid)
await self.handle_command_invite(portal, reply, mxid_input=mxid)

def handle_service_message(self, message: MessageService):
def handle_service_message(self, message: MessageService) -> None:
to_id = message.to_id
if isinstance(to_id, PeerChannel):
to_id = to_id.channel_id
Expand All @@ -246,11 +254,12 @@ def handle_service_message(self, message: MessageService):
elif isinstance(action, MessageActionChatDeleteUser) and action.user_id == self.tgid:
self.remove_chat(to_id)

async def update(self, update):
async def update(self, update) -> bool:
if not isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
return
return False
if isinstance(update.message, MessageService):
return self.handle_service_message(update.message)
self.handle_service_message(update.message)
return False

is_command = (isinstance(update.message, Message)
and update.message.entities and len(update.message.entities) > 0
Expand All @@ -266,7 +275,7 @@ def name(self) -> str:
return "bot"


def init(context) -> Optional[Bot]:
def init(context: 'Context') -> Optional[Bot]:
global config
config = context.config
token = config["telegram.bot_token"]
Expand Down
Loading