diff --git a/main.py b/main.py index 49925606..5600afa9 100644 --- a/main.py +++ b/main.py @@ -31,6 +31,52 @@ from db_async import ( init_db_async ) import asyncio +# In-memory caches for REST endpoints +_cached_live: dict = {"players": []} +_cached_trails: dict = {"trails": []} +_cache_task: asyncio.Task | None = None + +async def _refresh_cache_loop() -> None: + """Background task: refresh `/live` and `/trails` caches every 5 seconds.""" + while True: + try: + # Recompute live players (last 30s) + cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW + sql_live = """ + SELECT sub.*, + COALESCE(rs.total_rares, 0) AS total_rares, + COALESCE(rss.session_rares, 0) AS session_rares + FROM ( + SELECT DISTINCT ON (character_name) * + FROM telemetry_events + WHERE timestamp > :cutoff + ORDER BY character_name, timestamp DESC + ) sub + LEFT JOIN rare_stats rs + ON sub.character_name = rs.character_name + LEFT JOIN rare_stats_sessions rss + ON sub.character_name = rss.character_name + AND sub.session_id = rss.session_id + """ + rows = await database.fetch_all(sql_live, {"cutoff": cutoff}) + _cached_live["players"] = [dict(r) for r in rows] + # Recompute trails (last 600s) + cutoff2 = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=600) + sql_trail = """ + SELECT timestamp, character_name, ew, ns, z + FROM telemetry_events + WHERE timestamp >= :cutoff + ORDER BY character_name, timestamp + """ + rows2 = await database.fetch_all(sql_trail, {"cutoff": cutoff2}) + _cached_trails["trails"] = [ + {"timestamp": r["timestamp"], "character_name": r["character_name"], + "ew": r["ew"], "ns": r["ns"], "z": r["z"]} + for r in rows2 + ] + except Exception as e: + print(f"[CACHE] refresh error: {e}") + await asyncio.sleep(5) # ------------------------------------------------------------------ app = FastAPI() @@ -121,13 +167,23 @@ async def on_startup(): await asyncio.sleep(5) else: raise RuntimeError(f"Could not connect to database after {max_attempts} attempts") - + # Start background cache refresh (live & trails) + global _cache_task + _cache_task = asyncio.create_task(_refresh_cache_loop()) @app.on_event("shutdown") async def on_shutdown(): """Event handler triggered when application is shutting down. Ensures the database connection is closed cleanly. """ + # Stop cache refresh task + global _cache_task + if _cache_task: + _cache_task.cancel() + try: + await _cache_task + except asyncio.CancelledError: + pass await database.disconnect() @@ -141,31 +197,8 @@ def debug(): @app.get("/live", response_model=dict) @app.get("/live/", response_model=dict) async def get_live_players(): - """Return recent live telemetry per character (last 30 seconds).""" - cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW - # Build SQL to select the most recent telemetry entry per character: - # - Use DISTINCT ON (character_name) to get latest row for each player - # - Join rare_stats for cumulative counts and rare_stats_sessions for session-specific counts - sql = """ - SELECT sub.*, - COALESCE(rs.total_rares, 0) AS total_rares, - COALESCE(rss.session_rares, 0) AS session_rares - FROM ( - SELECT DISTINCT ON (character_name) * - FROM telemetry_events - WHERE timestamp > :cutoff - ORDER BY character_name, timestamp DESC - ) sub - LEFT JOIN rare_stats rs - ON sub.character_name = rs.character_name - LEFT JOIN rare_stats_sessions rss - ON sub.character_name = rss.character_name - AND sub.session_id = rss.session_id - """ - rows = await database.fetch_all(sql, {"cutoff": cutoff}) - players = [dict(r) for r in rows] - # Ensure all types (e.g. datetime) are JSON serializable - return JSONResponse(content=jsonable_encoder({"players": players})) + """Return cached live telemetry per character.""" + return JSONResponse(content=jsonable_encoder(_cached_live)) @@ -176,28 +209,8 @@ async def get_live_players(): async def get_trails( seconds: int = Query(600, ge=0, description="Lookback window in seconds"), ): - """Return position snapshots (timestamp, character_name, ew, ns, z) for the past `seconds`.""" - cutoff = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=seconds) - # Query position snapshots for all characters since the cutoff time - sql = """ - SELECT timestamp, character_name, ew, ns, z - FROM telemetry_events - WHERE timestamp >= :cutoff - ORDER BY character_name, timestamp - """ - rows = await database.fetch_all(sql, {"cutoff": cutoff}) - trails = [ - { - "timestamp": r["timestamp"], - "character_name": r["character_name"], - "ew": r["ew"], - "ns": r["ns"], - "z": r["z"], - } - for r in rows - ] - # Ensure all types (e.g. datetime) are JSON serializable - return JSONResponse(content=jsonable_encoder({"trails": trails})) + """Return cached trails (updated every 5 seconds).""" + return JSONResponse(content=jsonable_encoder(_cached_trails)) # -------------------- WebSocket endpoints ----------------------- ## WebSocket connection tracking