Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preparing Alpha Branch #3709

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/pydiscordsh/pydiscordsh/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .api import Routes, CORS
from .apps import TursoDatabase, DiscordServerManager
from .apps import TursoDatabase, DiscordServerManager, Kilobase
from .api import SetupSchema, Hero, DiscordServer, Health, SchemaEngine, Utils
14 changes: 7 additions & 7 deletions apps/pydiscordsh/pydiscordsh/api/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, List
import os, re, html
from sqlmodel import Field, Session, SQLModel, create_engine, select, JSON, Column
from pydantic import validator, root_validator
from pydantic import field_validator, model_validator
import logging
from pydiscordsh.api.utils import Utils

Expand All @@ -22,7 +22,7 @@ def _sanitize_string(value: str, user_id: Optional[str] = None, server_id: Optio
raise ValueError("Invalid content in input: Contains potentially harmful characters.")
return html.escape(sanitized)

@root_validator(pre=True)
@model_validator(mode="before")
def sanitize_all_fields(cls, values):
user_id = values.get('user_id', None) #TODO: Pass user ID into this somwhow once we get users done
server_id = values.get('server_id', None)
Expand Down Expand Up @@ -72,7 +72,7 @@ class DiscordServer(SanitizedBaseModel, table=True):
created_at: Optional[int] = Field(default=None, nullable=False) # UNIX timestamp for creation date
updated_at: Optional[int] = Field(default=None, nullable=True) # UNIX timestamp for update date

@validator("website", "logo", "banner", pre=True, always=True)
@field_validator("website", "logo", "banner", "url")
def validate_common_urls(cls, value):
try:
return Utils.validate_url(value)
Expand All @@ -82,7 +82,7 @@ def validate_common_urls(cls, value):
raise ValueError(f"Invalid URL format for field '{cls.__name__}'. Please provide a valid URL.") from e
return value

@validator("lang", pre=True, always=True)
@field_validator("lang")
def validate_lang(cls, value):
if value:
if len(value) > 2:
Expand All @@ -93,7 +93,7 @@ def validate_lang(cls, value):
raise ValueError(f"Invalid language code: {lang}. Must be one of {', '.join(valid_languages)}.")
return value

@validator("invite", pre=True, always=True)
@field_validator("invite")
def validate_invite(cls, value):
if not value or not isinstance(value, str):
raise ValueError("Invite must be a valid string.")
Expand All @@ -106,13 +106,13 @@ def validate_invite(cls, value):
return value
raise ValueError(f"Invalid invite link or invite code. Got: {value}")

@validator("categories", pre=True, always=True)
@field_validator("categories")
def validate_categories(cls, value):
if value and len(value) > 2:
raise ValueError("Categories list cannot have more than 2 items.")
return value

@validator("video", pre=True, always=True)
@field_validator("video")
def validate_video(cls, value):
youtube_url_pattern = r"(https?://(?:www\.)?(?:youtube\.com/(?:[^/]+/)*[^/]+(?:\?v=|\/)([a-zA-Z0-9_-]{1,50}))|youtu\.be/([a-zA-Z0-9_-]{1,50}))"
if value:
Expand Down
3 changes: 2 additions & 1 deletion apps/pydiscordsh/pydiscordsh/apps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .turso import TursoDatabase
from .discord import DiscordServerManager
from .discord import DiscordServerManager
from .kilobase import Kilobase
100 changes: 100 additions & 0 deletions apps/pydiscordsh/pydiscordsh/apps/kilobase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import jwt
import time
from supabase import create_client, Client
from jwt import ExpiredSignatureError, InvalidTokenError

class Kilobase:
def __init__(self):
"""Initialize the Supabase client and JWT secret."""
self.supabase_url = os.getenv("SUPABASE_URL")
self.supabase_key = os.getenv("SUPABASE_KEY")
self.jwt_secret = os.getenv("JWT_SECRET") or "default_secret"

if not self.supabase_url or not self.supabase_key:
raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set in the environment variables.")

self.client: Client = create_client(self.supabase_url, self.supabase_key)

def issue_jwt(self, user_id: str, expires_in: int = 3600) -> str:
"""
Issue a JWT for a given user.

Args:
user_id (str): The user ID to include in the token.
expires_in (int): Token expiration time in seconds. Default is 1 hour.

Returns:
str: The generated JWT token.
"""
payload = {
"sub": user_id,
"exp": int(time.time()) + expires_in,
"iat": int(time.time())
}
token = jwt.encode(payload, self.jwt_secret, algorithm="HS256")
return token

def verify_jwt(self, token: str) -> dict:
"""
Verify a JWT token.

Args:
token (str): The JWT token to verify.

Returns:
dict: Decoded payload if the token is valid.

Raises:
ValueError: If the token is expired or invalid.
"""
try:
decoded = jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
return decoded
except ExpiredSignatureError:
raise ValueError("Token has expired.")
except InvalidTokenError:
raise ValueError("Invalid token.")

def get_user_by_id(self, user_id: str):
"""
Fetch a user's data from the Supabase `users` table.

Args:
user_id (str): The user ID to query.

Returns:
dict: User data or None if not found.
"""
response = self.client.table("users").select("*").eq("id", user_id).single().execute()
return response.data if response.data else None

def extract_user_id(self, token: str) -> str:
"""
Extract the user ID (UUID) from a Supabase JWT token.

Args:
token (str): The JWT token to decode.

Returns:
str: The user's unique ID (UUID) if valid.

Raises:
ValueError: If the token is expired or invalid.
"""
try:
# Decode the Supabase token using the Supabase key
payload = jwt.decode(token, self.supabase_key, algorithms=["HS256"])

# Supabase stores the user ID under 'sub' (subject) claim
user_id = payload.get("sub")
if not user_id:
raise ValueError("User ID not found in the token.")
return user_id

except ExpiredSignatureError:
raise ValueError("Token has expired.")
except InvalidTokenError:
raise ValueError("Invalid token.")
except Exception as e:
raise ValueError(f"Token verification error: {e}")
Empty file.
Loading