From 6edfb6255eba33b8aa7530ca860f413c58834997 Mon Sep 17 00:00:00 2001 From: Danny Lin Date: Sat, 14 Dec 2019 21:29:08 -0800 Subject: [PATCH] bot: Split into multiple compartmentalized ABC mixins This helps accomodate the growing Bot class, since it's easily divided into several logical components anyway. Each component is an ABC (Abstract Base Class) that functions as a mixin and thus can access all the other attributes of the Bot class. Bot subclasses all the mixins to unify them. Telethon uses the same approach to divide TelegramClient into many mixins that each contain different categories of API methods. An adaptive MixinBase reference that is defined as abc.ABC during runtime and Bot during type checking is used as the subclass for each mixin to work around https://github.com/python/mypy/issues/5837. Furthermore, the MixinBase reference itself has its type specified as Any to work around https://github.com/python/mypy/issues/2477. The only change for outsiders should be the Bot class being moved to the `core` module. Signed-off-by: Danny Lin --- pyrobud/bot.py | 429 ----------------------------- pyrobud/command.py | 2 +- pyrobud/core/__init__.py | 1 + pyrobud/core/bot.py | 38 +++ pyrobud/core/bot_mixin_base.py | 11 + pyrobud/core/command_dispatcher.py | 133 +++++++++ pyrobud/core/database_provider.py | 26 ++ pyrobud/core/event_dispatcher.py | 77 ++++++ pyrobud/core/module_extender.py | 85 ++++++ pyrobud/core/telegram_bot.py | 163 +++++++++++ pyrobud/launch.py | 2 +- pyrobud/module.py | 2 +- 12 files changed, 537 insertions(+), 432 deletions(-) delete mode 100644 pyrobud/bot.py create mode 100644 pyrobud/core/__init__.py create mode 100644 pyrobud/core/bot.py create mode 100644 pyrobud/core/bot_mixin_base.py create mode 100644 pyrobud/core/command_dispatcher.py create mode 100644 pyrobud/core/database_provider.py create mode 100644 pyrobud/core/event_dispatcher.py create mode 100644 pyrobud/core/module_extender.py create mode 100644 pyrobud/core/telegram_bot.py diff --git a/pyrobud/bot.py b/pyrobud/bot.py deleted file mode 100644 index 83d8815e..00000000 --- a/pyrobud/bot.py +++ /dev/null @@ -1,429 +0,0 @@ -import asyncio -import bisect -import importlib -import inspect -import logging -import os -from types import ModuleType -from typing import Any, Mapping, MutableMapping, MutableSequence, Optional, Type, Union - -import aiohttp -import plyvel -import sentry_sdk -import telethon as tg - -from . import command, custom_modules, module, modules, util -from .listener import Listener, ListenerFunc - - -class Bot: - # Initialized during instantiation - commands: MutableMapping[str, command.Command] - modules: MutableMapping[str, module.Module] - listeners: MutableMapping[str, MutableSequence[Listener]] - config: util.config.Config - log: logging.Logger - http_session: aiohttp.ClientSession - _db: util.db.AsyncDB - db: util.db.AsyncDB - client: tg.TelegramClient - - # Initialized during startup - loop: asyncio.AbstractEventLoop - prefix: str - user: tg.types.User - uid: int - start_time_us: int - - def __init__(self, config: util.config.Config): - # Initialize module dicts - self.commands = {} - self.modules = {} - self.listeners = {} - - # Save reference to config - self.config = config - - # Initialize other objects - self.log = logging.getLogger("bot") - self.http_session = aiohttp.ClientSession() - - # Initialize database - self._db = util.db.AsyncDB(plyvel.DB(config["bot"]["db_path"], create_if_missing=True)) - self.db = self.get_db("bot") - - # Initialize Telegram client - self.init_client() - - def init_client(self) -> None: - tg_config: Mapping[str, Union[int, str]] = self.config["telegram"] - - session_name = tg_config["session_name"] - if not isinstance(session_name, str): - raise TypeError("Session name must be a str") - - api_id = tg_config["api_id"] - if not isinstance(api_id, int): - raise TypeError("API ID must be an int") - - api_hash = tg_config["api_hash"] - if not isinstance(api_hash, str): - raise TypeError("API hash must be a str") - - self.client = tg.TelegramClient(session_name, api_id, api_hash) - - def get_db(self, prefix: str) -> util.db.AsyncDB: - return self._db.prefixed_db(prefix + ".") - - def register_command(self, mod: module.Module, name: str, func: command.CommandFunc) -> None: - cmd = command.Command(name, mod, func) - - if name in self.commands: - orig = self.commands[name] - raise module.ExistingCommandError(orig, cmd) - - self.commands[name] = cmd - - for alias in cmd.aliases: - if alias in self.commands: - orig = self.commands[alias] - raise module.ExistingCommandError(orig, cmd, alias=True) - - self.commands[alias] = cmd - - def unregister_command(self, cmd: command.Command) -> None: - del self.commands[cmd.name] - - for alias in cmd.aliases: - try: - del self.commands[alias] - except KeyError: - continue - - def register_commands(self, mod: module.Module) -> None: - for name, func in util.find_prefixed_funcs(mod, "cmd_"): - done = False - - try: - self.register_command(mod, name, func) - done = True - finally: - if not done: - self.unregister_commands(mod) - - def unregister_commands(self, mod: module.Module) -> None: - # Can't unregister while iterating, so collect commands to unregister afterwards - to_unreg = [] - - for name, cmd in self.commands.items(): - # Let unregister_command deal with aliases - if name != cmd.name: - continue - - if cmd.module == mod: - to_unreg.append(cmd) - - # Actually unregister the commands - for cmd in to_unreg: - self.unregister_command(cmd) - - def register_listener(self, mod: module.Module, event: str, func: ListenerFunc, priority: int = 100) -> None: - listener = Listener(event, func, mod, priority) - - if event in self.listeners: - bisect.insort(self.listeners[event], listener) - else: - self.listeners[event] = [listener] - - def unregister_listener(self, listener: Listener) -> None: - self.listeners[listener.event].remove(listener) - - def register_listeners(self, mod: module.Module) -> None: - for event, func in util.find_prefixed_funcs(mod, "on_"): - done = True - try: - self.register_listener(mod, event, func, priority=getattr(func, "_listener_priority", 100)) - done = True - finally: - if not done: - self.unregister_listeners(mod) - - def unregister_listeners(self, mod: module.Module) -> None: - # Can't unregister while iterating, so collect listeners to unregister afterwards - to_unreg = [] - - for lst in self.listeners.values(): - for listener in lst: - if listener.module == mod: - to_unreg.append(listener) - - # Actually unregister the listeners - for listener in to_unreg: - self.unregister_listener(listener) - - def load_module(self, cls: Type[module.Module], *, comment: Optional[str] = None) -> None: - _comment = comment + " " if comment else "" - self.log.info( - f"Loading {_comment}module '{cls.name}' ({cls.__name__}) from '{os.path.relpath(inspect.getfile(cls))}'" - ) - - if cls.name in self.modules: - old = type(self.modules[cls.name]) - raise module.ExistingModuleError(old, cls) - - mod = cls(self) - mod.comment = comment - self.register_listeners(mod) - self.register_commands(mod) - self.modules[cls.name] = mod - - def unload_module(self, mod: module.Module) -> None: - _comment = mod.comment + " " if mod.comment else "" - - cls = type(mod) - path = os.path.relpath(inspect.getfile(cls)) - self.log.info(f"Unloading {_comment}module '{cls.name}' ({cls.__name__}) from '{path}'") - - self.unregister_listeners(mod) - self.unregister_commands(mod) - del self.modules[cls.name] - - def _load_modules_from_metamod(self, metamod: ModuleType, *, comment: str = None) -> None: - for _sym in getattr(metamod, "__all__", ()): - module_mod: ModuleType = getattr(metamod, _sym) - - if inspect.ismodule(module_mod): - for sym in dir(module_mod): - cls = getattr(module_mod, sym) - if inspect.isclass(cls) and issubclass(cls, module.Module): - self.load_module(cls, comment=comment) - - def load_all_modules(self) -> None: - self.log.info("Loading modules") - self._load_modules_from_metamod(modules) - self._load_modules_from_metamod(custom_modules, comment="custom") - self.log.info("All modules loaded.") - - def unload_all_modules(self) -> None: - self.log.info("Unloading modules...") - - # Can't modify while iterating, so collect a list first - for mod in list(self.modules.values()): - self.unload_module(mod) - - self.log.info("All modules unloaded.") - - async def reload_module_pkg(self) -> None: - self.log.info("Reloading base module class...") - await util.run_sync(importlib.reload, module) - - self.log.info("Reloading master module...") - await util.run_sync(importlib.reload, modules) - - self.log.info("Reloading custom master module...") - await util.run_sync(importlib.reload, custom_modules) - - def command_predicate(self, event: tg.events.NewMessage.Event) -> bool: - if event.raw_text.startswith(self.prefix): - parts = event.raw_text.split() - parts[0] = parts[0][len(self.prefix) :] - - event.segments = parts - return True - - return False - - async def start(self) -> None: - # Get and store current event loop, since this is the first coroutine - self.loop = asyncio.get_event_loop() - - # Load prefix - self.prefix = await self.db.get("prefix", self.config["bot"]["default_prefix"]) - - # Load modules - self.load_all_modules() - await self.dispatch_event("load") - - # Start Telegram client - await self.client.start() - - # Get info - user = await self.client.get_me() - if not isinstance(user, tg.types.User): - raise TypeError("Missing full self user information") - self.user = user - self.uid = user.id - - # Set Sentry username if enabled - if self.config["bot"]["report_username"]: - with sentry_sdk.configure_scope() as scope: - scope.set_user({"username": self.user.username}) - - # Record start time and dispatch start event - self.start_time_us = util.time.usec() - await self.dispatch_event("start", self.start_time_us) - - # Register handlers - self.client.add_event_handler(self.on_message, tg.events.NewMessage()) - self.client.add_event_handler(self.on_message_edit, tg.events.MessageEdited()) - self.client.add_event_handler( - self.on_command, tg.events.NewMessage(outgoing=True, func=self.command_predicate), - ) - self.client.add_event_handler(self.on_chat_action, tg.events.ChatAction()) - - self.log.info("Bot is ready") - - # Catch up on missed events - self.log.info("Catching up on missed events") - await self.client.catch_up() - self.log.info("Finished catching up") - - async def stop(self) -> None: - await self.dispatch_event("stop") - await self.http_session.close() - await self._db.close() - - self.log.info("Running post-stop hooks") - await self.dispatch_event("stopped") - - async def dispatch_event(self, event: str, *args: Any, **kwargs: Any) -> None: - tasks = set() - - try: - listeners = self.listeners[event] - except KeyError: - return None - - if not listeners: - return - - for lst in listeners: - task = self.loop.create_task(lst.func(*args, **kwargs)) - tasks.add(task) - - self.log.debug(f"Dispatching event '{event}' with data {args}") - await asyncio.wait(tasks) - - def dispatch_event_nowait(self, event: str, *args: Any, **kwargs: Any) -> None: - self.loop.create_task(self.dispatch_event(event, *args, **kwargs)) - - async def on_message(self, event: tg.events.NewMessage.Event) -> None: - await self.dispatch_event("message", event) - - async def on_message_edit(self, event: tg.events.MessageEdited.Event) -> None: - await self.dispatch_event("message_edit", event) - - async def on_chat_action(self, event: tg.events.ChatAction.Event) -> None: - await self.dispatch_event("chat_action", event) - - async def on_command(self, msg: tg.events.NewMessage.Event) -> None: - cmd = None - - try: - # Attempt to get command info - try: - cmd = self.commands[msg.segments[0]] - except KeyError: - return - - # Construct invocation context - ctx = command.Context(self, msg.message, msg.segments, len(self.prefix) + len(msg.segments[0]) + 1,) - - if not (cmd.usage is None or cmd.usage_optional or ctx.input): - err_base = f"⚠️ Missing parameters: {cmd.usage}" - - if cmd.usage_reply: - if msg.is_reply: - reply_msg = await msg.get_reply_message() - if reply_msg.text: - ctx.input = reply_msg.text - ctx.parsed_input = reply_msg.raw_text - else: - await ctx.respond(f"{err_base}\n__The message you replied to doesn't contain text.__") - return - else: - await ctx.respond(f"{err_base} (replying is also supported)") - return - else: - await ctx.respond(err_base) - return - - # Invoke command function - try: - ret = await cmd.func(ctx) - - # Response shortcut - if ret is not None: - await ctx.respond(ret) - except Exception as e: - cmd.module.log.error(f"Error in command '{cmd.name}'", exc_info=e) - await ctx.respond(f"⚠️ Error executing command:\n```{util.format_exception(e)}```") - - await self.dispatch_event("command", cmd, msg) - except Exception as e: - if cmd is not None: - cmd.module.log.error("Error in command handler", exc_info=e) - - await self.respond(msg.message, f"⚠️ Error in command handler:\n```{util.format_exception(e)}```") - - # Flexible response function with filtering, truncation, redaction, etc. - async def respond( - self, - msg: tg.custom.Message, - text: Optional[str] = None, - *, - mode: Optional[str] = None, - redact: Optional[bool] = None, - response: Optional[tg.custom.Message] = None, - **kwargs: Any, - ) -> tg.custom.Message: - # Read redaction setting from config - if redact is None: - redact = self.config["bot"]["redact_responses"] - - # Filter text - if text is not None: - # Redact sensitive information if enabled and known - if redact: - tg_config: Mapping[str, str] = self.config["telegram"] - api_id = str(tg_config["api_id"]) - api_hash = tg_config["api_hash"] - - if api_id in text: - text = text.replace(api_id, "[REDACTED]") - if api_hash in text: - text = text.replace(api_hash, "[REDACTED]") - if self.user.phone is not None and self.user.phone in text: - text = text.replace(self.user.phone, "[REDACTED]") - - # Truncate messages longer than Telegram's 4096-character length limit - text = util.tg.truncate(text) - - # Default to disabling link previews in responses - if "link_preview" not in kwargs: - kwargs["link_preview"] = False - - # Use selected response mode if not overridden by invoker - if mode is None: - mode = self.config["bot"]["response_mode"] - - if mode == "edit": - return await msg.edit(text=text, **kwargs) - elif mode == "reply": - if response is not None: - # Already replied, so just edit the existing reply to reduce spam - return await response.edit(text=text, **kwargs) - else: - # Reply since we haven't done so yet - return await msg.reply(text, **kwargs) - elif mode == "repost": - if response is not None: - # Already reposted, so just edit the existing reply to reduce spam - return await response.edit(text=text, **kwargs) - else: - # Repost since we haven't done so yet - response = await msg.respond(text, reply_to=msg.reply_to_msg_id, **kwargs) - await msg.delete() - return response - else: - raise ValueError(f"Unknown response mode '{mode}'") diff --git a/pyrobud/command.py b/pyrobud/command.py index 56f0dee6..d75339bf 100644 --- a/pyrobud/command.py +++ b/pyrobud/command.py @@ -3,7 +3,7 @@ import telethon as tg if TYPE_CHECKING: - from .bot import Bot + from .core import Bot CommandFunc = Union[ Callable[..., Coroutine[Any, Any, None]], Callable[..., Coroutine[Any, Any, Optional[str]]], diff --git a/pyrobud/core/__init__.py b/pyrobud/core/__init__.py new file mode 100644 index 00000000..e34e5d92 --- /dev/null +++ b/pyrobud/core/__init__.py @@ -0,0 +1 @@ +from .bot import Bot diff --git a/pyrobud/core/bot.py b/pyrobud/core/bot.py new file mode 100644 index 00000000..7ea79517 --- /dev/null +++ b/pyrobud/core/bot.py @@ -0,0 +1,38 @@ +import logging + +import aiohttp +import telethon as tg + +from ..util.config import Config +from .command_dispatcher import CommandDispatcher +from .database_provider import DatabaseProvider +from .event_dispatcher import EventDispatcher +from .module_extender import ModuleExtender +from .telegram_bot import TelegramBot + + +class Bot(TelegramBot, ModuleExtender, CommandDispatcher, DatabaseProvider, EventDispatcher): + # Initialized during instantiation + config: Config + log: logging.Logger + http_session: aiohttp.ClientSession + client: tg.TelegramClient + + def __init__(self, config: Config): + # Save reference to config + self.config = config + + # Initialize other objects + self.log = logging.getLogger("bot") + self.http_session = aiohttp.ClientSession() + + # Initialize mixins + super().__init__() + + async def stop(self) -> None: + await self.dispatch_event("stop") + await self.http_session.close() + await self._db.close() + + self.log.info("Running post-stop hooks") + await self.dispatch_event("stopped") diff --git a/pyrobud/core/bot_mixin_base.py b/pyrobud/core/bot_mixin_base.py new file mode 100644 index 00000000..e8155135 --- /dev/null +++ b/pyrobud/core/bot_mixin_base.py @@ -0,0 +1,11 @@ +from typing import TYPE_CHECKING, Any + +MixinBase: Any +if TYPE_CHECKING: + from .bot import Bot + + MixinBase = Bot +else: + import abc + + MixinBase = abc.ABC diff --git a/pyrobud/core/command_dispatcher.py b/pyrobud/core/command_dispatcher.py new file mode 100644 index 00000000..40b78acb --- /dev/null +++ b/pyrobud/core/command_dispatcher.py @@ -0,0 +1,133 @@ +from typing import TYPE_CHECKING, Any, MutableMapping + +import telethon as tg + +from .. import command, module, util +from .bot_mixin_base import MixinBase + +if TYPE_CHECKING: + from .bot import Bot + + +class CommandDispatcher(MixinBase): + # Initialized during instantiation + commands: MutableMapping[str, command.Command] + + def __init__(self: "Bot", **kwargs: Any) -> None: + # Initialize command map + self.commands = {} + + # Propagate initialization to other mixins + super().__init__(**kwargs) + + def register_command(self: "Bot", mod: module.Module, name: str, func: command.CommandFunc) -> None: + cmd = command.Command(name, mod, func) + + if name in self.commands: + orig = self.commands[name] + raise module.ExistingCommandError(orig, cmd) + + self.commands[name] = cmd + + for alias in cmd.aliases: + if alias in self.commands: + orig = self.commands[alias] + raise module.ExistingCommandError(orig, cmd, alias=True) + + self.commands[alias] = cmd + + def unregister_command(self: "Bot", cmd: command.Command) -> None: + del self.commands[cmd.name] + + for alias in cmd.aliases: + try: + del self.commands[alias] + except KeyError: + continue + + def register_commands(self: "Bot", mod: module.Module) -> None: + for name, func in util.find_prefixed_funcs(mod, "cmd_"): + done = False + + try: + self.register_command(mod, name, func) + done = True + finally: + if not done: + self.unregister_commands(mod) + + def unregister_commands(self: "Bot", mod: module.Module) -> None: + # Can't unregister while iterating, so collect commands to unregister afterwards + to_unreg = [] + + for name, cmd in self.commands.items(): + # Let unregister_command deal with aliases + if name != cmd.name: + continue + + if cmd.module == mod: + to_unreg.append(cmd) + + # Actually unregister the commands + for cmd in to_unreg: + self.unregister_command(cmd) + + def command_predicate(self: "Bot", event: tg.events.NewMessage.Event) -> bool: + if event.raw_text.startswith(self.prefix): + parts = event.raw_text.split() + parts[0] = parts[0][len(self.prefix) :] + + event.segments = parts + return True + + return False + + async def on_command(self: "Bot", msg: tg.events.NewMessage.Event) -> None: + cmd = None + + try: + # Attempt to get command info + try: + cmd = self.commands[msg.segments[0]] + except KeyError: + return + + # Construct invocation context + ctx = command.Context(self, msg.message, msg.segments, len(self.prefix) + len(msg.segments[0]) + 1,) + + if not (cmd.usage is None or cmd.usage_optional or ctx.input): + err_base = f"⚠️ Missing parameters: {cmd.usage}" + + if cmd.usage_reply: + if msg.is_reply: + reply_msg = await msg.get_reply_message() + if reply_msg.text: + ctx.input = reply_msg.text + ctx.parsed_input = reply_msg.raw_text + else: + await ctx.respond(f"{err_base}\n__The message you replied to doesn't contain text.__") + return + else: + await ctx.respond(f"{err_base} (replying is also supported)") + return + else: + await ctx.respond(err_base) + return + + # Invoke command function + try: + ret = await cmd.func(ctx) + + # Response shortcut + if ret is not None: + await ctx.respond(ret) + except Exception as e: + cmd.module.log.error(f"Error in command '{cmd.name}'", exc_info=e) + await ctx.respond(f"⚠️ Error executing command:\n```{util.format_exception(e)}```") + + await self.dispatch_event("command", cmd, msg) + except Exception as e: + if cmd is not None: + cmd.module.log.error("Error in command handler", exc_info=e) + + await self.respond(msg.message, f"⚠️ Error in command handler:\n```{util.format_exception(e)}```") diff --git a/pyrobud/core/database_provider.py b/pyrobud/core/database_provider.py new file mode 100644 index 00000000..bbf6397e --- /dev/null +++ b/pyrobud/core/database_provider.py @@ -0,0 +1,26 @@ +from typing import TYPE_CHECKING, Any + +import plyvel + +from .. import util +from .bot_mixin_base import MixinBase + +if TYPE_CHECKING: + from .bot import Bot + + +class DatabaseProvider(MixinBase): + # Initialized during instantiation + _db: util.db.AsyncDB + db: util.db.AsyncDB + + def __init__(self: "Bot", **kwargs: Any) -> None: + # Initialize database + self._db = util.db.AsyncDB(plyvel.DB(self.config["bot"]["db_path"], create_if_missing=True)) + self.db = self.get_db("bot") + + # Propagate initialization to other mixins + super().__init__(**kwargs) + + def get_db(self: "Bot", prefix: str) -> util.db.AsyncDB: + return self._db.prefixed_db(prefix + ".") diff --git a/pyrobud/core/event_dispatcher.py b/pyrobud/core/event_dispatcher.py new file mode 100644 index 00000000..b5980303 --- /dev/null +++ b/pyrobud/core/event_dispatcher.py @@ -0,0 +1,77 @@ +import asyncio +import bisect +from typing import TYPE_CHECKING, Any, MutableMapping, MutableSequence + +from .. import module, util +from ..listener import Listener, ListenerFunc +from .bot_mixin_base import MixinBase + +if TYPE_CHECKING: + from .bot import Bot + + +class EventDispatcher(MixinBase): + # Initialized during instantiation + listeners: MutableMapping[str, MutableSequence[Listener]] + + def __init__(self: "Bot", **kwargs: Any) -> None: + # Initialize listener map + self.listeners = {} + + # Propagate initialization to other mixins + super().__init__(**kwargs) + + def register_listener(self: "Bot", mod: module.Module, event: str, func: ListenerFunc, priority: int = 100) -> None: + listener = Listener(event, func, mod, priority) + + if event in self.listeners: + bisect.insort(self.listeners[event], listener) + else: + self.listeners[event] = [listener] + + def unregister_listener(self: "Bot", listener: Listener) -> None: + self.listeners[listener.event].remove(listener) + + def register_listeners(self: "Bot", mod: module.Module) -> None: + for event, func in util.find_prefixed_funcs(mod, "on_"): + done = True + try: + self.register_listener(mod, event, func, priority=getattr(func, "_listener_priority", 100)) + done = True + finally: + if not done: + self.unregister_listeners(mod) + + def unregister_listeners(self: "Bot", mod: module.Module) -> None: + # Can't unregister while iterating, so collect listeners to unregister afterwards + to_unreg = [] + + for lst in self.listeners.values(): + for listener in lst: + if listener.module == mod: + to_unreg.append(listener) + + # Actually unregister the listeners + for listener in to_unreg: + self.unregister_listener(listener) + + async def dispatch_event(self: "Bot", event: str, *args: Any, **kwargs: Any) -> None: + tasks = set() + + try: + listeners = self.listeners[event] + except KeyError: + return None + + if not listeners: + return + + for lst in listeners: + task = self.loop.create_task(lst.func(*args, **kwargs)) + tasks.add(task) + + self.log.debug(f"Dispatching event '{event}' with data {args}") + await asyncio.wait(tasks) + + def dispatch_event_nowait(self: "Bot", event: str, *args: Any, **kwargs: Any) -> None: + self.loop.create_task(self.dispatch_event(event, *args, **kwargs)) diff --git a/pyrobud/core/module_extender.py b/pyrobud/core/module_extender.py new file mode 100644 index 00000000..32cd045f --- /dev/null +++ b/pyrobud/core/module_extender.py @@ -0,0 +1,85 @@ +import importlib +import inspect +import os +from types import ModuleType +from typing import TYPE_CHECKING, Any, MutableMapping, Optional, Type + +from .. import custom_modules, module, modules, util +from .bot_mixin_base import MixinBase + +if TYPE_CHECKING: + from .bot import Bot + + +class ModuleExtender(MixinBase): + # Initialized during instantiation + modules: MutableMapping[str, module.Module] + + def __init__(self: "Bot", **kwargs: Any) -> None: + # Initialize module map + self.modules = {} + + # Propagate initialization to other mixins + super().__init__(**kwargs) + + def load_module(self: "Bot", cls: Type[module.Module], *, comment: Optional[str] = None) -> None: + _comment = comment + " " if comment else "" + self.log.info( + f"Loading {_comment}module '{cls.name}' ({cls.__name__}) from '{os.path.relpath(inspect.getfile(cls))}'" + ) + + if cls.name in self.modules: + old = type(self.modules[cls.name]) + raise module.ExistingModuleError(old, cls) + + mod = cls(self) + mod.comment = comment + self.register_listeners(mod) + self.register_commands(mod) + self.modules[cls.name] = mod + + def unload_module(self: "Bot", mod: module.Module) -> None: + _comment = mod.comment + " " if mod.comment else "" + + cls = type(mod) + path = os.path.relpath(inspect.getfile(cls)) + self.log.info(f"Unloading {_comment}module '{cls.name}' ({cls.__name__}) from '{path}'") + + self.unregister_listeners(mod) + self.unregister_commands(mod) + del self.modules[cls.name] + + def _load_modules_from_metamod(self: "Bot", metamod: ModuleType, *, comment: str = None) -> None: + for _sym in getattr(metamod, "__all__", ()): + module_mod: ModuleType = getattr(metamod, _sym) + + if inspect.ismodule(module_mod): + for sym in dir(module_mod): + cls = getattr(module_mod, sym) + if inspect.isclass(cls) and issubclass(cls, module.Module): + self.load_module(cls, comment=comment) + + def load_all_modules(self: "Bot") -> None: + self.log.info("Loading modules") + self._load_modules_from_metamod(modules) + self._load_modules_from_metamod(custom_modules, comment="custom") + self.log.info("All modules loaded.") + + def unload_all_modules(self: "Bot") -> None: + self.log.info("Unloading modules...") + + # Can't modify while iterating, so collect a list first + for mod in list(self.modules.values()): + self.unload_module(mod) + + self.log.info("All modules unloaded.") + + async def reload_module_pkg(self: "Bot") -> None: + self.log.info("Reloading base module class...") + await util.run_sync(importlib.reload, module) + + self.log.info("Reloading master module...") + await util.run_sync(importlib.reload, modules) + + self.log.info("Reloading custom master module...") + await util.run_sync(importlib.reload, custom_modules) diff --git a/pyrobud/core/telegram_bot.py b/pyrobud/core/telegram_bot.py new file mode 100644 index 00000000..46614281 --- /dev/null +++ b/pyrobud/core/telegram_bot.py @@ -0,0 +1,163 @@ +import asyncio +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union + +import sentry_sdk +import telethon as tg + +from .. import util +from .bot_mixin_base import MixinBase + +if TYPE_CHECKING: + from .bot import Bot + +TelegramConfig = Mapping[str, Union[int, str]] + + +class TelegramBot(MixinBase): + # Initialized during instantiation + tg_config: TelegramConfig + + # Initialized during startup + loop: asyncio.AbstractEventLoop + prefix: str + user: tg.types.User + uid: int + start_time_us: int + + def __init__(self: "Bot", **kwargs: Any) -> None: + # Get Telegram parameters from config and check types + self.tg_config = self.config["telegram"] + + session_name = self.tg_config["session_name"] + if not isinstance(session_name, str): + raise TypeError("Session name must be a string") + + api_id = self.tg_config["api_id"] + if not isinstance(api_id, int): + raise TypeError("API ID must be an integer") + + api_hash = self.tg_config["api_hash"] + if not isinstance(api_hash, str): + raise TypeError("API hash must be a string") + + # Initialize Telegram client with gathered parameters + self.client = tg.TelegramClient(session_name, api_id, api_hash) + + # Propagate initialization to other mixins + super().__init__(**kwargs) + + async def start(self: "Bot") -> None: + # Get and store current event loop, since this is the first coroutine that runs + self.loop = asyncio.get_event_loop() + + # Load prefix + self.prefix = await self.db.get("prefix", self.config["bot"]["default_prefix"]) + + # Load modules + self.load_all_modules() + await self.dispatch_event("load") + + # Start Telegram client + await self.client.start() + + # Get info + user = await self.client.get_me() + if not isinstance(user, tg.types.User): + raise TypeError("Missing full self user information") + self.user = user + self.uid = user.id + + # Set Sentry username if enabled + if self.config["bot"]["report_username"]: + with sentry_sdk.configure_scope() as scope: + scope.set_user({"username": self.user.username}) + + # Record start time and dispatch start event + self.start_time_us = util.time.usec() + await self.dispatch_event("start", self.start_time_us) + + # Register handlers + self.client.add_event_handler(self.on_message, tg.events.NewMessage()) + self.client.add_event_handler(self.on_message_edit, tg.events.MessageEdited()) + self.client.add_event_handler( + self.on_command, tg.events.NewMessage(outgoing=True, func=self.command_predicate), + ) + self.client.add_event_handler(self.on_chat_action, tg.events.ChatAction()) + + self.log.info("Bot is ready") + + # Catch up on missed events + self.log.info("Catching up on missed events") + await self.client.catch_up() + self.log.info("Finished catching up") + + async def on_message(self: "Bot", event: tg.events.NewMessage.Event) -> None: + await self.dispatch_event("message", event) + + async def on_message_edit(self: "Bot", event: tg.events.MessageEdited.Event) -> None: + await self.dispatch_event("message_edit", event) + + async def on_chat_action(self: "Bot", event: tg.events.ChatAction.Event) -> None: + await self.dispatch_event("chat_action", event) + + # Flexible response function with filtering, truncation, redaction, etc. + async def respond( + self: "Bot", + msg: tg.custom.Message, + text: Optional[str] = None, + *, + mode: Optional[str] = None, + redact: Optional[bool] = None, + response: Optional[tg.custom.Message] = None, + **kwargs: Any, + ) -> tg.custom.Message: + # Read redaction setting from config + if redact is None: + redact = self.config["bot"]["redact_responses"] + + # Filter text + if text is not None: + # Redact sensitive information if enabled and known + if redact: + tg_config: Mapping[str, str] = self.config["telegram"] + api_id = str(tg_config["api_id"]) + api_hash = tg_config["api_hash"] + + if api_id in text: + text = text.replace(api_id, "[REDACTED]") + if api_hash in text: + text = text.replace(api_hash, "[REDACTED]") + if self.user.phone is not None and self.user.phone in text: + text = text.replace(self.user.phone, "[REDACTED]") + + # Truncate messages longer than Telegram's 4096-character length limit + text = util.tg.truncate(text) + + # Default to disabling link previews in responses + if "link_preview" not in kwargs: + kwargs["link_preview"] = False + + # Use selected response mode if not overridden by invoker + if mode is None: + mode = self.config["bot"]["response_mode"] + + if mode == "edit": + return await msg.edit(text=text, **kwargs) + elif mode == "reply": + if response is not None: + # Already replied, so just edit the existing reply to reduce spam + return await response.edit(text=text, **kwargs) + else: + # Reply since we haven't done so yet + return await msg.reply(text, **kwargs) + elif mode == "repost": + if response is not None: + # Already reposted, so just edit the existing reply to reduce spam + return await response.edit(text=text, **kwargs) + else: + # Repost since we haven't done so yet + response = await msg.respond(text, reply_to=msg.reply_to_msg_id, **kwargs) + await msg.delete() + return response + else: + raise ValueError(f"Unknown response mode '{mode}'") diff --git a/pyrobud/launch.py b/pyrobud/launch.py index 1a39e90d..15310141 100644 --- a/pyrobud/launch.py +++ b/pyrobud/launch.py @@ -4,7 +4,7 @@ import tomlkit from . import DEFAULT_CONFIG_PATH, util -from .bot import Bot +from .core import Bot log = logging.getLogger("launch") diff --git a/pyrobud/module.py b/pyrobud/module.py index 2062d4d9..e12a6e86 100644 --- a/pyrobud/module.py +++ b/pyrobud/module.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional, Type if TYPE_CHECKING: - from .bot import Bot + from .core import Bot from .command import Command