Skip to content

Commit

Permalink
feat(schedule): add /schedule cancel (#387)
Browse files Browse the repository at this point in the history
* feat(schedule): add /schedule cancel

* Update message on creation to include cancel instructions

* copy edits

* Fix zoom
  • Loading branch information
sloria authored Dec 31, 2021
1 parent 598be65 commit 764e931
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 29 deletions.
78 changes: 55 additions & 23 deletions bot/database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import datetime as dt
import logging
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Union

import databases
import nanoid
Expand Down Expand Up @@ -188,6 +187,19 @@ def generate_nanoid() -> str:
sa.Column("last_synced_at", TIMESTAMP),
)

scheduled_events = sa.Table(
"scheduled_events",
metadata,
sa.Column("event_id", BIGINT, primary_key=True, doc="Discord event ID"),
sa.Column(
"created_by",
BIGINT,
index=True,
doc="Discord user ID for the user who created the event through the bot",
),
created_at_column(),
)

# -----------------------------------------------------------------------------


Expand All @@ -196,7 +208,7 @@ class Store:

def __init__(
self,
database_url: Union[str, databases.DatabaseURL],
database_url: str | databases.DatabaseURL,
*,
force_rollback: bool = False,
):
Expand All @@ -211,7 +223,7 @@ def disconnect(self):
def transaction(self):
return self.db.transaction()

async def set_user_timezone(self, user_id: int, timezone: Optional[dt.tzinfo]):
async def set_user_timezone(self, user_id: int, timezone: dt.tzinfo | None):
logger.info(f"setting timezone for user_id {user_id}")
stmt = insert(user_settings).values(user_id=user_id, timezone=timezone)
stmt = stmt.on_conflict_do_update(
Expand All @@ -220,23 +232,23 @@ async def set_user_timezone(self, user_id: int, timezone: Optional[dt.tzinfo]):
)
await self.db.execute(stmt)

async def get_user_timezone(self, user_id: int) -> Optional[pytz.BaseTzInfo]:
async def get_user_timezone(self, user_id: int) -> pytz.BaseTzInfo | None:
logger.info(f"retrieving timezone for user_id {user_id}")
query = user_settings.select().where(user_settings.c.user_id == user_id)
return await self.db.fetch_val(query=query, column=user_settings.c.timezone)

async def get_guild_settings(self, guild_id: int) -> Optional[Mapping]:
async def get_guild_settings(self, guild_id: int) -> Mapping | None:
logger.info(f"retrieving guild settings sheet key for guild_id {guild_id}")
query = guild_settings.select().where(guild_settings.c.guild_id == guild_id)
return await self.db.fetch_one(query=query)

async def get_guild_schedule_sheet_key(self, guild_id: int) -> Optional[str]:
async def get_guild_schedule_sheet_key(self, guild_id: int) -> str | None:
query = guild_settings.select().where(guild_settings.c.guild_id == guild_id)
return await self.db.fetch_val(
query=query, column=guild_settings.c.schedule_sheet_key
)

