diff --git a/bot/database.py b/bot/database.py index 78045867..73a56deb 100644 --- a/bot/database.py +++ b/bot/database.py @@ -1,11 +1,10 @@ +from __future__ import annotations + import datetime as dt import logging from typing import Iterator -from typing import List from typing import Mapping -from typing import Optional from typing import Sequence -from typing import Union import databases import nanoid @@ -188,6 +187,19 @@ def generate_nanoid() -> str: sa.Column("last_synced_at", TIMESTAMP), ) +scheduled_events = sa.Table( + "scheduled_events", + metadata, + sa.Column("event_id", BIGINT, primary_key=True, doc="Discord event ID"), + sa.Column( + "created_by", + BIGINT, + index=True, + doc="Discord user ID for the user who created the event through the bot", + ), + created_at_column(), +) + # ----------------------------------------------------------------------------- @@ -196,7 +208,7 @@ class Store: def __init__( self, - database_url: Union[str, databases.DatabaseURL], + database_url: str | databases.DatabaseURL, *, force_rollback: bool = False, ): @@ -211,7 +223,7 @@ def disconnect(self): def transaction(self): return self.db.transaction() - async def set_user_timezone(self, user_id: int, timezone: Optional[dt.tzinfo]): + async def set_user_timezone(self, user_id: int, timezone: dt.tzinfo | None): logger.info(f"setting timezone for user_id {user_id}") stmt = insert(user_settings).values(user_id=user_id, timezone=timezone) stmt = stmt.on_conflict_do_update( @@ -220,23 +232,23 @@ async def set_user_timezone(self, user_id: int, timezone: Optional[dt.tzinfo]): ) await self.db.execute(stmt) - async def get_user_timezone(self, user_id: int) -> Optional[pytz.BaseTzInfo]: + async def get_user_timezone(self, user_id: int) -> pytz.BaseTzInfo | None: logger.info(f"retrieving timezone for user_id {user_id}") query = user_settings.select().where(user_settings.c.user_id == user_id) return await self.db.fetch_val(query=query, column=user_settings.c.timezone) - async def get_guild_settings(self, guild_id: int) -> Optional[Mapping]: + async def get_guild_settings(self, guild_id: int) -> Mapping | None: logger.info(f"retrieving guild settings sheet key for guild_id {guild_id}") query = guild_settings.select().where(guild_settings.c.guild_id == guild_id) return await self.db.fetch_one(query=query) - async def get_guild_schedule_sheet_key(self, guild_id: int) -> Optional[str]: + async def get_guild_schedule_sheet_key(self, guild_id: int) -> str | None: query = guild_settings.select().where(guild_settings.c.guild_id == guild_id) return await self.db.fetch_val( query=query, column=guild_settings.c.schedule_sheet_key ) - async def get_guild_daily_message_channel_id(self, guild_id: int) -> Optional[int]: + async def get_guild_daily_message_channel_id(self, guild_id: int) -> int | None: query = guild_settings.select().where(guild_settings.c.guild_id == guild_id) return await self.db.fetch_val( query=query, column=guild_settings.c.daily_message_channel_id @@ -300,13 +312,13 @@ async def create_zoom_meeting( ) await self.db.execute(stmt) - async def get_zoom_meeting(self, meeting_id: int) -> Optional[Mapping]: + async def get_zoom_meeting(self, meeting_id: int) -> Mapping | None: query = zoom_meetings.select().where(zoom_meetings.c.meeting_id == meeting_id) return await self.db.fetch_one(query=query) async def get_latest_pending_zoom_meeting_for_user( self, zoom_user: str - ) -> Optional[Mapping]: + ) -> Mapping | None: query = ( zoom_meetings.select() .where( @@ -370,12 +382,12 @@ async def remove_zoom_message(self, *, message_id: int): zoom_messages.delete().where(zoom_messages.c.message_id == message_id) ) - async def get_zoom_message(self, message_id: int) -> Optional[Mapping]: + async def get_zoom_message(self, message_id: int) -> Mapping | None: return await self.db.fetch_one( zoom_messages.select().where(zoom_messages.c.message_id == message_id) ) - async def get_zoom_messages(self, meeting_id: int) -> List[Mapping]: + async def get_zoom_messages(self, meeting_id: int) -> list[Mapping]: return await self.db.fetch_all( zoom_messages.select().where(zoom_messages.c.meeting_id == meeting_id) ) @@ -385,8 +397,8 @@ async def add_zoom_participant( *, meeting_id: int, name: str, - zoom_id: Optional[str], - email: Optional[str], + zoom_id: str | None, + email: str | None, joined_at: dt.datetime, ): stmt = insert(zoom_participants).values( @@ -409,16 +421,14 @@ async def add_zoom_participant( ) await self.db.execute(stmt) - async def get_zoom_participant( - self, *, meeting_id: int, name: str - ) -> Optional[Mapping]: + async def get_zoom_participant(self, *, meeting_id: int, name: str) -> Mapping | None: query = zoom_participants.select().where( (zoom_participants.c.meeting_id == meeting_id) & (zoom_participants.c.name == name) ) return await self.db.fetch_one(query=query) - async def get_zoom_participants(self, meeting_id: int) -> List[Mapping]: + async def get_zoom_participants(self, meeting_id: int) -> list[Mapping]: return await self.db.fetch_all( zoom_participants.select() .where(zoom_participants.c.meeting_id == meeting_id) @@ -521,7 +531,7 @@ async def clear_aslpp_members(self): async def clear_aslpp_intros(self): await self.db.execute(aslpp_intros.delete()) - async def get_aslpp_member(self, user_id: int) -> Optional[Mapping]: + async def get_aslpp_member(self, user_id: int) -> Mapping | None: return await self.db.fetch_one( aslpp_members.select().where(aslpp_members.c.user_id == user_id) ) @@ -550,7 +560,7 @@ async def remove_aslpp_member(self, *, user_id: int): aslpp_members.delete().where(aslpp_members.c.user_id == user_id) ) - async def get_aslpp_members_without_intro(self, since: dt.timedelta) -> List[Mapping]: + async def get_aslpp_members_without_intro(self, since: dt.timedelta) -> list[Mapping]: return await self.db.fetch_all( aslpp_members.select() .where( @@ -576,20 +586,42 @@ async def has_aslpp_intro(self, user_id: int) -> bool: return False return record["result"] - async def mark_aslpp_members_active(self, user_ids: List[int]): + async def mark_aslpp_members_active(self, user_ids: list[int]): await self.db.execute( aslpp_members.update() .where(aslpp_members.c.user_id.in_(user_ids)) .values(is_active=True) ) - async def mark_aslpp_members_inactive(self, user_ids: List[int]): + async def mark_aslpp_members_inactive(self, user_ids: list[int]): await self.db.execute( aslpp_members.update() .where(aslpp_members.c.user_id.in_(user_ids)) .values(is_active=False) ) + # Scheduled events + + async def create_scheduled_event(self, *, event_id: int, created_by: int): + await self.db.execute( + insert(scheduled_events).values( + event_id=event_id, + created_by=created_by, + # NOTE: need to pass created_at because default=now + # doesn't have an effect when using postgresql.insert + created_at=now(), + ) + ) + + async def get_scheduled_events_for_user(self, user_id: int) -> list[Mapping]: + query = scheduled_events.select().where(scheduled_events.c.created_by == user_id) + return await self.db.fetch_all(query=query) + + async def remove_scheduled_event(self, *, event_id: int): + await self.db.execute( + scheduled_events.delete().where(scheduled_events.c.event_id == event_id) + ) + store = Store( database_url=settings.TEST_DATABASE_URL diff --git a/bot/exts/meetings/meetings.py b/bot/exts/meetings/meetings.py index f6898ed8..fc9ec18f 100644 --- a/bot/exts/meetings/meetings.py +++ b/bot/exts/meetings/meetings.py @@ -447,7 +447,10 @@ async def on_select(select_interaction: MessageInteraction, value: str): ) view = DropdownView.from_options( - options=options, on_select=on_select, placeholder="Choose a user" + options=options, + on_select=on_select, + placeholder="Choose a user", + creator_id=inter.user.id, ) await inter.send("Choose a user to downgrade to Basic.", view=view) else: diff --git a/bot/exts/schedule.py b/bot/exts/schedule.py index d86fd721..50793f1d 100644 --- a/bot/exts/schedule.py +++ b/bot/exts/schedule.py @@ -11,7 +11,9 @@ import pytz from disnake import GuildCommandInteraction from disnake import GuildScheduledEvent +from disnake import MessageInteraction from disnake.ext import commands +from disnake.ext.commands import Cog from bot import settings from bot.database import store @@ -26,6 +28,7 @@ from bot.utils.discord import display_name from bot.utils.ui import ButtonGroupOption from bot.utils.ui import ButtonGroupView +from bot.utils.ui import DropdownView from bot.utils.ui import LinkView logger = logging.getLogger(__name__) @@ -99,7 +102,10 @@ def check_user_response(m: disnake.Message): timeout=60, ) except asyncio.exceptions.TimeoutError: - await inter.send(content="⚠️ You waited too long to respond. Try again.") + await inter.send( + content="⚠️ You waited too long to respond. Try running `/schedule new` again." + ) + # TODO: handle this error more gracefully so it doesn't pollute the logs raise PromptCancelled if response_message.content.lower() == "cancel": await response_message.reply(content="✨ _Cancelled_") @@ -113,7 +119,7 @@ async def schedule_command(self, inter: GuildCommandInteraction): @schedule_command.sub_command(name="new") async def schedule_new(self, inter: GuildCommandInteraction): - """Quickly add a new scheduled event with guided prompts.""" + """Add a new scheduled event with guided prompts.""" # Step 1: Prompt for the start time tries = 0 max_retries = 3 @@ -218,10 +224,14 @@ async def schedule_new(self, inter: GuildCommandInteraction): reason=f"/schedule command used by user {user.id} ({display_name(user)})", **video_service_kwargs, ) + await store.create_scheduled_event(event_id=event.id, created_by=user.id) event_url = f"https://discord.com/events/{event.guild_id}/{event.id}" await inter.channel.send( - content='🙌 **Successfully created event.** Click "Event Link" below to view/edit your event and mark yourself as "Interested".', + content=( + '🙌 **Successfully created event.** Click "Event Link" below to mark yourself as "Interested".\n' + "To cancel your event, use `/schedule cancel`." + ), embed=make_event_embed(event), view=LinkView(label="Event Link", url=event_url), ) @@ -239,6 +249,56 @@ async def schedule_new(self, inter: GuildCommandInteraction): except disnake.errors.Forbidden: logger.warn("cannot send DM to user. skipping...") + @schedule_command.sub_command(name="cancel") + async def schedule_cancel(self, inter: GuildCommandInteraction): + """Cancel an event created through this bot""" + assert inter.user is not None + scheduled_events = await store.get_scheduled_events_for_user(inter.user.id) + events: list[GuildScheduledEvent] = [] + for event in scheduled_events: + event = inter.guild.get_scheduled_event(event["event_id"]) + if event: + events.append(event) + + if not events: + await inter.send("⚠️You have no events to cancel.", ephemeral=True) + return + + await inter.send( + "👌 OK, let's cancel your event. _Fetching your events…_", ephemeral=True + ) + + async def on_select(select_interaction: MessageInteraction, value: str): + logger.debug(f"selected event {value}") + event = inter.guild.get_scheduled_event(int(value)) + assert event is not None + logger.info(f"canceling event {event.id}") + await event.delete() + await select_interaction.response.edit_message( + content=f"✅ Successfully cancelled **{event.name}**.", + view=None, + ) + + options = [ + disnake.SelectOption( + label=f"{event.name} · {format_scheduled_start_time(event.scheduled_start_time)}", + value=str(event.id), + ) + for event in events + ] + view = DropdownView.from_options( + options=options, + on_select=on_select, + placeholder="Choose an event", + creator_id=inter.user.id, + ) + await inter.send(content="Choose an event to cancel.", view=view, ephemeral=True) + + @Cog.listener() + async def on_guild_scheduled_event_delete(self, event: GuildScheduledEvent) -> None: + logger.info(f"removing scheduled event {event.id}") + await store.remove_scheduled_event(event_id=event.id) + def setup(bot: commands.Bot) -> None: bot.add_cog(Schedule(bot)) diff --git a/bot/utils/ui.py b/bot/utils/ui.py index 14af6660..c6cd5926 100644 --- a/bot/utils/ui.py +++ b/bot/utils/ui.py @@ -102,8 +102,9 @@ async def callback(self, inter: disnake.MessageInteraction): class DropdownView(disnake.ui.View): - def __init__(self): + def __init__(self, creator_id: int): super().__init__() + self.creator_id = creator_id self.dropdown: Dropdown | None = None @classmethod @@ -113,10 +114,16 @@ def from_options( options: Sequence[disnake.SelectOption], on_select: Callback, placeholder: str | None = None, + creator_id: int, ) -> DropdownView: - view = cls() + view = cls(creator_id=creator_id) async def handle_select(inter: disnake.MessageInteraction, value): + # Ignore clicks by other users + assert inter.user is not None + if inter.user.id != creator_id: + await inter.send("⚠️ You can't interact with this UI.", ephemeral=True) + return await on_select(inter, value) view.stop() diff --git a/migrations/versions/16e3c8b20c26_add_scheduled_events.py b/migrations/versions/16e3c8b20c26_add_scheduled_events.py new file mode 100644 index 00000000..96feb84d --- /dev/null +++ b/migrations/versions/16e3c8b20c26_add_scheduled_events.py @@ -0,0 +1,42 @@ +"""add scheduled_events + +Revision ID: 16e3c8b20c26 +Revises: ad375ea654d5 +Create Date: 2021-12-31 10:29:54.135898 + +""" +from alembic import op +import sqlalchemy as sa +import bot + + +# revision identifiers, used by Alembic. +revision = "16e3c8b20c26" +down_revision = "ad375ea654d5" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "scheduled_events", + sa.Column("event_id", sa.BIGINT(), nullable=False), + sa.Column("created_by", sa.BIGINT(), nullable=True), + sa.Column("created_at", bot.database.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("event_id"), + ) + op.create_index( + op.f("ix_scheduled_events_created_by"), + "scheduled_events", + ["created_by"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_scheduled_events_created_by"), table_name="scheduled_events") + op.drop_table("scheduled_events") + # ### end Alembic commands ###