Skip to content

Commit

Permalink
Merge pull request #3709 from KBVE/dev
Browse files Browse the repository at this point in the history
Preparing Alpha Branch
  • Loading branch information
Fudster authored Jan 10, 2025
2 parents 60a58c1 + 04fe508 commit 42d3035
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 9 deletions.
2 changes: 1 addition & 1 deletion apps/pydiscordsh/pydiscordsh/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .api import Routes, CORS
from .apps import TursoDatabase, DiscordServerManager
from .apps import TursoDatabase, DiscordServerManager, Kilobase
from .api import SetupSchema, Hero, DiscordServer, Health, SchemaEngine, Utils
14 changes: 7 additions & 7 deletions apps/pydiscordsh/pydiscordsh/api/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, List
import os, re, html
from sqlmodel import Field, Session, SQLModel, create_engine, select, JSON, Column
from pydantic import validator, root_validator
from pydantic import field_validator, model_validator
import logging
from pydiscordsh.api.utils import Utils

Expand All @@ -22,7 +22,7 @@ def _sanitize_string(value: str, user_id: Optional[str] = None, server_id: Optio
raise ValueError("Invalid content in input: Contains potentially harmful characters.")
return html.escape(sanitized)

@root_validator(pre=True)
@model_validator(mode="before")
def sanitize_all_fields(cls, values):
user_id = values.get('user_id', None) #TODO: Pass user ID into this somwhow once we get users done
server_id = values.get('server_id', None)
Expand Down Expand Up @@ -72,7 +72,7 @@ class DiscordServer(SanitizedBaseModel, table=True):
created_at: Optional[int] = Field(default=None, nullable=False) # UNIX timestamp for creation date
updated_at: Optional[int] = Field(default=None, nullable=True) # UNIX timestamp for update date

@validator("website", "logo", "banner", pre=True, always=True)
@field_validator("website", "logo", "banner", "url")
def validate_common_urls(cls, value):
try:
return Utils.validate_url(value)
Expand All @@ -82,7 +82,7 @@ def validate_common_urls(cls, value):
raise ValueError(f"Invalid URL format for field '{cls.__name__}'. Please provide a valid URL.") from e
return value

@validator("lang", pre=True, always=True)
@field_validator("lang")
def validate_lang(cls, value):
if value:
if len(value) > 2:
Expand All @@ -93,7 +93,7 @@ def validate_lang(cls, value):
raise ValueError(f"Invalid language code: {lang}. Must be one of {', '.join(valid_languages)}.")
return value

@validator("invite", pre=True, always=True)
@field_validator("invite")
def validate_invite(cls, value):
if not value or not isinstance(value, str):
raise ValueError("Invite must be a valid string.")
Expand All @@ -106,13 +106,13 @@ def validate_invite(cls, value):
return value
raise ValueError(f"Invalid invite link or invite code. Got: {value}")

@validator("categories", pre=True, always=True)
@field_validator("categories")
def validate_categories(cls, value):
if value and len(value) > 2:
raise ValueError("Categories list cannot have more than 2 items.")
return value

@validator("video", pre=True, always=True)
@field_validator("video")
def validate_video(cls, value):
youtube_url_pattern = r"(https?://(?:www\.)?(?:youtube\.com/(?:[^/]+/)*[^/]+(?:\?v=|\/)([a-zA-Z0-9_-]{1,50}))|youtu\.be/([a-zA-Z0-9_-]{1,50}))"
if value:
Expand Down
3 changes: 2 additions & 1 deletion apps/pydiscordsh/pydiscordsh/apps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .turso import TursoDatabase
from .discord import DiscordServerManager
from .discord import DiscordServerManager
from .kilobase import Kilobase
100 changes: 100 additions & 0 deletions apps/pydiscordsh/pydiscordsh/apps/kilobase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import jwt
import time
from supabase import create_client, Client
from jwt import ExpiredSignatureError, InvalidTokenError

class Kilobase:
def __init__(self):
"""Initialize the Supabase client and JWT secret."""
self.supabase_url = os.getenv("SUPABASE_URL")
self.supabase_key = os.getenv("SUPABASE_KEY")
self.jwt_secret = os.getenv("JWT_SECRET") or "default_secret"

if not self.supabase_url or not self.supabase_key:
raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set in the environment variables.")

self.client: Client = create_client(self.supabase_url, self.supabase_key)

def issue_jwt(self, user_id: str, expires_in: int = 3600) -> str:
"""
Issue a JWT for a given user.
Args:
user_id (str): The user ID to include in the token.
expires_in (int): Token expiration time in seconds. Default is 1 hour.
Returns:
str: The generated JWT token.
"""
payload = {
"sub": user_id,
"exp": int(time.time()) + expires_in,
"iat": int(time.time())
}
token = jwt.encode(payload, self.jwt_secret, algorithm="HS256")
return token

def verify_jwt(self, token: str) -> dict:
"""
Verify a JWT token.
Args:
token (str): The JWT token to verify.
Returns:
dict: Decoded payload if the token is valid.
Raises:
ValueError: If the token is expired or invalid.
"""
try:
decoded = jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
return decoded
except ExpiredSignatureError:
raise ValueError("Token has expired.")
except InvalidTokenError:
raise ValueError("Invalid token.")

def get_user_by_id(self, user_id: str):
"""
Fetch a user's data from the Supabase `users` table.
Args:
user_id (str): The user ID to query.
Returns:
dict: User data or None if not found.
"""
response = self.client.table("users").select("*").eq("id", user_id).single().execute()
return response.data if response.data else None

def extract_user_id(self, token: str) -> str:
"""
Extract the user ID (UUID) from a Supabase JWT token.
Args:
token (str): The JWT token to decode.
Returns:
str: The user's unique ID (UUID) if valid.
Raises:
ValueError: If the token is expired or invalid.
"""
try:
# Decode the Supabase token using the Supabase key
payload = jwt.decode(token, self.supabase_key, algorithms=["HS256"])

# Supabase stores the user ID under 'sub' (subject) claim
user_id = payload.get("sub")
if not user_id:
raise ValueError("User ID not found in the token.")
return user_id

except ExpiredSignatureError:
raise ValueError("Token has expired.")
except InvalidTokenError:
raise ValueError("Invalid token.")
except Exception as e:
raise ValueError(f"Token verification error: {e}")
Empty file.

0 comments on commit 42d3035

Please sign in to comment.