-
-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bot: Split into multiple compartmentalized ABC mixins
This helps accomodate the growing Bot class, since it's easily divided into several logical components anyway. Each component is an ABC (Abstract Base Class) that functions as a mixin and thus can access all the other attributes of the Bot class. Bot subclasses all the mixins to unify them. Telethon uses the same approach to divide TelegramClient into many mixins that each contain different categories of API methods. An adaptive MixinBase reference that is defined as abc.ABC during runtime and Bot during type checking is used as the subclass for each mixin to work around python/mypy#5837. Furthermore, the MixinBase reference itself has its type specified as Any to work around python/mypy#2477. The only change for outsiders should be the Bot class being moved to the `core` module. Signed-off-by: Danny Lin <[email protected]>
- Loading branch information
Showing
12 changed files
with
537 additions
and
432 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bot import Bot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
|
||
import aiohttp | ||
import telethon as tg | ||
|
||
from ..util.config import Config | ||
from .command_dispatcher import CommandDispatcher | ||
from .database_provider import DatabaseProvider | ||
from .event_dispatcher import EventDispatcher | ||
from .module_extender import ModuleExtender | ||
from .telegram_bot import TelegramBot | ||
|
||
|
||
class Bot(TelegramBot, ModuleExtender, CommandDispatcher, DatabaseProvider, EventDispatcher): | ||
# Initialized during instantiation | ||
config: Config | ||
log: logging.Logger | ||
http_session: aiohttp.ClientSession | ||
client: tg.TelegramClient | ||
|
||
def __init__(self, config: Config): | ||
# Save reference to config | ||
self.config = config | ||
|
||
# Initialize other objects | ||
self.log = logging.getLogger("bot") | ||
self.http_session = aiohttp.ClientSession() | ||
|
||
# Initialize mixins | ||
super().__init__() | ||
|
||
async def stop(self) -> None: | ||
await self.dispatch_event("stop") | ||
await self.http_session.close() | ||
await self._db.close() | ||
|
||
self.log.info("Running post-stop hooks") | ||
await self.dispatch_event("stopped") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from typing import TYPE_CHECKING, Any | ||
|
||
MixinBase: Any | ||
if TYPE_CHECKING: | ||
from .bot import Bot | ||
|
||
MixinBase = Bot | ||
else: | ||
import abc | ||
|
||
MixinBase = abc.ABC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import TYPE_CHECKING, Any, MutableMapping | ||
|
||
import telethon as tg | ||
|
||
from .. import command, module, util | ||
from .bot_mixin_base import MixinBase | ||
|
||
if TYPE_CHECKING: | ||
from .bot import Bot | ||
|
||
|
||
class CommandDispatcher(MixinBase): | ||
# Initialized during instantiation | ||
commands: MutableMapping[str, command.Command] | ||
|
||
def __init__(self: "Bot", **kwargs: Any) -> None: | ||
# Initialize command map | ||
self.commands = {} | ||
|
||
# Propagate initialization to other mixins | ||
super().__init__(**kwargs) | ||
|
||
def register_command(self: "Bot", mod: module.Module, name: str, func: command.CommandFunc) -> None: | ||
cmd = command.Command(name, mod, func) | ||
|
||
if name in self.commands: | ||
orig = self.commands[name] | ||
raise module.ExistingCommandError(orig, cmd) | ||
|
||
self.commands[name] = cmd | ||
|
||
for alias in cmd.aliases: | ||
if alias in self.commands: | ||
orig = self.commands[alias] | ||
raise module.ExistingCommandError(orig, cmd, alias=True) | ||
|
||
self.commands[alias] = cmd | ||
|
||
def unregister_command(self: "Bot", cmd: command.Command) -> None: | ||
del self.commands[cmd.name] | ||
|
||
for alias in cmd.aliases: | ||
try: | ||
del self.commands[alias] | ||
except KeyError: | ||
continue | ||
|
||
def register_commands(self: "Bot", mod: module.Module) -> None: | ||
for name, func in util.find_prefixed_funcs(mod, "cmd_"): | ||
done = False | ||
|
||
try: | ||
self.register_command(mod, name, func) | ||
done = True | ||
finally: | ||
if not done: | ||
self.unregister_commands(mod) | ||
|
||
def unregister_commands(self: "Bot", mod: module.Module) -> None: | ||
# Can't unregister while iterating, so collect commands to unregister afterwards | ||
to_unreg = [] | ||
|
||
for name, cmd in self.commands.items(): | ||
# Let unregister_command deal with aliases | ||
if name != cmd.name: | ||
continue | ||
|
||
if cmd.module == mod: | ||
to_unreg.append(cmd) | ||
|
||
# Actually unregister the commands | ||
for cmd in to_unreg: | ||
self.unregister_command(cmd) | ||
|
||
def command_predicate(self: "Bot", event: tg.events.NewMessage.Event) -> bool: | ||
if event.raw_text.startswith(self.prefix): | ||
parts = event.raw_text.split() | ||
parts[0] = parts[0][len(self.prefix) :] | ||
|
||
event.segments = parts | ||
return True | ||
|
||
return False | ||
|
||
async def on_command(self: "Bot", msg: tg.events.NewMessage.Event) -> None: | ||
cmd = None | ||
|
||
try: | ||
# Attempt to get command info | ||
try: | ||
cmd = self.commands[msg.segments[0]] | ||
except KeyError: | ||
return | ||
|
||
# Construct invocation context | ||
ctx = command.Context(self, msg.message, msg.segments, len(self.prefix) + len(msg.segments[0]) + 1,) | ||
|
||
if not (cmd.usage is None or cmd.usage_optional or ctx.input): | ||
err_base = f"⚠️ Missing parameters: {cmd.usage}" | ||
|
||
if cmd.usage_reply: | ||
if msg.is_reply: | ||
reply_msg = await msg.get_reply_message() | ||
if reply_msg.text: | ||
ctx.input = reply_msg.text | ||
ctx.parsed_input = reply_msg.raw_text | ||
else: | ||
await ctx.respond(f"{err_base}\n__The message you replied to doesn't contain text.__") | ||
return | ||
else: | ||
await ctx.respond(f"{err_base} (replying is also supported)") | ||
return | ||
else: | ||
await ctx.respond(err_base) | ||
return | ||
|
||
# Invoke command function | ||
try: | ||
ret = await cmd.func(ctx) | ||
|
||
# Response shortcut | ||
if ret is not None: | ||
await ctx.respond(ret) | ||
except Exception as e: | ||
cmd.module.log.error(f"Error in command '{cmd.name}'", exc_info=e) | ||
await ctx.respond(f"⚠️ Error executing command:\n```{util.format_exception(e)}```") | ||
|
||
await self.dispatch_event("command", cmd, msg) | ||
except Exception as e: | ||
if cmd is not None: | ||
cmd.module.log.error("Error in command handler", exc_info=e) | ||
|
||
await self.respond(msg.message, f"⚠️ Error in command handler:\n```{util.format_exception(e)}```") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import plyvel | ||
|
||
from .. import util | ||
from .bot_mixin_base import MixinBase | ||
|
||
if TYPE_CHECKING: | ||
from .bot import Bot | ||
|
||
|
||
class DatabaseProvider(MixinBase): | ||
# Initialized during instantiation | ||
_db: util.db.AsyncDB | ||
db: util.db.AsyncDB | ||
|
||
def __init__(self: "Bot", **kwargs: Any) -> None: | ||
# Initialize database | ||
self._db = util.db.AsyncDB(plyvel.DB(self.config["bot"]["db_path"], create_if_missing=True)) | ||
self.db = self.get_db("bot") | ||
|
||
# Propagate initialization to other mixins | ||
super().__init__(**kwargs) | ||
|
||
def get_db(self: "Bot", prefix: str) -> util.db.AsyncDB: | ||
return self._db.prefixed_db(prefix + ".") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import asyncio | ||
import bisect | ||
from typing import TYPE_CHECKING, Any, MutableMapping, MutableSequence | ||
|
||
from .. import module, util | ||
from ..listener import Listener, ListenerFunc | ||
from .bot_mixin_base import MixinBase | ||
|
||
if TYPE_CHECKING: | ||
from .bot import Bot | ||
|
||
|
||
class EventDispatcher(MixinBase): | ||
# Initialized during instantiation | ||
listeners: MutableMapping[str, MutableSequence[Listener]] | ||
|
||
def __init__(self: "Bot", **kwargs: Any) -> None: | ||
# Initialize listener map | ||
self.listeners = {} | ||
|
||
# Propagate initialization to other mixins | ||
super().__init__(**kwargs) | ||
|
||
def register_listener(self: "Bot", mod: module.Module, event: str, func: ListenerFunc, priority: int = 100) -> None: | ||
listener = Listener(event, func, mod, priority) | ||
|
||
if event in self.listeners: | ||
bisect.insort(self.listeners[event], listener) | ||
else: | ||
self.listeners[event] = [listener] | ||
|
||
def unregister_listener(self: "Bot", listener: Listener) -> None: | ||
self.listeners[listener.event].remove(listener) | ||
|
||
def register_listeners(self: "Bot", mod: module.Module) -> None: | ||
for event, func in util.find_prefixed_funcs(mod, "on_"): | ||
done = True | ||
try: | ||
self.register_listener(mod, event, func, priority=getattr(func, "_listener_priority", 100)) | ||
done = True | ||
finally: | ||
if not done: | ||
self.unregister_listeners(mod) | ||
|
||
def unregister_listeners(self: "Bot", mod: module.Module) -> None: | ||
# Can't unregister while iterating, so collect listeners to unregister afterwards | ||
to_unreg = [] | ||
|
||
for lst in self.listeners.values(): | ||
for listener in lst: | ||
if listener.module == mod: | ||
to_unreg.append(listener) | ||
|
||
# Actually unregister the listeners | ||
for listener in to_unreg: | ||
self.unregister_listener(listener) | ||
|
||
async def dispatch_event(self: "Bot", event: str, *args: Any, **kwargs: Any) -> None: | ||
tasks = set() | ||
|
||
try: | ||
listeners = self.listeners[event] | ||
except KeyError: | ||
return None | ||
|
||
if not listeners: | ||
return | ||
|
||
for lst in listeners: | ||
task = self.loop.create_task(lst.func(*args, **kwargs)) | ||
tasks.add(task) | ||
|
||
self.log.debug(f"Dispatching event '{event}' with data {args}") | ||
await asyncio.wait(tasks) | ||
|
||
def dispatch_event_nowait(self: "Bot", event: str, *args: Any, **kwargs: Any) -> None: | ||
self.loop.create_task(self.dispatch_event(event, *args, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import importlib | ||
import inspect | ||
import os | ||
from types import ModuleType | ||
from typing import TYPE_CHECKING, Any, MutableMapping, Optional, Type | ||
|
||
from .. import custom_modules, module, modules, util | ||
from .bot_mixin_base import MixinBase | ||
|
||
if TYPE_CHECKING: | ||
from .bot import Bot | ||
|
||
|
||
class ModuleExtender(MixinBase): | ||
# Initialized during instantiation | ||
modules: MutableMapping[str, module.Module] | ||
|
||
def __init__(self: "Bot", **kwargs: Any) -> None: | ||
# Initialize module map | ||
self.modules = {} | ||
|
||
# Propagate initialization to other mixins | ||
super().__init__(**kwargs) | ||
|
||
def load_module(self: "Bot", cls: Type[module.Module], *, comment: Optional[str] = None) -> None: | ||
_comment = comment + " " if comment else "" | ||
self.log.info( | ||
f"Loading {_comment}module '{cls.name}' ({cls.__name__}) from '{os.path.relpath(inspect.getfile(cls))}'" | ||
) | ||
|
||
if cls.name in self.modules: | ||
old = type(self.modules[cls.name]) | ||
raise module.ExistingModuleError(old, cls) | ||
|
||
mod = cls(self) | ||
mod.comment = comment | ||
self.register_listeners(mod) | ||
self.register_commands(mod) | ||
self.modules[cls.name] = mod | ||
|
||
def unload_module(self: "Bot", mod: module.Module) -> None: | ||
_comment = mod.comment + " " if mod.comment else "" | ||
|
||
cls = type(mod) | ||
path = os.path.relpath(inspect.getfile(cls)) | ||
self.log.info(f"Unloading {_comment}module '{cls.name}' ({cls.__name__}) from '{path}'") | ||
|
||
self.unregister_listeners(mod) | ||
self.unregister_commands(mod) | ||
del self.modules[cls.name] | ||
|
||
def _load_modules_from_metamod(self: "Bot", metamod: ModuleType, *, comment: str = None) -> None: | ||
for _sym in getattr(metamod, "__all__", ()): | ||
module_mod: ModuleType = getattr(metamod, _sym) | ||
|
||
if inspect.ismodule(module_mod): | ||
for sym in dir(module_mod): | ||
cls = getattr(module_mod, sym) | ||
if inspect.isclass(cls) and issubclass(cls, module.Module): | ||
self.load_module(cls, comment=comment) | ||
|
||
def load_all_modules(self: "Bot") -> None: | ||
self.log.info("Loading modules") | ||
self._load_modules_from_metamod(modules) | ||
self._load_modules_from_metamod(custom_modules, comment="custom") | ||
self.log.info("All modules loaded.") | ||
|
||
def unload_all_modules(self: "Bot") -> None: | ||
self.log.info("Unloading modules...") | ||
|
||
# Can't modify while iterating, so collect a list first | ||
for mod in list(self.modules.values()): | ||
self.unload_module(mod) | ||
|
||
self.log.info("All modules unloaded.") | ||
|
||
async def reload_module_pkg(self: "Bot") -> None: | ||
self.log.info("Reloading base module class...") | ||
await util.run_sync(importlib.reload, module) | ||
|
||
self.log.info("Reloading master module...") | ||
await util.run_sync(importlib.reload, modules) | ||
|
||
self.log.info("Reloading custom master module...") | ||
await util.run_sync(importlib.reload, custom_modules) |
Oops, something went wrong.