Skip to content

Commit

Permalink
bot: Split into multiple compartmentalized ABC mixins
Browse files Browse the repository at this point in the history
This helps accomodate the growing Bot class, since it's easily divided
into several logical components anyway. Each component is an ABC
(Abstract Base Class) that functions as a mixin and thus can access all
the other attributes of the Bot class. Bot subclasses all the mixins to
unify them. Telethon uses the same approach to divide TelegramClient
into many mixins that each contain different categories of API methods.

An adaptive MixinBase reference that is defined as abc.ABC during
runtime and Bot during type checking is used as the subclass for each
mixin to work around python/mypy#5837.
Furthermore, the MixinBase reference itself has its type specified as
Any to work around python/mypy#2477.

The only change for outsiders should be the Bot class being moved to the
`core` module.

Signed-off-by: Danny Lin <[email protected]>
  • Loading branch information
kdrag0n committed Dec 15, 2019
1 parent 1b8f07a commit 6edfb62
Show file tree
Hide file tree
Showing 12 changed files with 537 additions and 432 deletions.
429 changes: 0 additions & 429 deletions pyrobud/bot.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyrobud/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import telethon as tg

if TYPE_CHECKING:
from .bot import Bot
from .core import Bot

CommandFunc = Union[
Callable[..., Coroutine[Any, Any, None]], Callable[..., Coroutine[Any, Any, Optional[str]]],
Expand Down
1 change: 1 addition & 0 deletions pyrobud/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bot import Bot
38 changes: 38 additions & 0 deletions pyrobud/core/bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import aiohttp
import telethon as tg

from ..util.config import Config
from .command_dispatcher import CommandDispatcher
from .database_provider import DatabaseProvider
from .event_dispatcher import EventDispatcher
from .module_extender import ModuleExtender
from .telegram_bot import TelegramBot


class Bot(TelegramBot, ModuleExtender, CommandDispatcher, DatabaseProvider, EventDispatcher):
# Initialized during instantiation
config: Config
log: logging.Logger
http_session: aiohttp.ClientSession
client: tg.TelegramClient

def __init__(self, config: Config):
# Save reference to config
self.config = config

# Initialize other objects
self.log = logging.getLogger("bot")
self.http_session = aiohttp.ClientSession()

# Initialize mixins
super().__init__()

async def stop(self) -> None:
await self.dispatch_event("stop")
await self.http_session.close()
await self._db.close()

self.log.info("Running post-stop hooks")
await self.dispatch_event("stopped")
11 changes: 11 additions & 0 deletions pyrobud/core/bot_mixin_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import TYPE_CHECKING, Any

MixinBase: Any
if TYPE_CHECKING:
from .bot import Bot

MixinBase = Bot
else:
import abc

MixinBase = abc.ABC
133 changes: 133 additions & 0 deletions pyrobud/core/command_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import TYPE_CHECKING, Any, MutableMapping

import telethon as tg

from .. import command, module, util
from .bot_mixin_base import MixinBase

if TYPE_CHECKING:
from .bot import Bot


class CommandDispatcher(MixinBase):
# Initialized during instantiation
commands: MutableMapping[str, command.Command]

def __init__(self: "Bot", **kwargs: Any) -> None:
# Initialize command map
self.commands = {}

# Propagate initialization to other mixins
super().__init__(**kwargs)

def register_command(self: "Bot", mod: module.Module, name: str, func: command.CommandFunc) -> None:
cmd = command.Command(name, mod, func)

if name in self.commands:
orig = self.commands[name]
raise module.ExistingCommandError(orig, cmd)

self.commands[name] = cmd

for alias in cmd.aliases:
if alias in self.commands:
orig = self.commands[alias]
raise module.ExistingCommandError(orig, cmd, alias=True)

self.commands[alias] = cmd

def unregister_command(self: "Bot", cmd: command.Command) -> None:
del self.commands[cmd.name]

for alias in cmd.aliases:
try:
del self.commands[alias]
except KeyError:
continue

def register_commands(self: "Bot", mod: module.Module) -> None:
for name, func in util.find_prefixed_funcs(mod, "cmd_"):
done = False

try:
self.register_command(mod, name, func)
done = True
finally:
if not done:
self.unregister_commands(mod)

def unregister_commands(self: "Bot", mod: module.Module) -> None:
# Can't unregister while iterating, so collect commands to unregister afterwards
to_unreg = []

for name, cmd in self.commands.items():
# Let unregister_command deal with aliases
if name != cmd.name:
continue

if cmd.module == mod:
to_unreg.append(cmd)

# Actually unregister the commands
for cmd in to_unreg:
self.unregister_command(cmd)

def command_predicate(self: "Bot", event: tg.events.NewMessage.Event) -> bool:
if event.raw_text.startswith(self.prefix):
parts = event.raw_text.split()
parts[0] = parts[0][len(self.prefix) :]

event.segments = parts
return True

return False

async def on_command(self: "Bot", msg: tg.events.NewMessage.Event) -> None:
cmd = None

try:
# Attempt to get command info
try:
cmd = self.commands[msg.segments[0]]
except KeyError:
return

# Construct invocation context
ctx = command.Context(self, msg.message, msg.segments, len(self.prefix) + len(msg.segments[0]) + 1,)

if not (cmd.usage is None or cmd.usage_optional or ctx.input):
err_base = f"⚠️ Missing parameters: {cmd.usage}"

if cmd.usage_reply:
if msg.is_reply:
reply_msg = await msg.get_reply_message()
if reply_msg.text:
ctx.input = reply_msg.text
ctx.parsed_input = reply_msg.raw_text
else:
await ctx.respond(f"{err_base}\n__The message you replied to doesn't contain text.__")
return
else:
await ctx.respond(f"{err_base} (replying is also supported)")
return
else:
await ctx.respond(err_base)
return

# Invoke command function
try:
ret = await cmd.func(ctx)

# Response shortcut
if ret is not None:
await ctx.respond(ret)
except Exception as e:
cmd.module.log.error(f"Error in command '{cmd.name}'", exc_info=e)
await ctx.respond(f"⚠️ Error executing command:\n```{util.format_exception(e)}```")

await self.dispatch_event("command", cmd, msg)
except Exception as e:
if cmd is not None:
cmd.module.log.error("Error in command handler", exc_info=e)

await self.respond(msg.message, f"⚠️ Error in command handler:\n```{util.format_exception(e)}```")
26 changes: 26 additions & 0 deletions pyrobud/core/database_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING, Any

import plyvel

from .. import util
from .bot_mixin_base import MixinBase

if TYPE_CHECKING:
from .bot import Bot


class DatabaseProvider(MixinBase):
# Initialized during instantiation
_db: util.db.AsyncDB
db: util.db.AsyncDB

def __init__(self: "Bot", **kwargs: Any) -> None:
# Initialize database
self._db = util.db.AsyncDB(plyvel.DB(self.config["bot"]["db_path"], create_if_missing=True))
self.db = self.get_db("bot")

# Propagate initialization to other mixins
super().__init__(**kwargs)

def get_db(self: "Bot", prefix: str) -> util.db.AsyncDB:
return self._db.prefixed_db(prefix + ".")
77 changes: 77 additions & 0 deletions pyrobud/core/event_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import asyncio
import bisect
from typing import TYPE_CHECKING, Any, MutableMapping, MutableSequence

from .. import module, util
from ..listener import Listener, ListenerFunc
from .bot_mixin_base import MixinBase

if TYPE_CHECKING:
from .bot import Bot


class EventDispatcher(MixinBase):
# Initialized during instantiation
listeners: MutableMapping[str, MutableSequence[Listener]]

def __init__(self: "Bot", **kwargs: Any) -> None:
# Initialize listener map
self.listeners = {}

# Propagate initialization to other mixins
super().__init__(**kwargs)

def register_listener(self: "Bot", mod: module.Module, event: str, func: ListenerFunc, priority: int = 100) -> None:
listener = Listener(event, func, mod, priority)

if event in self.listeners:
bisect.insort(self.listeners[event], listener)
else:
self.listeners[event] = [listener]

def unregister_listener(self: "Bot", listener: Listener) -> None:
self.listeners[listener.event].remove(listener)

def register_listeners(self: "Bot", mod: module.Module) -> None:
for event, func in util.find_prefixed_funcs(mod, "on_"):
done = True
try:
self.register_listener(mod, event, func, priority=getattr(func, "_listener_priority", 100))
done = True
finally:
if not done:
self.unregister_listeners(mod)

def unregister_listeners(self: "Bot", mod: module.Module) -> None:
# Can't unregister while iterating, so collect listeners to unregister afterwards
to_unreg = []

for lst in self.listeners.values():
for listener in lst:
if listener.module == mod:
to_unreg.append(listener)

# Actually unregister the listeners
for listener in to_unreg:
self.unregister_listener(listener)

async def dispatch_event(self: "Bot", event: str, *args: Any, **kwargs: Any) -> None:
tasks = set()

try:
listeners = self.listeners[event]
except KeyError:
return None

if not listeners:
return

for lst in listeners:
task = self.loop.create_task(lst.func(*args, **kwargs))
tasks.add(task)

self.log.debug(f"Dispatching event '{event}' with data {args}")
await asyncio.wait(tasks)

def dispatch_event_nowait(self: "Bot", event: str, *args: Any, **kwargs: Any) -> None:
self.loop.create_task(self.dispatch_event(event, *args, **kwargs))
85 changes: 85 additions & 0 deletions pyrobud/core/module_extender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import importlib
import inspect
import os
from types import ModuleType
from typing import TYPE_CHECKING, Any, MutableMapping, Optional, Type

from .. import custom_modules, module, modules, util
from .bot_mixin_base import MixinBase

if TYPE_CHECKING:
from .bot import Bot


class ModuleExtender(MixinBase):
# Initialized during instantiation
modules: MutableMapping[str, module.Module]

def __init__(self: "Bot", **kwargs: Any) -> None:
# Initialize module map
self.modules = {}

# Propagate initialization to other mixins
super().__init__(**kwargs)

def load_module(self: "Bot", cls: Type[module.Module], *, comment: Optional[str] = None) -> None:
_comment = comment + " " if comment else ""
self.log.info(
f"Loading {_comment}module '{cls.name}' ({cls.__name__}) from '{os.path.relpath(inspect.getfile(cls))}'"
)

if cls.name in self.modules:
old = type(self.modules[cls.name])
raise module.ExistingModuleError(old, cls)

mod = cls(self)
mod.comment = comment
self.register_listeners(mod)
self.register_commands(mod)
self.modules[cls.name] = mod

def unload_module(self: "Bot", mod: module.Module) -> None:
_comment = mod.comment + " " if mod.comment else ""

cls = type(mod)
path = os.path.relpath(inspect.getfile(cls))
self.log.info(f"Unloading {_comment}module '{cls.name}' ({cls.__name__}) from '{path}'")

self.unregister_listeners(mod)
self.unregister_commands(mod)
del self.modules[cls.name]

def _load_modules_from_metamod(self: "Bot", metamod: ModuleType, *, comment: str = None) -> None:
for _sym in getattr(metamod, "__all__", ()):
module_mod: ModuleType = getattr(metamod, _sym)

if inspect.ismodule(module_mod):
for sym in dir(module_mod):
cls = getattr(module_mod, sym)
if inspect.isclass(cls) and issubclass(cls, module.Module):
self.load_module(cls, comment=comment)

def load_all_modules(self: "Bot") -> None:
self.log.info("Loading modules")
self._load_modules_from_metamod(modules)
self._load_modules_from_metamod(custom_modules, comment="custom")
self.log.info("All modules loaded.")

def unload_all_modules(self: "Bot") -> None:
self.log.info("Unloading modules...")

# Can't modify while iterating, so collect a list first
for mod in list(self.modules.values()):
self.unload_module(mod)

self.log.info("All modules unloaded.")

async def reload_module_pkg(self: "Bot") -> None:
self.log.info("Reloading base module class...")
await util.run_sync(importlib.reload, module)

self.log.info("Reloading master module...")
await util.run_sync(importlib.reload, modules)

self.log.info("Reloading custom master module...")
await util.run_sync(importlib.reload, custom_modules)
Loading

0 comments on commit 6edfb62

Please sign in to comment.