Skip to content

Commit

Permalink
✨ Enhanced cooldown system with bucket types (#57)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
Co-authored-by: nicebots-xyz-bot <[email protected]>
  • Loading branch information
3 people authored Dec 22, 2024
1 parent eecc0c3 commit 125de62
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 8 deletions.
36 changes: 36 additions & 0 deletions src/extensions/nice_errors/handlers/cooldown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) NiceBots
# SPDX-License-Identifier: MIT

from typing import Any, final, override

import discord

from src import custom
from src.i18n.classes import RawTranslation, apply_locale
from src.utils.cooldown import CooldownExceeded

from .base import BaseErrorHandler, ErrorHandlerRType


@final
class CooldownErrorHandler(BaseErrorHandler[CooldownExceeded]):
def __init__(self, translations: dict[str, RawTranslation]) -> None:
self.translations = translations
super().__init__(CooldownExceeded)

@override
async def __call__(
self,
error: CooldownExceeded,
ctx: custom.Context | discord.Interaction,
sendargs: dict[str, Any],
message: str,
report: bool,
) -> ErrorHandlerRType:
translations = apply_locale(self.translations, self._get_locale(ctx))

message = translations.error_cooldown_exceeded

sendargs["ephemeral"] = True

return False, False, message, sendargs
3 changes: 3 additions & 0 deletions src/extensions/nice_errors/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from schema import Optional, Schema

from src import custom
from src.utils.cooldown import CooldownExceeded

from .handlers import error_handler
from .handlers.cooldown import CooldownErrorHandler
from .handlers.forbidden import ForbiddenErrorHandler
from .handlers.generic import GenericErrorHandler
from .handlers.not_found import NotFoundErrorHandler
Expand Down Expand Up @@ -64,3 +66,4 @@ def setup(bot: custom.Bot, config: dict[str, Any]) -> None:
error_handler.add_error_handler(None, GenericErrorHandler(config["translations"]))
error_handler.add_error_handler(commands.CommandNotFound, NotFoundErrorHandler(config["translations"]))
error_handler.add_error_handler(discord.Forbidden, ForbiddenErrorHandler(config["translations"]))
error_handler.add_error_handler(CooldownExceeded, CooldownErrorHandler(config["translations"]))
8 changes: 8 additions & 0 deletions src/extensions/nice_errors/translations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ strings:
it: Ops! Non ho i permessi necessari per farlo.
es-ES: ¡Ups! No tengo el permiso necesario para hacer eso.
ru: Упс! У меня нет необходимых прав для выполнения этого действия.
error_cooldown_exceeded:
en-US: Whoops! You're doing that too fast. Please wait before trying again.
de: Hoppla! Du machst das zu schnell. Bitte warte, bevor du es erneut versuchst.
nl: Oeps! Je doet dat te snel. Wacht even voordat je het opnieuw probeert.
fr: Oups ! Vous faites cela trop vite. Veuillez attendre avant de réessayer.
it: Ops! Stai facendo troppo in fretta. Attendi prima di riprovare.
es-ES: ¡Ups! Estás haciendo eso demasiado rápido. Por favor, espera antes de intentarlo de nuevo.
ru: Упс! Вы делаете это слишком быстро. Пожалуйста, подождите, прежде чем попробовать снова.
error_generic:
en-US: Whoops! An error occurred while executing this command.
de: Hoppla! Bei der Ausführung dieses Kommandos ist ein Fehler aufgetreten.
Expand Down
3 changes: 2 additions & 1 deletion src/extensions/ping/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from src import custom
from src.log import logger
from src.utils.cooldown import cooldown
from src.utils.cooldown import BucketType, cooldown