async def get_guild_daily_message_channel_id(self, guild_id: int) -> Optional[int]:
async def get_guild_daily_message_channel_id(self, guild_id: int) -> int | None:
query = guild_settings.select().where(guild_settings.c.guild_id == guild_id)
return await self.db.fetch_val(
query=query, column=guild_settings.c.daily_message_channel_id
Expand Down Expand Up @@ -300,13 +312,13 @@ async def create_zoom_meeting(
)
await self.db.execute(stmt)

async def get_zoom_meeting(self, meeting_id: int) -> Optional[Mapping]:
async def get_zoom_meeting(self, meeting_id: int) -> Mapping | None:
query = zoom_meetings.select().where(zoom_meetings.c.meeting_id == meeting_id)
return await self.db.fetch_one(query=query)

async def get_latest_pending_zoom_meeting_for_user(
self, zoom_user: str
) -> Optional[Mapping]:
) -> Mapping | None:
query = (
zoom_meetings.select()
.where(
Expand Down Expand Up @@ -370,12 +382,12 @@ async def remove_zoom_message(self, *, message_id: int):
zoom_messages.delete().where(zoom_messages.c.message_id == message_id)
)

async def get_zoom_message(self, message_id: int) -> Optional[Mapping]:
async def get_zoom_message(self, message_id: int) -> Mapping | None:
return await self.db.fetch_one(
zoom_messages.select().where(zoom_messages.c.message_id == message_id)
)

async def get_zoom_messages(self, meeting_id: int) -> List[Mapping]:
async def get_zoom_messages(self, meeting_id: int) -> list[Mapping]:
return await self.db.fetch_all(
zoom_messages.select().where(zoom_messages.c.meeting_id == meeting_id)
)
Expand All @@ -385,8 +397,8 @@ async def add_zoom_participant(
*,
meeting_id: int,
name: str,
zoom_id: Optional[str],
email: Optional[str],
zoom_id: str | None,
email: str | None,
joined_at: dt.datetime,
):
stmt = insert(zoom_participants).values(
Expand All @@ -409,16 +421,14 @@ async def add_zoom_participant(
)
await self.db.execute(stmt)

async def get_zoom_participant(
self, *, meeting_id: int, name: str
) -> Optional[Mapping]:
async def get_zoom_participant(self, *, meeting_id: int, name: str) -> Mapping | None:
query = zoom_participants.select().where(
(zoom_participants.c.meeting_id == meeting_id)
& (zoom_participants.c.name == name)
)
return await self.db.fetch_one(query=query)

async def get_zoom_participants(self, meeting_id: int) -> List[Mapping]:
async def get_zoom_participants(self, meeting_id: int) -> list[Mapping]:
return await self.db.fetch_all(
zoom_participants.select()
.where(zoom_participants.c.meeting_id == meeting_id)
Expand Down Expand Up @@ -521,7 +531,7 @@ async def clear_aslpp_members(self):
async def clear_aslpp_intros(self):
await self.db.execute(aslpp_intros.delete())

async def get_aslpp_member(self, user_id: int) -> Optional[Mapping]:
async def get_aslpp_member(self, user_id: int) -> Mapping | None:
return await self.db.fetch_one(
aslpp_members.select().where(aslpp_members.c.user_id == user_id)
)
Expand Down Expand Up @@ -550,7 +560,7 @@ async def remove_aslpp_member(self, *, user_id: int):
aslpp_members.delete().where(aslpp_members.c.user_id == user_id)
)

async def get_aslpp_members_without_intro(self, since: dt.timedelta) -> List[Mapping]:
async def get_aslpp_members_without_intro(self, since: dt.timedelta) -> list[Mapping]:
return await self.db.fetch_all(
aslpp_members.select()
.where(
Expand All @@ -576,20 +586,42 @@ async def has_aslpp_intro(self, user_id: int) -> bool:
return False
return record["result"]

async def mark_aslpp_members_active(self, user_ids: List[int]):
async def mark_aslpp_members_active(self, user_ids: list[int]):
await self.db.execute(
aslpp_members.update()
.where(aslpp_members.c.user_id.in_(user_ids))
.values(is_active=True)
)

async def mark_aslpp_members_inactive(self, user_ids: List[int]):
async def mark_aslpp_members_inactive(self, user_ids: list[int]):
await self.db.execute(
aslpp_members.update()
.where(aslpp_members.c.user_id.in_(user_ids))
.values(is_active=False)
)

# Scheduled events

async def create_scheduled_event(self, *, event_id: int, created_by: int):
await self.db.execute(
insert(scheduled_events).values(
event_id=event_id,
created_by=created_by,
# NOTE: need to pass created_at because default=now
# doesn't have an effect when using postgresql.insert
created_at=now(),
)
)

async def get_scheduled_events_for_user(self, user_id: int) -> list[Mapping]:
query = scheduled_events.select().where(scheduled_events.c.created_by == user_id)
return await self.db.fetch_all(query=query)

async def remove_scheduled_event(self, *, event_id: int):
await self.db.execute(
scheduled_events.delete().where(scheduled_events.c.event_id == event_id)
)


store = Store(
database_url=settings.TEST_DATABASE_URL
Expand Down
5 changes: 4 additions & 1 deletion bot/exts/meetings/meetings.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,10 @@ async def on_select(select_interaction: MessageInteraction, value: str):
)

view = DropdownView.from_options(
options=options, on_select=on_select, placeholder="Choose a user"
options=options,
on_select=on_select,
placeholder="Choose a user",
creator_id=inter.user.id,
)
await inter.send("Choose a user to downgrade to Basic.", view=view)
else:
Expand Down
66 changes: 63 additions & 3 deletions bot/exts/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import pytz
from disnake import GuildCommandInteraction
from disnake import GuildScheduledEvent
from disnake import MessageInteraction
from disnake.ext import commands
from disnake.ext.commands import Cog

from bot import settings
from bot.database import store
Expand All @@ -26,6 +28,7 @@
from bot.utils.discord import display_name
from bot.utils.ui import ButtonGroupOption
from bot.utils.ui import ButtonGroupView
from bot.utils.ui import DropdownView
from bot.utils.ui import LinkView

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +102,10 @@ def check_user_response(m: disnake.Message):
timeout=60,
)
except asyncio.exceptions.TimeoutError:
await inter.send(content="⚠️ You waited too long to respond. Try again.")
await inter.send(
content="⚠️ You waited too long to respond. Try running `/schedule new` again."
)
# TODO: handle this error more gracefully so it doesn't pollute the logs
raise PromptCancelled
if response_message.content.lower() == "cancel":
await response_message.reply(content="✨ _Cancelled_")
Expand All @@ -113,7 +119,7 @@ async def schedule_command(self, inter: GuildCommandInteraction):

@schedule_command.sub_command(name="new")
async def schedule_new(self, inter: GuildCommandInteraction):
"""Quickly add a new scheduled event with guided prompts."""
"""Add a new scheduled event with guided prompts."""
# Step 1: Prompt for the start time
tries = 0
max_retries = 3
Expand Down Expand Up @@ -218,10 +224,14 @@ async def schedule_new(self, inter: GuildCommandInteraction):
reason=f"/schedule command used by user {user.id} ({display_name(user)})",
**video_service_kwargs,
)
await store.create_scheduled_event(event_id=event.id, created_by=user.id)

event_url = f"https://discord.com/events/{event.guild_id}/{event.id}"
await inter.channel.send(
content='🙌 **Successfully created event.** Click "Event Link" below to view/edit your event and mark yourself as "Interested".',
content=(
'🙌 **Successfully created event.** Click "Event Link" below to mark yourself as "Interested".\n'
"To cancel your event, use `/schedule cancel`."
),
embed=make_event_embed(event),
view=LinkView(label="Event Link", url=event_url),
)
Expand All @@ -239,6 +249,56 @@ async def schedule_new(self, inter: GuildCommandInteraction):
except disnake.errors.Forbidden:
logger.warn("cannot send DM to user. skipping...")

@schedule_command.sub_command(name="cancel")
async def schedule_cancel(self, inter: GuildCommandInteraction):
"""Cancel an event created through this bot"""
assert inter.user is not None
scheduled_events = await store.get_scheduled_events_for_user(inter.user.id)
events: list[GuildScheduledEvent] = []
for event in scheduled_events:
event = inter.guild.get_scheduled_event(event["event_id"])
if event:
events.append(event)

if not events:
await inter.send("⚠️You have no events to cancel.", ephemeral=True)
return

await inter.send(
"👌 OK, let's cancel your event. _Fetching your events…_", ephemeral=True
)

async def on_select(select_interaction: MessageInteraction, value: str):
logger.debug(f"selected event {value}")
event = inter.guild.get_scheduled_event(int(value))
assert event is not None
logger.info(f"canceling event {event.id}")
await event.delete()
await select_interaction.response.edit_message(
content=f"✅ Successfully cancelled **{event.name}**.",
view=None,
)

options = [
disnake.SelectOption(
label=f"{event.name} · {format_scheduled_start_time(event.scheduled_start_time)}",
value=str(event.id),
)
for event in events
]
view = DropdownView.from_options(
options=options,
on_select=on_select,
placeholder="Choose an event",
creator_id=inter.user.id,
)
await inter.send(content="Choose an event to cancel.", view=view, ephemeral=True)

@Cog.listener()
async def on_guild_scheduled_event_delete(self, event: GuildScheduledEvent) -> None:
logger.info(f"removing scheduled event {event.id}")
await store.remove_scheduled_event(event_id=event.id)


def setup(bot: commands.Bot) -> None:
bot.add_cog(Schedule(bot))
11 changes: 9 additions & 2 deletions bot/utils/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ async def callback(self, inter: disnake.MessageInteraction):


class DropdownView(disnake.ui.View):
def __init__(self):
def __init__(self, creator_id: int):
super().__init__()
self.creator_id = creator_id
self.dropdown: Dropdown | None = None

@classmethod
Expand All @@ -113,10 +114,16 @@ def from_options(
options: Sequence[disnake.SelectOption],
on_select: Callback,
placeholder: str | None = None,
creator_id: int,
) -> DropdownView:
view = cls()
view = cls(creator_id=creator_id)

async def handle_select(inter: disnake.MessageInteraction, value):
# Ignore clicks by other users
assert inter.user is not None
if inter.user.id != creator_id:
await inter.send("⚠️ You can't interact with this UI.", ephemeral=True)
return
await on_select(inter, value)
view.stop()

Expand Down
42 changes: 42 additions & 0 deletions migrations/versions/16e3c8b20c26_add_scheduled_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""add scheduled_events
Revision ID: 16e3c8b20c26
Revises: ad375ea654d5
Create Date: 2021-12-31 10:29:54.135898
"""
from alembic import op
import sqlalchemy as sa
import bot


# revision identifiers, used by Alembic.
revision = "16e3c8b20c26"
down_revision = "ad375ea654d5"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"scheduled_events",
sa.Column("event_id", sa.BIGINT(), nullable=False),
sa.Column("created_by", sa.BIGINT(), nullable=True),
sa.Column("created_at", bot.database.TIMESTAMP(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("event_id"),
)
op.create_index(
op.f("ix_scheduled_events_created_by"),
"scheduled_events",
["created_by"],
unique=False,
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_scheduled_events_created_by"), table_name="scheduled_events")
op.drop_table("scheduled_events")
# ### end Alembic commands ###

0 comments on commit 764e931

Please sign in to comment.