diff --git a/bot/database.py b/bot/database.py index 713c6493..fcd78c07 100644 --- a/bot/database.py +++ b/bot/database.py @@ -2,6 +2,7 @@ import datetime as dt import logging +import uuid from typing import Iterator, Mapping, Sequence import databases @@ -13,7 +14,7 @@ from sqlalchemy import sql from sqlalchemy.dialects.postgresql import BIGINT from sqlalchemy.dialects.postgresql import TIMESTAMP as _TIMESTAMP -from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.dialects.postgresql import UUID, insert from sqlalchemy.sql.schema import ForeignKey from . import settings @@ -46,10 +47,20 @@ class TIMESTAMP(sa.TypeDecorator): impl = _TIMESTAMP(timezone=True) +def id_column(name="id", **kwargs): + return sa.Column( + name, UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, **kwargs + ) + + def created_at_column(name="created_at", **kwargs): return sa.Column(name, TIMESTAMP, nullable=False, default=now, **kwargs) +def updated_at_column(name="updated_at", **kwargs): + return sa.Column(name, TIMESTAMP, nullable=False, default=now, onupdate=now, **kwargs) + + NANOID_ALPHABET = "23456789abcdefghijkmnopqrstuvwxyz" NANOID_SIZE = 5 @@ -235,6 +246,26 @@ def generate_nanoid() -> str: created_at_column(), ) +user_stars = sa.Table( + "user_stars", + metadata, + sa.Column("user_id", BIGINT, primary_key=True, doc="Discord user ID"), + sa.Column("star_count", sa.Integer, doc="Number of stars for the user"), + created_at_column(), + updated_at_column(), +) + +star_logs = sa.Table( + "star_logs", + metadata, + id_column(), + sa.Column("from_user_id", BIGINT, doc="Discord user ID of the giver"), + sa.Column("to_user_id", BIGINT, doc="Discord user ID of the recipient"), + sa.Column("message_id", BIGINT, doc="Discord message ID"), + sa.Column("action", sa.Text, nullable=False, doc="The type of action"), + created_at_column(), +) + # ----------------------------------------------------------------------------- @@ -690,6 +721,96 @@ async def remove_scheduled_event(self, *, event_id: int): scheduled_events.delete().where(scheduled_events.c.event_id == event_id) ) + ##### Stars ##### + + async def give_star(self, *, from_user_id: int, to_user_id: int, message_id: int): + created_at = now() + # Insert a star log + stmt = insert(star_logs).values( + id=uuid.uuid4(), + from_user_id=from_user_id, + to_user_id=to_user_id, + message_id=message_id, + created_at=created_at, + action="ADD", + ) + await self.db.execute(stmt) + + # Update the user's star count + stmt = insert(user_stars).values( + user_id=to_user_id, + star_count=1, + created_at=created_at, + updated_at=created_at, + ) + stmt = stmt.on_conflict_do_update( + index_elements=(user_stars.c.user_id,), + set_=dict(star_count=user_stars.c.star_count + 1, updated_at=created_at), + ) + await self.db.execute(stmt) + + async def remove_star(self, *, from_user_id: int, to_user_id: int, message_id: int): + created_at = now() + # Insert a star log + stmt = insert(star_logs).values( + id=uuid.uuid4(), + from_user_id=from_user_id, + to_user_id=to_user_id, + message_id=message_id, + created_at=created_at, + action="REMOVE", + ) + await self.db.execute(stmt) + + # Update the user's star count + stmt = insert(user_stars).values( + user_id=to_user_id, + star_count=0, + created_at=created_at, + updated_at=created_at, + ) + stmt = stmt.on_conflict_do_update( + index_elements=(user_stars.c.user_id,), + set_=dict(star_count=user_stars.c.star_count - 1, updated_at=created_at), + ) + await self.db.execute(stmt) + + async def get_user_stars(self, user_id: int) -> int: + query = user_stars.select().where(user_stars.c.user_id == user_id) + return await self.db.fetch_val(query=query, column=user_stars.c.star_count) or 0 + + async def set_user_stars( + self, *, from_user_id: int, to_user_id: int, star_count: int + ): + created_at = now() + # Insert a star log + stmt = insert(star_logs).values( + id=uuid.uuid4(), + from_user_id=from_user_id, + to_user_id=to_user_id, + message_id=None, + created_at=created_at, + action="SET", + ) + await self.db.execute(stmt) + # Update the user's star count + stmt = insert(user_stars).values( + user_id=to_user_id, + star_count=star_count, + created_at=created_at, + updated_at=created_at, + ) + stmt = stmt.on_conflict_do_update( + index_elements=(user_stars.c.user_id,), + set_=dict(star_count=stmt.excluded.star_count, updated_at=created_at), + ) + await self.db.execute(stmt) + + async def list_user_stars(self, limit: int) -> list[Mapping]: + return await self.db.fetch_all( + user_stars.select().order_by(user_stars.c.star_count.desc()).limit(limit) + ) + store = Store( database_url=settings.TEST_DATABASE_URL diff --git a/bot/exts/stars.py b/bot/exts/stars.py new file mode 100644 index 00000000..f203208f --- /dev/null +++ b/bot/exts/stars.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +from typing import cast + +import disnake +from disnake import Embed, GuildCommandInteraction +from disnake.ext import commands +from disnake.ext.commands import Bot, Cog, Context, slash_command + +from bot import settings +from bot.database import store +from bot.utils.discord import display_name +from bot.utils.reactions import get_reaction_message, should_handle_reaction + +logger = logging.getLogger(__name__) + +COMMAND_PREFIX = settings.COMMAND_PREFIX +STAR_EMOJI = "⭐" + + +class Stars(Cog): + def __init__(self, bot: commands.Bot): + self.bot = bot + + def cog_check(self, ctx: Context): + if not bool(ctx.guild) or ctx.guild.id != settings.SIGN_CAFE_GUILD_ID: + raise commands.errors.CheckFailure( + f"⚠️ `{COMMAND_PREFIX}{ctx.invoked_with}` must be run within the Sign Cafe server (not a DM)." + ) + return True + + @slash_command(name="stars", guild_ids=(settings.SIGN_CAFE_GUILD_ID,)) + async def stars_command(self, inter: GuildCommandInteraction): + pass + + @stars_command.sub_command(name="set") + @commands.has_permissions(kick_members=True) # Staff + async def stars_set( + self, inter: GuildCommandInteraction, user: disnake.User, stars: int + ): + """(Authorized users only) Set a user's star count + + Parameters + ---------- + user: The user to change + stars: The star count to set + """ + assert inter.user is not None + async with store.transaction(): + await store.set_user_stars( + from_user_id=inter.user.id, + to_user_id=user.id, + star_count=stars, + ) + embed = Embed( + description=f"Set star count for {user.mention}", color=disnake.Color.yellow() + ) + user_stars = await store.get_user_stars(user.id) + embed.add_field(name=f"{STAR_EMOJI} count", value=str(user_stars)) + embed.set_author( + name=display_name(user), + icon_url=user.avatar.url if user.avatar else Embed.Empty, + ) + await inter.response.send_message(embed=embed) + + @stars_command.sub_command(name="board") + async def stars_board(self, inter: GuildCommandInteraction): + """Show the star leaderboard""" + records = await store.list_user_stars(limit=10) + embed = Embed( + title=f"{STAR_EMOJI} Leaderboard", + description="\n".join( + [ + f"{i+1}. <@{record['user_id']}> | {record['star_count']} {STAR_EMOJI}" + for i, record in enumerate(records) + ] + ), + color=disnake.Color.yellow(), + ) + await inter.response.send_message(embed=embed) + + async def maybe_get_star_reaction_message_and_user( + self, + payload: disnake.RawReactionActionEvent, + ) -> tuple[disnake.Message | None, disnake.Member | None]: + if not settings.SIGN_CAFE_ENABLE_STARS: + return None, None + if not should_handle_reaction(self.bot, payload, {STAR_EMOJI}): + return None, None + message = await get_reaction_message(self.bot, payload) + if not message: + return None, None + if not message.guild: + return None, None + if not message.guild.id == settings.SIGN_CAFE_GUILD_ID: + return None, None + if bool(getattr(message.author, "bot", None)): # User is a bot + return None, None + channel = cast(disnake.TextChannel, message.channel) + if not channel.guild: + return None, None + from_user = await channel.guild.get_or_fetch_member(payload.user_id) + if not from_user: + return None, None + permissions = channel.permissions_for(from_user) + is_staff = getattr(permissions, "kick_members", False) is True + if not is_staff: + return None, None + return message, from_user + + @Cog.listener() + async def on_raw_reaction_add(self, payload: disnake.RawReactionActionEvent) -> None: + message, from_user = await self.maybe_get_star_reaction_message_and_user(payload) + if message is None or from_user is None: + return + + to_user = message.author + + async with store.transaction(): + await store.give_star( + from_user_id=from_user.id, + to_user_id=to_user.id, + message_id=message.id, + ) + channel = cast( + disnake.TextChannel, self.bot.get_channel(settings.SIGN_CAFE_BOT_CHANNEL_ID) + ) + embed = Embed( + description=f"{to_user.mention} received a {STAR_EMOJI} from {from_user.mention}\n[Source message]({message.jump_url})", + color=disnake.Color.yellow(), + ) + user_stars = await store.get_user_stars(to_user.id) + embed.add_field(name=f"{STAR_EMOJI} count", value=str(user_stars)) + embed.set_author( + name=display_name(to_user), + icon_url=to_user.avatar.url if to_user.avatar else Embed.Empty, + ) + await channel.send(embed=embed) + + @Cog.listener() + async def on_raw_reaction_remove( + self, payload: disnake.RawReactionActionEvent + ) -> None: + message, from_user = await self.maybe_get_star_reaction_message_and_user(payload) + if message is None or from_user is None: + return + + to_user = message.author + + async with store.transaction(): + await store.remove_star( + from_user_id=from_user.id, + to_user_id=to_user.id, + message_id=message.id, + ) + channel = cast( + disnake.TextChannel, self.bot.get_channel(settings.SIGN_CAFE_BOT_CHANNEL_ID) + ) + embed = Embed( + description=f"{to_user.mention} removed a {STAR_EMOJI} from {from_user.mention}\n[Source message]({message.jump_url})", + color=disnake.Color.yellow(), + ) + user_stars = await store.get_user_stars(to_user.id) + embed.add_field(name=f"{STAR_EMOJI} count", value=str(user_stars)) + embed.set_author( + name=display_name(to_user), + icon_url=to_user.avatar.url if to_user.avatar else Embed.Empty, + ) + await channel.send(embed=embed) + + +def setup(bot: Bot) -> None: + bot.add_cog(Stars(bot)) diff --git a/bot/settings.py b/bot/settings.py index e403b1c2..d75db979 100644 --- a/bot/settings.py +++ b/bot/settings.py @@ -56,6 +56,7 @@ ) SIGN_CAFE_AGE_ROLE_IDS = env.list("SIGN_CAFE_AGE_ROLE_IDS", subcast=int) SIGN_CAFE_ENABLE_UNMUTE_WARNING = env.bool("SIGN_CAFE_ENABLE_UNMUTE_WARNING", True) +SIGN_CAFE_ENABLE_STARS = env.bool("SIGN_CAFE_ENABLE_STARS", True) SIGN_CAFE_INACTIVE_DAYS = env.int("SIGN_CAFE_INACTIVE_DAYS", 30) SIGN_CAFE_PRUNE_DAYS = env.int("SIGN_CAFE_PRUNE_DAYS", 30) SIGN_CAFE_ZOOM_WATCH_LIST = env.list("SIGN_CAFE_ZOOM_WATCH_LIST", default=[], subcast=str) diff --git a/migrations/versions/273609b17e5b_add_star_tables.py b/migrations/versions/273609b17e5b_add_star_tables.py new file mode 100644 index 00000000..51c7eb05 --- /dev/null +++ b/migrations/versions/273609b17e5b_add_star_tables.py @@ -0,0 +1,47 @@ +"""add star tables + +Revision ID: 273609b17e5b +Revises: 697c0fc1f5a6 +Create Date: 2022-02-20 19:13:49.695248 + +""" +from alembic import op +import sqlalchemy as sa +import bot +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "273609b17e5b" +down_revision = "697c0fc1f5a6" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "star_logs", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("from_user_id", sa.BIGINT(), nullable=True), + sa.Column("to_user_id", sa.BIGINT(), nullable=True), + sa.Column("message_id", sa.BIGINT(), nullable=True), + sa.Column("action", sa.Text(), nullable=False), + sa.Column("created_at", bot.database.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_stars", + sa.Column("user_id", sa.BIGINT(), nullable=False), + sa.Column("star_count", sa.Integer(), nullable=True), + sa.Column("created_at", bot.database.TIMESTAMP(timezone=True), nullable=False), + sa.Column("updated_at", bot.database.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_stars") + op.drop_table("star_logs") + # ### end Alembic commands ###