diff --git a/src/lvmapi/routers/alerts.py b/src/lvmapi/routers/alerts.py index 45933d6..1ba6bb0 100644 --- a/src/lvmapi/routers/alerts.py +++ b/src/lvmapi/routers/alerts.py @@ -13,14 +13,13 @@ import warnings import polars -from aiocache import Cache, cached -from aiocache.serializers import PickleSerializer -from fastapi import APIRouter, Request +from fastapi import APIRouter from pydantic import BaseModel from lvmopstools.weather import get_weather_data, is_weather_data_safe from lvmapi.tools.alerts import enclosure_alerts, spec_temperature_alerts +from lvmapi.tools.general import cache_response class AlertsSummary(BaseModel): @@ -45,17 +44,12 @@ class AlertsSummary(BaseModel): @router.get("") @router.get("/") @router.get("/summary") -@cached( - ttl=60, - cache=Cache.REDIS, # type: ignore - key="alerts_summary", - serializer=PickleSerializer(), - port=6379, - namespace="lvmapi", -) -async def summary(request: Request) -> AlertsSummary: +@cache_response("alerts:summary", ttl=60, response_model=AlertsSummary) +async def summary(): """Summary of alerts.""" + from lvmapi.app import app + now = time.time() tasks: list[asyncio.Task] = [] @@ -110,11 +104,11 @@ async def summary(request: Request) -> AlertsSummary: door_alert = enclosure_alerts_response.get("door_alert", None) # These fake states are just for testing. - if request.app.state.use_fake_states: - humidity_alert = request.app.state.fake_states["humidity_alert"] - wind_alert = request.app.state.fake_states["wind_alert"] - rain_sensor_alarm = request.app.state.fake_states["rain_alert"] - door_alert = request.app.state.fake_states["door_alert"] + if app.state.use_fake_states: + humidity_alert = app.state.fake_states["humidity_alert"] + wind_alert = app.state.fake_states["wind_alert"] + rain_sensor_alarm = app.state.fake_states["rain_alert"] + door_alert = app.state.fake_states["door_alert"] return AlertsSummary( humidity_alert=humidity_alert, diff --git a/src/lvmapi/routers/enclosure.py b/src/lvmapi/routers/enclosure.py index 9d73380..f2027ab 100644 --- a/src/lvmapi/routers/enclosure.py +++ b/src/lvmapi/routers/enclosure.py @@ -13,8 +13,6 @@ from typing import Annotated, Any, Literal -from aiocache import Cache, cached -from aiocache.serializers import PickleSerializer from fastapi import APIRouter, Body, HTTPException, Path from pydantic import BaseModel, Field, create_model, field_validator @@ -23,6 +21,7 @@ from lvmapi.auth import AuthDependency from lvmapi.tasks import move_dome_task +from lvmapi.tools.general import cache_response from lvmapi.tools.rabbitmq import send_clu_command @@ -37,7 +36,9 @@ class PLCStatus(BaseModel): labels: list[str] = [] @field_validator("value", mode="before") - def cast_value(cls, value: str) -> int: + def cast_value(cls, value: str | int) -> int: + if isinstance(value, int): + return value return int(value, 16) @field_validator("labels", mode="after") @@ -93,15 +94,8 @@ class NPSBody(BaseModel): @router.get("") @router.get("/") @router.get("/status") -@cached( - ttl=5, - cache=Cache.REDIS, # type: ignore - key="enclosure_status", - serializer=PickleSerializer(), - port=6379, - namespace="lvmapi", -) -async def status() -> EnclosureStatus: +@cache_response("enclosure:status", ttl=60, response_model=EnclosureStatus) +async def status(): """Performs an emergency shutdown of the enclosure and telescopes.""" try: diff --git a/src/lvmapi/tools/general.py b/src/lvmapi/tools/general.py index dc89fa7..3a1caf7 100644 --- a/src/lvmapi/tools/general.py +++ b/src/lvmapi/tools/general.py @@ -9,17 +9,25 @@ from __future__ import annotations import functools -from datetime import datetime, timedelta +import json +from datetime import datetime, timedelta, timezone -from typing import Any +from typing import Any, Type, TypeVar import psycopg import psycopg.sql +from aiocache import Cache +from fastapi import HTTPException +from pydantic import BaseModel from lvmapi import config -__all__ = ["timed_cache", "get_db_connection", "insert_to_database"] +__all__ = [ + "timed_cache", + "get_db_connection", + "insert_to_database", +] def timed_cache(seconds: float): @@ -39,7 +47,7 @@ def timed_cache(seconds: float): def _wrapper(f): update_delta = timedelta(seconds=seconds) - next_update = datetime.utcnow() + update_delta + next_update = datetime.now(timezone.utc) + update_delta # Apply @lru_cache to f with no cache size limit f = functools.lru_cache(None)(f) @@ -47,7 +55,7 @@ def _wrapper(f): @functools.wraps(f) def _wrapped(*args, **kwargs): nonlocal next_update - now = datetime.utcnow() + now = datetime.now(timezone.utc) if now >= next_update: f.cache_clear() next_update = now + update_delta @@ -123,3 +131,59 @@ async def insert_to_database( for row in data: values = [row.get(col, None) for col in columns] await acursor.execute(query, values) + + +T = TypeVar("T", bound=BaseModel) + + +def cache_response( + key: str, + ttl: int = 60, + namespace: str = "lvmapi", + response_model: Type[T] | None = None, +): + """Caching decorator for FastAPI endpoints. + + See https://dev.to/sivakumarmanoharan/caching-in-fastapi-unlocking-high-performance-development-20ej + + """ + + def decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + cache_key = f"{namespace}:{key}" + + assert Cache.REDIS + cache = Cache.REDIS( + endpoint="localhost", # type: ignore + port=6379, # type: ignore + namespace=namespace, + ) + + # Try to retrieve data from cache + cached_value = await cache.get(cache_key) + if cached_value: + if response_model: + return response_model(**json.loads(cached_value)) + return json.loads(cached_value) + + # Call the actual function if cache is not hit + response: T | Any = await func(*args, **kwargs) + + try: + # Store the response in Redis with a TTL + if response_model: + cacheable = response.model_dump_json() + else: + cacheable = json.dumps(response) + + await cache.set(cache_key, cacheable, ttl=ttl) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error caching data: {e}") + + return response + + return wrapper + + return decorator