default = {
"enabled": True,
Expand All @@ -32,6 +32,7 @@ def __init__(self, bot: custom.Bot) -> None:
limit=1,
per=5,
strong=True,
bucket_type=BucketType.USER,
)
async def ping(
self,
Expand Down
68 changes: 61 additions & 7 deletions src/utils/cooldown.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import time
from collections.abc import Awaitable, Callable, Coroutine
from enum import Enum
from functools import wraps
from inspect import isawaitable
from typing import Any, Concatenate, cast
Expand All @@ -15,6 +16,16 @@
type CogCommandFunction[T: commands.Cog, **P] = Callable[Concatenate[T, custom.ApplicationContext, P], Awaitable[None]]


class BucketType(Enum):
DEFAULT = "default" # Uses provided key as is
USER = "user" # Per-user cooldown
MEMBER = "member" # Per-member (user+guild) cooldown
GUILD = "guild" # Per-guild cooldown
CHANNEL = "channel" # Per-channel cooldown
CATEGORY = "category" # Per-category cooldown
ROLE = "role" # Per-role cooldown (uses highest role)


async def parse_reactive_setting[T](value: ReactiveCooldownSetting[T], bot: custom.Bot, ctx: custom.Context) -> T:
if isinstance(value, type):
return value # pyright: ignore [reportReturnType]
Expand All @@ -26,22 +37,58 @@ async def parse_reactive_setting[T](value: ReactiveCooldownSetting[T], bot: cust


class CooldownExceeded(commands.CheckFailure):
def __init__(self, retry_after: float) -> None:
def __init__(self, retry_after: float, bucket_type: BucketType) -> None:
self.retry_after: float = retry_after
super().__init__("You are on cooldown")
self.bucket_type: BucketType = bucket_type
super().__init__(f"You are on {bucket_type.value} cooldown")


# inspired by https://github.com/ItsDrike/code-jam-2024/blob/main/src/utils/ratelimit.py
def get_bucket_key(ctx: custom.ApplicationContext, base_key: str, bucket_type: BucketType) -> str: # noqa: PLR0911
"""Generate a cooldown key based on the bucket type."""
match bucket_type:
case BucketType.USER:
return f"{base_key}:user:{ctx.author.id}"
case BucketType.MEMBER:
return (
f"{base_key}:member:{ctx.guild_id}:{ctx.author.id}" if ctx.guild else f"{base_key}:user:{ctx.author.id}"
)
case BucketType.GUILD:
return f"{base_key}:guild:{ctx.guild_id}" if ctx.guild else base_key
case BucketType.CHANNEL:
return f"{base_key}:channel:{ctx.channel.id}"
case BucketType.CATEGORY:
category_id = ctx.channel.category_id if hasattr(ctx.channel, "category_id") else None
return f"{base_key}:category:{category_id}" if category_id else f"{base_key}:channel:{ctx.channel.id}"
case BucketType.ROLE:
if ctx.guild and hasattr(ctx.author, "roles"):
top_role_id = max((role.id for role in ctx.author.roles), default=0)
return f"{base_key}:role:{top_role_id}"
return f"{base_key}:user:{ctx.author.id}"
case _: # BucketType.DEFAULT
return base_key


def cooldown[C: commands.Cog, **P](
def cooldown[C: commands.Cog, **P]( # noqa: PLR0913
key: ReactiveCooldownSetting[str],
*,
limit: ReactiveCooldownSetting[int],
per: ReactiveCooldownSetting[int],
bucket_type: ReactiveCooldownSetting[BucketType] = BucketType.DEFAULT,
strong: ReactiveCooldownSetting[bool] = False,
cls: ReactiveCooldownSetting[type[CooldownExceeded]] = CooldownExceeded,
) -> Callable[[CogCommandFunction[C, P]], CogCommandFunction[C, P]]:
"""Enhanced cooldown decorator that supports different bucket types.
Args:
key: Base key for the cooldown
limit: Number of uses allowed
per: Time period in seconds
bucket_type: Type of bucket to use for the cooldown
strong: If True, adds current timestamp even if limit is reached
cls: Custom exception class to raise
"""

def inner(func: CogCommandFunction[C, P]) -> CogCommandFunction[C, P]:
@wraps(func)
async def wrapper(self: C, ctx: custom.ApplicationContext, *args: P.args, **kwargs: P.kwargs) -> None:
Expand All @@ -51,17 +98,24 @@ async def wrapper(self: C, ctx: custom.ApplicationContext, *args: P.args, **kwar
per_value: int = await parse_reactive_setting(per, ctx.bot, ctx)
strong_value: bool = await parse_reactive_setting(strong, ctx.bot, ctx)
cls_value: type[CooldownExceeded] = await parse_reactive_setting(cls, ctx.bot, ctx)
bucket_type_value: BucketType = await parse_reactive_setting(bucket_type, ctx.bot, ctx)

# Generate the full cooldown key based on bucket type
full_key = get_bucket_key(ctx, key_value, bucket_type_value)

now = time.time()
time_stamps = cast(tuple[float, ...], await cache.get(key_value, default=(), namespace="cooldown"))
time_stamps = cast(tuple[float, ...], await cache.get(full_key, default=(), namespace="cooldown"))
time_stamps = tuple(filter(lambda x: x > now - per_value, time_stamps))
time_stamps = time_stamps[-limit_value:]

if len(time_stamps) < limit_value or strong_value:
time_stamps = (*time_stamps, now)
await cache.set(key_value, time_stamps, namespace="cooldown", ttl=per_value)
await cache.set(full_key, time_stamps, namespace="cooldown", ttl=per_value)
limit_value += 1 # to account for the current command

if len(time_stamps) >= limit_value:
raise cls_value(min(time_stamps) - now + per_value)
raise cls_value(min(time_stamps) - now + per_value, bucket_type_value)

await func(self, ctx, *args, **kwargs)

return wrapper
Expand Down

0 comments on commit 125de62

Please sign in to comment.