diff --git a/.env.example b/.env.example index c7a7ad6a..04a10443 100644 --- a/.env.example +++ b/.env.example @@ -9,11 +9,9 @@ GOOGLE_PRIVATE_KEY_ID="CHANGEME" GOOGLE_PRIVATE_KEY="CHANGEME" GOOGLE_CLIENT_EMAIL="CHANGEME" FEEDBACK_SHEET_KEY="CHANGEME" -# Mapping of guild IDs => gsheet keys -SCHEDULE_SHEET_KEYS="123=321" -# Comma-delimited list of channel IDs where to send daily schedules -SCHEDULE_CHANNELS="456" +TOPICS_SHEET_KEY="CHANGEME" DAILY_PRACTICE_SEND_TIME=14:00 +GUILD_SETTINGS="" # Mapping of Discord usernames w/ discriminator => Zoom user email addresses or IDs ZOOM_USERS="bob#1234=bob@example.com,alice#5678=alice@example.com" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3269e80c..cabb9b6c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,8 +10,8 @@ jobs: with: python-version: 3.9.0 - name: Install dependencies - run: . script/bootstrap + run: ./script/bootstrap - name: Run migrations run: PYTHONPATH=. alembic upgrade head - name: Run tests - run: . script/test + run: ./script/test diff --git a/DEVELOPING.md b/DEVELOPING.md index 77e4f7f9..f29658bc 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -9,7 +9,7 @@ NOTE: If you're not on macOS, you'll need to install [pyenv](https://github.com/ With Docker running: ``` -. script/bootstrap +./script/bootstrap ``` Edit `.env` with the proper values. @@ -19,13 +19,19 @@ Edit `.env` with the proper values. Re-run the bootstrap script ``` -. script/bootstrap +./script/update ``` ## Running tests ``` -. script/test +./script/test +``` + +## Resetting the database + +``` +DB_RESET=1 SKIP_BOOTSTRAP=1 ./script/update ``` ## Releasing diff --git a/bot.py b/bot.py index 9ae960bd..518e59c3 100644 --- a/bot.py +++ b/bot.py @@ -51,8 +51,6 @@ GOOGLE_CLIENT_EMAIL = env.str("GOOGLE_CLIENT_EMAIL", required=True) GOOGLE_TOKEN_URI = env.str("GOOGLE_TOKEN_URI", "https://oauth2.googleapis.com/token") FEEDBACK_SHEET_KEY = env.str("FEEDBACK_SHEET_KEY", required=True) -SCHEDULE_SHEET_KEYS = env.dict("SCHEDULE_SHEET_KEYS", required=True, subcast_key=int) -SCHEDULE_CHANNELS = env.list("SCHEDULE_CHANNELS", required=True, subcast=int) ZOOM_USERS = env.dict("ZOOM_USERS", required=True) ZOOM_JWT = env.str("ZOOM_JWT", required=True) @@ -85,7 +83,9 @@ intents=intents, ) -store = database.Store(database_url=DATABASE_URL, force_rollback=TESTING) +store = database.Store( + database_url=TEST_DATABASE_URL if TESTING else DATABASE_URL, force_rollback=TESTING +) # ----------------------------------------------------------------------------- @@ -207,12 +207,11 @@ async def sign_error(ctx, error): def handshape_impl(name: str): logger.info(f"handshape: '{name}'") - if name == "random": - name = random.choice(tuple(handshapes.HANDSHAPES.keys())) - logger.info(f"chose '{name}'") - try: - handshape = handshapes.get_handshape(name) + if name == "random": + handshape = handshapes.get_random_handshape() + else: + handshape = handshapes.get_handshape(name) except handshapes.HandshapeNotFoundError: logger.info(f"handshape '{name}' not found") suggestion = did_you_mean(name, tuple(handshapes.HANDSHAPES.keys())) @@ -350,21 +349,23 @@ class PracticeSession(NamedTuple): notes: str -def get_practice_worksheet_for_guild(guild_id: int): +async def get_practice_worksheet_for_guild(guild_id: int): logger.info(f"fetching practice worksheet {guild_id}") client = get_gsheet_client() - sheet = client.open_by_key(SCHEDULE_SHEET_KEYS[guild_id]) + sheet_key = await store.get_guild_schedule_sheet_key(guild_id) + assert sheet_key is not None + sheet = client.open_by_key(sheet_key) return sheet.get_worksheet(0) -def get_practice_sessions( +async def get_practice_sessions( guild_id: int, dtime: dt.datetime, *, worksheet=None, parse_settings: Optional[dict] = None, ) -> List[PracticeSession]: - worksheet = worksheet or get_practice_worksheet_for_guild(guild_id) + worksheet = worksheet or await get_practice_worksheet_for_guild(guild_id) all_values = worksheet.get_all_values() return sorted( [ @@ -425,7 +426,7 @@ def format_multi_time(dtime: dt.datetime) -> str: ) -def make_practice_session_embed( +async def make_practice_session_embed( guild_id: int, sessions: List[PracticeSession], *, dtime: dt.datetime ) -> discord.Embed: now_pacific = utcnow().astimezone(PACIFIC) @@ -435,7 +436,7 @@ def make_practice_session_embed( description = f"Today - {description}" elif (dtime_pacific.date() - now_pacific.date()).days == 1: description = f"Tomorrow - {description}" - sheet_key = SCHEDULE_SHEET_KEYS[guild_id] + sheet_key = await store.get_guild_schedule_sheet_key(guild_id) schedule_url = f"https://docs.google.com/spreadsheets/d/{sheet_key}/edit" embed = discord.Embed( description=description, @@ -467,10 +468,10 @@ def make_practice_session_embed( return embed -def make_practice_sessions_today_embed(guild_id: int) -> discord.Embed: +async def make_practice_sessions_today_embed(guild_id: int) -> discord.Embed: now = utcnow() - sessions = get_practice_sessions(guild_id, dtime=now) - return make_practice_session_embed(guild_id, sessions, dtime=now) + sessions = await get_practice_sessions(guild_id, dtime=now) + return await make_practice_session_embed(guild_id, sessions, dtime=now) async def is_in_guild(ctx: Context) -> bool: @@ -483,7 +484,8 @@ async def is_in_guild(ctx: Context) -> bool: async def has_practice_schedule(ctx: Context) -> bool: await is_in_guild(ctx) - if ctx.guild.id not in SCHEDULE_SHEET_KEYS: + has_practice_schedule = await store.guild_has_practice_schedule(ctx.guild.id) + if not has_practice_schedule: raise commands.errors.CheckFailure( "⚠️ No configured practice schedule for this server. If you think this is a mistake, contact the bot owner." ) @@ -506,7 +508,7 @@ async def has_practice_schedule(ctx: Context) -> bool: ) -def schedule_impl(guild_id: int, when: Optional[str]): +async def schedule_impl(guild_id: int, when: Optional[str]): settings: Optional[Dict[str, str]] if when and when.strip().lower() != "today": settings = {"PREFER_DATES_FROM": "future"} @@ -515,8 +517,8 @@ def schedule_impl(guild_id: int, when: Optional[str]): else: settings = None dtime = utcnow() - sessions = get_practice_sessions(guild_id, dtime=dtime, parse_settings=settings) - embed = make_practice_session_embed(guild_id, sessions, dtime=dtime) + sessions = await get_practice_sessions(guild_id, dtime=dtime, parse_settings=settings) + embed = await make_practice_session_embed(guild_id, sessions, dtime=dtime) return {"embed": embed} @@ -524,7 +526,8 @@ def schedule_impl(guild_id: int, when: Optional[str]): @commands.check(has_practice_schedule) async def schedule_command(ctx: Context, *, when: Optional[str]): await ctx.channel.trigger_typing() - await ctx.send(**schedule_impl(guild_id=ctx.guild.id, when=when)) + ret = await schedule_impl(guild_id=ctx.guild.id, when=when) + await ctx.send(**ret) PRACTICE_HELP = """Schedule a practice session @@ -632,12 +635,16 @@ async def practice_impl(*, guild_id: int, host: str, start_time: str, user_id: i ) row = (display_dtime, host, notes) logger.info(f"adding new practice session to sheet: {row}") - worksheet = get_practice_worksheet_for_guild(guild_id) + worksheet = await get_practice_worksheet_for_guild(guild_id) worksheet.append_row(row) dtime_pacific = dtime.astimezone(PACIFIC) short_display_date = f"{dtime_pacific:%a, %b %d} {format_multi_time(dtime)}" - sessions = get_practice_sessions(guild_id=guild_id, dtime=dtime, worksheet=worksheet) - embed = make_practice_session_embed(guild_id=guild_id, sessions=sessions, dtime=dtime) + sessions = await get_practice_sessions( + guild_id=guild_id, dtime=dtime, worksheet=worksheet + ) + embed = await make_practice_session_embed( + guild_id=guild_id, sessions=sessions, dtime=dtime + ) if str(used_timezone) != str(user_timezone): await store.set_user_timezone(user_id, used_timezone) return { @@ -704,36 +711,73 @@ async def daily_practice_message(): if now_eastern.time() > DAILY_PRACTICE_SEND_TIME: date = now_eastern.date() + dt.timedelta(days=1) then = EASTERN.localize(dt.datetime.combine(date, DAILY_PRACTICE_SEND_TIME)) + channel_ids = list(await store.get_daily_message_channel_ids()) logger.info( - f"practice schedules for {len(SCHEDULE_CHANNELS)} channels will be sent at {then.isoformat()}" + f"practice schedules for {len(channel_ids)} channels will be sent at {then.isoformat()}" ) await discord.utils.sleep_until(then.astimezone(dt.timezone.utc)) - for channel_id in SCHEDULE_CHANNELS: + for channel_id in channel_ids: try: - channel = bot.get_channel(channel_id) - guild = channel.guild - logger.info( - f'sending daily practice schedule for guild: "{guild.name}" in #{channel.name}' - ) - asyncio.create_task( - channel.send(embed=make_practice_sessions_today_embed(guild.id)) - ) + asyncio.create_task(send_daily_message(channel_id)) except Exception: logger.exception(f"could not send to channel {channel_id}") +random.seed("howsignbot") +SHUFFLED_HANDSHAPE_NAMES = sorted(list(handshapes.HANDSHAPES.keys())) +random.shuffle(SHUFFLED_HANDSHAPE_NAMES) + + +def get_daily_handshape(dtime: Optional[dt.datetime] = None) -> handshapes.Handshape: + dtime = dtime or utcnow() + day_of_year = dtime.timetuple().tm_yday + name = SHUFFLED_HANDSHAPE_NAMES[day_of_year % len(SHUFFLED_HANDSHAPE_NAMES)] + return handshapes.get_handshape(name) + + +async def send_daily_message(channel_id: int): + channel = bot.get_channel(channel_id) + guild = channel.guild + logger.info(f'sending daily message for guild: "{guild.name}" in #{channel.name}') + embed = await make_practice_sessions_today_embed(guild.id) + file_ = None + + settings = await store.get_guild_settings(guild.id) + + # Handshape of the Day + if settings.get("include_handshape_of_the_day"): + handshape = get_daily_handshape() + filename = f"{handshape.name}.png" + file_ = discord.File(handshape.path, filename=filename) + embed.set_thumbnail(url=f"attachment://{filename}") + embed.add_field( + name="Handshape of the Day", value=f'"{handshape.name}"', inline=False + ) + + # Topics of the Day + if settings.get("include_topics_of_the_day"): + topic = await store.get_topic_for_guild(guild.id) + topic2 = await store.get_topic_for_guild(guild.id) + embed.add_field(name="Discuss...", value=f'"{topic}"\n\n"{topic2}"', inline=False) + + await channel.send(file=file_, embed=embed) + + @bot.command( - name="send_schedule", + name="send_daily_message", help="BOT OWNER ONLY: Manually send a daily practice schedule for a channel", ) @commands.is_owner() -async def send_schedule_command(ctx: Context, channel_id: int): - if channel_id not in SCHEDULE_CHANNELS: +async def send_daily_message_command(ctx: Context, channel_id: int): + channel_ids = set(await store.get_daily_message_channel_ids()) + if channel_id not in channel_ids: await ctx.send(f"⚠️ Schedule channel not configured for Channel ID {channel_id}") + return + await send_daily_message(channel_id) + channel = bot.get_channel(channel_id) guild = channel.guild - await channel.send(embed=make_practice_sessions_today_embed(guild.id)) - await ctx.send(f'🗓 Schedule sent to "{guild.name}", #{channel.name}') + await ctx.send(f'🗓 Daily message sent to "{guild.name}", #{channel.name}') # ----------------------------------------------------------------------------- @@ -844,7 +888,6 @@ async def idiom_command(ctx, spoiler: Optional[str]): # ----------------------------------------------------------------------------- - ZOOM_CLOSED_MESSAGE = "✨ _Zoom meeting ended_" @@ -1179,8 +1222,6 @@ async def presence_command_error(ctx, error): await ctx.send(content=message) -# Used for getting channel IDs for SCHEDULE_CHANNELS - CHANNEL_INFO_TEMPLATE = """Guild name: {ctx.guild.name} Guild ID: {ctx.guild.id} Channel name: {ctx.channel.name} diff --git a/dump_guild_settings.py b/dump_guild_settings.py new file mode 100644 index 00000000..cac8f64a --- /dev/null +++ b/dump_guild_settings.py @@ -0,0 +1,44 @@ +import json +from base64 import b64encode + +dev_guild_settings = [ + # CHANGEME + { + "guild_id": 123, + "schedule_sheet_key": "changeme", + "daily_message_channel_id": 321, + "include_handshape_of_the_day": True, + "include_topics_of_the_day": True, + } +] + +prod_guild_settings = [ + # CHANGEME + { + "guild_id": 123, + "schedule_sheet_key": "changeme", + "daily_message_channel_id": 321, + "include_handshape_of_the_day": True, + "include_topics_of_the_day": True, + } +] + + +def encode_settings(settings): + return b64encode(bytes(json.dumps(settings), "utf-8")).decode("utf-8") + + +def main(): + print("Dev settings:\n") + dev_encoded = encode_settings(dev_guild_settings) + print(f"GUILD_SETTINGS={dev_encoded}") + + print() + + print("Prod settings:\n") + prod_encoded = encode_settings(prod_guild_settings) + print(f"GUILD_SETTINGS={prod_encoded}") + + +if __name__ == "__main__": + main() diff --git a/lib/database/__init__.py b/lib/database/__init__.py index 8fb4963d..c935a9cf 100644 --- a/lib/database/__init__.py +++ b/lib/database/__init__.py @@ -1,15 +1,20 @@ +import datetime as dt import logging -from typing import Optional, Union +import random +from typing import Optional, Union, Iterator import databases import pytz +from databases.backends.postgres import Record import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import insert, BIGINT +from sqlalchemy import sql +from sqlalchemy.dialects.postgresql import insert, BIGINT, TIMESTAMP as _TIMESTAMP # re-export DatabaseURL = databases.DatabaseURL metadata = sa.MetaData() +NULL = sql.null() logger = logging.getLogger(__name__) @@ -28,8 +33,32 @@ def process_result_value(self, value, dialect): return pytz.timezone(value) +class TIMESTAMP(sa.TypeDecorator): + impl = _TIMESTAMP(timezone=True) + + # ----------------------------------------------------------------------------- +guild_settings = sa.Table( + "guild_settings", + metadata, + sa.Column("guild_id", BIGINT, primary_key=True), + sa.Column("schedule_sheet_key", sa.Text), + sa.Column("daily_message_channel_id", BIGINT), + sa.Column( + "include_handshape_of_the_day", + sa.Boolean, + server_default=sql.false(), + nullable=False, + ), + sa.Column( + "include_topics_of_the_day", + sa.Boolean, + server_default=sql.false(), + nullable=False, + ), +) + user_settings = sa.Table( "user_settings", metadata, @@ -37,6 +66,23 @@ def process_result_value(self, value, dialect): sa.Column("timezone", TimeZone), ) +topics = sa.Table( + "topics", + metadata, + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("content", sa.Text, unique=True), +) + +topic_usages = sa.Table( + "topic_usages", + metadata, + sa.Column("guild_id", BIGINT, primary_key=True), + sa.Column( + "topic_id", sa.ForeignKey(topics.c.id, ondelete="CASCADE"), primary_key=True + ), + sa.Column("last_used_at", TIMESTAMP, nullable=False), +) + # ----------------------------------------------------------------------------- @@ -70,3 +116,82 @@ async def get_user_timezone(self, user_id: int) -> Optional[pytz.BaseTzInfo]: logger.info(f"retrieving timezone for user_id {user_id}") query = user_settings.select().where(user_settings.c.user_id == user_id) return await self.db.fetch_val(query=query, column=user_settings.c.timezone) + + async def get_guild_settings(self, guild_id: int) -> Record: + logger.info(f"retrieving guild settings sheet key for guild_id {guild_id}") + query = guild_settings.select().where(guild_settings.c.guild_id == guild_id) + return await self.db.fetch_one(query=query) + + async def get_guild_schedule_sheet_key(self, guild_id: int) -> Optional[str]: + query = guild_settings.select().where(guild_settings.c.guild_id == guild_id) + return await self.db.fetch_val( + query=query, column=guild_settings.c.schedule_sheet_key + ) + + async def guild_has_practice_schedule(self, guild_id: int) -> bool: + select = sa.select( + ( + sa.exists() + .where( + (guild_settings.c.guild_id == guild_id) + & (guild_settings.c.schedule_sheet_key != NULL) + ) + .label("result"), + ) + ) + record = await self.db.fetch_one(select) + return record.get("result") + + async def get_daily_message_channel_ids(self) -> Iterator[int]: + all_settings = await self.db.fetch_all( + guild_settings.select().where( + guild_settings.c.daily_message_channel_id != NULL + ) + ) + return (record.get("daily_message_channel_id") for record in all_settings) + + async def mark_topic_used(self, guild_id: int, topic_id: int): + # Upsert topic_usage + stmt = insert(topic_usages).values( + topic_id=topic_id, + last_used_at=dt.datetime.now(dt.timezone.utc), + guild_id=guild_id, + ) + stmt = stmt.on_conflict_do_update( + index_elements=(topic_usages.c.topic_id, topic_usages.c.guild_id), + set_=dict( + last_used_at=stmt.excluded.last_used_at, + guild_id=stmt.excluded.guild_id, + ), + ) + await self.db.execute(stmt) + + async def get_topic_for_guild(self, guild_id: Optional[int] = None) -> str: + # If not sending within a guild, just randomly choose among all topics + if not guild_id: + all_topics = await self.db.fetch_all(topics.select()) + return random.choice(all_topics).get("content") + + unused_topics_records = await self.db.fetch_all( + topics.select().where( + ~sa.exists().where(topic_usages.c.topic_id == topics.c.id) + ) + ) + if unused_topics_records: + # Randomly select among unused topics + topic = random.choice(unused_topics_records) + else: + # Randomly choose from 20 least-recently used topics + least_recently_used = await self.db.fetch_all( + topics.select() + .select_from( + topics.join(topic_usages, topic_usages.c.topic_id == topics.c.id) + ) + .order_by(topic_usages.c.last_used_at) + .limit(20) + ) + topic = random.choice(least_recently_used) + + await self.mark_topic_used(guild_id, topic.get("id")) + + return topic.get("content") diff --git a/lib/handshapes/__init__.py b/lib/handshapes/__init__.py index 5f4f56f9..c1250ba9 100644 --- a/lib/handshapes/__init__.py +++ b/lib/handshapes/__init__.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from pathlib import Path @@ -74,3 +75,8 @@ def get_handshape(name): f"Could not find handshape with name '{name}'" ) from error return Handshape(name=cased_name, path=path) + + +def get_random_handshape(): + name = random.choice(tuple(HANDSHAPES.keys())) + return get_handshape(name) diff --git a/migrations/versions/90f23f127d05_add_topics.py b/migrations/versions/90f23f127d05_add_topics.py new file mode 100644 index 00000000..77c3db30 --- /dev/null +++ b/migrations/versions/90f23f127d05_add_topics.py @@ -0,0 +1,44 @@ +"""add topics + +Revision ID: 90f23f127d05 +Revises: a429443e6c16 +Create Date: 2020-11-17 23:49:07.131973 + +""" +from alembic import op +import sqlalchemy as sa +import database + + +# revision identifiers, used by Alembic. +revision = "90f23f127d05" +down_revision = "a429443e6c16" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "topics", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("content", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("content"), + ) + op.create_table( + "topic_usages", + sa.Column("guild_id", sa.BIGINT(), nullable=False), + sa.Column("topic_id", sa.Integer(), nullable=False), + sa.Column("last_used_at", database.TIMESTAMP(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["topic_id"], ["topics.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("guild_id", "topic_id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("topic_usages") + op.drop_table("topics") + # ### end Alembic commands ### diff --git a/migrations/versions/f9c59562a31a_add_guild_settings.py b/migrations/versions/f9c59562a31a_add_guild_settings.py new file mode 100644 index 00000000..12b80464 --- /dev/null +++ b/migrations/versions/f9c59562a31a_add_guild_settings.py @@ -0,0 +1,69 @@ +"""add guild settings + +Revision ID: f9c59562a31a +Revises: 90f23f127d05 +Create Date: 2020-11-19 00:12:17.159256 + +""" +import json +from base64 import b64decode + +from alembic import op +import sqlalchemy as sa +from environs import Env + + +# revision identifiers, used by Alembic. +revision = "f9c59562a31a" +down_revision = "90f23f127d05" +branch_labels = None +depends_on = None + +env = Env() +env.read_env() + + +def decode_settings(encoded): + return json.loads(b64decode(encoded)) + + +GUILD_SETTINGS = env.str("GUILD_SETTINGS", None) + + +def load_table(connection, table): + return sa.Table(table, sa.MetaData(), autoload_with=connection) + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "guild_settings", + sa.Column("guild_id", sa.BIGINT(), nullable=False), + sa.Column("schedule_sheet_key", sa.Text(), nullable=True), + sa.Column("daily_message_channel_id", sa.BIGINT(), nullable=True), + sa.Column( + "include_handshape_of_the_day", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.Column( + "include_topics_of_the_day", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.PrimaryKeyConstraint("guild_id"), + ) + # ### end Alembic commands ### + if GUILD_SETTINGS: + values = decode_settings(GUILD_SETTINGS) + connection = op.get_bind() + guild_settings = load_table(connection, "guild_settings") + connection.execute(guild_settings.insert().values(values)) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("guild_settings") + # ### end Alembic commands ### diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..d0237628 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +filterwarnings = + ignore:.*@coroutine.*:DeprecationWarning diff --git a/requirements-dev.txt b/requirements-dev.txt index 168629fa..03cdea15 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,4 +7,5 @@ python-semantic-release==7.3.0 syrupy==1.0.0 SQLAlchemy-Utils==0.36.8 pytest-asyncio==0.14.0 +asynctest==0.13.0 -e .[dev] diff --git a/script/bootstrap b/script/bootstrap old mode 100644 new mode 100755 index 520965f5..9cd42918 --- a/script/bootstrap +++ b/script/bootstrap @@ -1,20 +1,13 @@ +#!/bin/sh # script/bootstrap: Resolve all dependencies that the application requires to run. - -if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then - brew bundle check >/dev/null 2>&1 || { - echo "==> Installing Homebrew dependencies..." - brew bundle - } -fi +set -e if [ -z "$CI" ]; then name=${PWD##*/} # If python version is passed as first argument, use that, otherwise use 3.8 python_version=${1:-3.9.0} echo "==> Bootstrapping a Python $python_version environment called $name..." - pyenv install "$python_version" --skip-existing - pyenv virtualenv "$python_version" "$name" - pyenv local "$name" + pyenv virtualenv $python_version $name | true eval "$(pyenv init -)" unset name python_version fi @@ -22,32 +15,37 @@ fi # XXX Workaround to get psycopg2 to install properly on macOS if grep --quiet 'psycopg2' -- 'requirements.txt' && [ "$(uname -s)" = "Darwin" ]; then echo "==> Installing psycopg2..." - LDFLAGS=-L/usr/local/opt/openssl/lib python -m pip install $(grep 'psycopg2' -- 'requirements.txt') + LDFLAGS=-L/usr/local/opt/openssl/lib python -m pip install -q $(grep 'psycopg2' -- 'requirements.txt') fi -if [[ -f requirements-dev.txt ]]; then +if [ -f requirements-dev.txt ]; then echo "==> Installing/updating from requirements-dev.txt..." - python -m pip install -U -r requirements-dev.txt + python -m pip install -q -U -r requirements-dev.txt elif [ -f requirements.txt ]; then echo "==> Installing/updating from requirements.txt..." - python -m pip install -U -r requirements.txt + python -m pip install -q -U -r requirements.txt fi -if [[ -f .pre-commit-config.yaml ]]; then +if [ -f .pre-commit-config.yaml ]; then echo "==> Installing/updating pre-commit..." python -m pip install -U pre-commit echo "==> Installing/updating pre-commit hook..." pre-commit install -f fi -if [[ -f .env.example && ! -f .env ]]; then +if [ -f .env.example ] && [ ! -f .env ]; then echo "==> Copying .env.example to .env..." cp .env.example .env fi -if [[ -f docker-compose.yml ]]; then +if [ -f docker-compose.yml ]; then echo "==> Starting containers..." docker-compose up -d fi +if [ -f setup.py ]; then + echo "==> Installing from setup.py..." + python -m pip install -q -e '.[dev]' +fi + echo "==> Bootstrapping finished." diff --git a/script/deploy b/script/deploy index 0cc17c7c..bc0baab9 100644 --- a/script/deploy +++ b/script/deploy @@ -1,3 +1,4 @@ +#!/bin/sh echo "==> Checking out master..." git checkout master echo "==> Merging dev into master..." diff --git a/script/reset_db.py b/script/reset_db.py new file mode 100755 index 00000000..b222d3a1 --- /dev/null +++ b/script/reset_db.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +from sqlalchemy_utils import create_database, drop_database + +import bot + + +def main(): + drop_database(str(bot.DATABASE_URL)) + create_database(str(bot.DATABASE_URL)) + + +if __name__ == "__main__": + main() diff --git a/script/sync_topics.py b/script/sync_topics.py new file mode 100755 index 00000000..f371a8f1 --- /dev/null +++ b/script/sync_topics.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Sync database with a Google sheet. +This should be run locally. + +Usage: PYTHONPATH=. ./script/sync_topics.py +""" +import asyncio +from pprint import pprint + +from environs import Env + +from databases import Database +from database import topics +from sqlalchemy.dialects.postgresql import insert + +from bot import get_gsheet_client + +env = Env() +env.read_env() + +DATABASE_URL = env.str("DATABASE_URL", required=True) +STAGING_DATABASE_URL = env.str("STAGING_DATABASE_URL", None) +PROD_DATABASE_URL = env.str("PROD_DATABASE_URL", None) +TOPICS_SHEET_KEY = env.str("TOPICS_SHEET_KEY", required=True) + + +async def sync_topics(database_url, rows): + all_ids = tuple(row["id"] for row in rows) + async with Database(database_url) as db: + async with db.transaction(): + stmt = insert(topics).values(rows) + stmt = stmt.on_conflict_do_update( + index_elements=(topics.c.id,), set_=dict(content=stmt.excluded.content) + ) + await db.execute(stmt) + await db.execute(topics.delete().where(~topics.c.id.in_(all_ids))) + + +async def main(): + client = get_gsheet_client() + sheet = client.open_by_key(TOPICS_SHEET_KEY) + worksheet = sheet.get_worksheet(0) + rows = worksheet.get_all_records() + pprint(rows) + + for database_url in (DATABASE_URL, STAGING_DATABASE_URL, PROD_DATABASE_URL): + if database_url: + await sync_topics(database_url, rows) + + print(f"Synced {len(rows)} topics.") + + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) diff --git a/script/test b/script/test old mode 100644 new mode 100755 index 1a6e12c4..bfd4eba7 --- a/script/test +++ b/script/test @@ -1,2 +1,4 @@ +#!/bin/sh +set -e pre-commit run --all-files pytest diff --git a/script/update b/script/update old mode 100644 new mode 100755 index b922b836..52cca659 --- a/script/update +++ b/script/update @@ -1,11 +1,21 @@ +#!/bin/sh # Based on conventions from https://github.com/github/scripts-to-rule-them-all # script/update: Update application to run for its current checkout. +set -e if [ -z "$SKIP_BOOTSTRAP" ]; then - . script/bootstrap + ./script/bootstrap +fi + +if [ -n "$DB_RESET" ]; then + echo "==> Recreating database" + PYTHONPATH=. ./script/reset_db.py fi echo "==> Running migrations" PYTHONPATH=. alembic upgrade head +echo "==> Syncing topics" +PYTHONPATH=. ./script/sync_topics.py + echo "==> Update finished." diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..0124c477 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +import os +from contextlib import suppress + +import pytest +from sqlalchemy.exc import ProgrammingError +from sqlalchemy import create_engine +from sqlalchemy_utils import create_database, drop_database + +# Must be before bot import +os.environ["TESTING"] = "true" + +import bot # noqa:E402 + + +# https://www.starlette.io/database/#test-isolation +@pytest.fixture(scope="session", autouse=True) +def create_test_database(): + url = str(bot.TEST_DATABASE_URL) + engine = create_engine(url) + with suppress(ProgrammingError): + drop_database(url) + create_database(url) + bot.store.metadata.create_all(engine) + yield + drop_database(url) + + +@pytest.fixture +async def store(create_test_database): + await bot.store.connect() + yield bot.store + await bot.store.disconnect() + + +@pytest.fixture +def db(store): + return store.db diff --git a/tests/test_bot.py b/tests/test_bot.py index 3b036810..05eabbab 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -6,38 +6,20 @@ import gspread import pytest import pytz +from asynctest import patch from discord.ext import commands from freezegun import freeze_time from syrupy.filters import props -from sqlalchemy import create_engine -from sqlalchemy_utils import create_database, drop_database # Must be before bot import os.environ["TESTING"] = "true" import bot # noqa:E402 +import database # noqa:E402 random.seed(1) -# https://www.starlette.io/database/#test-isolation -@pytest.fixture(scope="session", autouse=True) -def create_test_database(): - url = str(bot.TEST_DATABASE_URL) - engine = create_engine(url) - create_database(url) - bot.store.metadata.create_all(engine) - yield - drop_database(url) - - -@pytest.fixture -async def store(): - await bot.store.connect() - yield bot.store - await bot.store.disconnect() - - @pytest.mark.parametrize( "word", ( @@ -93,10 +75,12 @@ def test_idiom(snapshot, spoiler): @pytest.fixture -def mock_worksheet(monkeypatch): - monkeypatch.setattr(bot, "SCHEDULE_SHEET_KEYS", {1234: "abc"}, raising=True) +async def mock_worksheet(monkeypatch, db): monkeypatch.setattr(bot, "GOOGLE_PRIVATE_KEY", "fake", raising=True) - with mock.patch("bot.get_practice_worksheet_for_guild") as mock_get_worksheet: + await db.execute( + database.guild_settings.insert(), {"guild_id": 1234, "schedule_sheet_key": "abc"} + ) + with patch("bot.get_practice_worksheet_for_guild") as mock_get_worksheet: WorksheetMock = mock.Mock(spec=gspread.Worksheet) WorksheetMock.get_all_values.return_value = [ ["docs", "more docs", ""], @@ -121,15 +105,17 @@ def mock_worksheet(monkeypatch): "9/27", ), ) +@pytest.mark.asyncio @freeze_time("2020-09-25 14:00:00") -def test_schedule(snapshot, mock_worksheet, when): - result = bot.schedule_impl(1234, when) +async def test_schedule(snapshot, mock_worksheet, store, when): + result = await bot.schedule_impl(1234, when) assert result == snapshot +@pytest.mark.asyncio @freeze_time("2020-09-25 14:00:00") -def test_schedule_no_practices(snapshot, mock_worksheet): - result = bot.schedule_impl(1234, "9/28/2020") +async def test_schedule_no_practices(snapshot, mock_worksheet): + result = await bot.schedule_impl(1234, "9/28/2020") embed = result["embed"] assert "September 28" in embed.description assert "There are no scheduled practices yet" in embed.description diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 00000000..7df9cce3 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,76 @@ +import datetime as dt +import random + +import pytest +from freezegun import freeze_time + +import database + +pytestmark = pytest.mark.asyncio + +random.seed(1) + + +async def test_get_topic_no_guild(store, db): + await db.execute(database.topics.insert(), {"id": 1, "content": "What's up?"}) + result = await store.get_topic_for_guild() + assert result == "What's up?" + + +@freeze_time("2020-09-25 14:00:00") +async def test_get_topic_for_guild(store, db): + await db.execute(database.topics.insert(), {"id": 1, "content": "What's up?"}) + result = await store.get_topic_for_guild(123) + assert result == "What's up?" + usage = await db.fetch_one( + database.topic_usages.select().where(database.topic_usages.c.topic_id == 1) + ) + assert usage is not None + assert usage.get("last_used_at").date() == dt.date(2020, 9, 25) + + +async def test_get_topic_for_guild_with_used_topic(store, db): + result = await db.execute_many( + database.topics.insert(), + values=( + {"id": 1, "content": "What's up?"}, + {"id": 2, "content": "Why did you learn ASL?"}, + ), + ) + await db.execute( + database.topic_usages.insert(), + { + "guild_id": 123, + "topic_id": 2, + "last_used_at": dt.datetime(2020, 11, 17, tzinfo=dt.timezone.utc), + }, + ) + result = await store.get_topic_for_guild(123) + assert result == "What's up?" + + +async def test_get_topic_for_guild_with_all_topics_used(store, db): + result = await db.execute_many( + database.topics.insert(), + values=( + {"id": 1, "content": "What's up?"}, + {"id": 2, "content": "Why did you learn ASL?"}, + ), + ) + await db.execute_many( + database.topic_usages.insert(), + [ + { + "guild_id": 123, + "topic_id": 1, + "last_used_at": dt.datetime(2020, 11, 16, tzinfo=dt.timezone.utc), + }, + { + "guild_id": 123, + "topic_id": 2, + "last_used_at": dt.datetime(2020, 11, 17, tzinfo=dt.timezone.utc), + }, + ], + ) + result = await store.get_topic_for_guild(123) + assert result == "What's up?"