MosswartOverlord/main.py
2025-05-26 21:47:56 +00:00

466 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
main.py - FastAPI-based telemetry server for Dereth Tracker.
This service ingests real-time position and event data from plugin clients via WebSockets,
stores telemetry and statistics in a TimescaleDB backend, and exposes HTTP and WebSocket
endpoints for browser clients to retrieve live and historical data, trails, and per-character stats.
"""
from datetime import datetime, timedelta, timezone
import json
import os
from typing import Dict
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from typing import Optional
# Async database support
from sqlalchemy.dialects.postgresql import insert as pg_insert
from db_async import (
database,
telemetry_events,
char_stats,
rare_stats,
rare_stats_sessions,
spawn_events,
rare_events,
init_db_async
)
import asyncio
# In-memory caches for REST endpoints
_cached_live: dict = {"players": []}
_cached_trails: dict = {"trails": []}
_cache_task: asyncio.Task | None = None
async def _refresh_cache_loop() -> None:
"""Background task: refresh `/live` and `/trails` caches every 5 seconds."""
while True:
try:
# Recompute live players (last 30s)
cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW
sql_live = """
SELECT sub.*,
COALESCE(rs.total_rares, 0) AS total_rares,
COALESCE(rss.session_rares, 0) AS session_rares
FROM (
SELECT DISTINCT ON (character_name) *
FROM telemetry_events
WHERE timestamp > :cutoff
ORDER BY character_name, timestamp DESC
) sub
LEFT JOIN rare_stats rs
ON sub.character_name = rs.character_name
LEFT JOIN rare_stats_sessions rss
ON sub.character_name = rss.character_name
AND sub.session_id = rss.session_id
"""
rows = await database.fetch_all(sql_live, {"cutoff": cutoff})
_cached_live["players"] = [dict(r) for r in rows]
# Recompute trails (last 600s)
cutoff2 = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=600)
sql_trail = """
SELECT timestamp, character_name, ew, ns, z
FROM telemetry_events
WHERE timestamp >= :cutoff
ORDER BY character_name, timestamp
"""
rows2 = await database.fetch_all(sql_trail, {"cutoff": cutoff2})
_cached_trails["trails"] = [
{"timestamp": r["timestamp"], "character_name": r["character_name"],
"ew": r["ew"], "ns": r["ns"], "z": r["z"]}
for r in rows2
]
except Exception as e:
print(f"[CACHE] refresh error: {e}")
await asyncio.sleep(5)
# ------------------------------------------------------------------
app = FastAPI()
# In-memory store mapping character_name to the most recent telemetry snapshot
live_snapshots: Dict[str, dict] = {}
# Shared secret used to authenticate plugin WebSocket connections (override for production)
SHARED_SECRET = "your_shared_secret"
# LOG_FILE = "telemetry_log.jsonl"
# ------------------------------------------------------------------
ACTIVE_WINDOW = timedelta(seconds=30) # Time window defining “online” players (last 30 seconds)
"""
Data models for plugin events:
- TelemetrySnapshot: periodic telemetry data from a player client
- SpawnEvent: information about a mob spawn event
- RareEvent: details of a rare mob event
"""
class TelemetrySnapshot(BaseModel):
character_name: str
char_tag: Optional[str] = None
session_id: str
timestamp: datetime
ew: float # +E / W
ns: float # +N / S
z: float
kills: int
kills_per_hour: Optional[float] = None
onlinetime: Optional[str] = None
deaths: int
# Removed from telemetry payload; always enforced to 0 and tracked via rare events
rares_found: int = 0
prismatic_taper_count: int
vt_state: str
# Optional telemetry metrics
mem_mb: Optional[float] = None
cpu_pct: Optional[float] = None
mem_handles: Optional[int] = None
latency_ms: Optional[float] = None
class SpawnEvent(BaseModel):
"""
Model for a spawn event emitted by plugin clients when a mob appears.
Records character context, mob type, timestamp, and spawn location.
"""
character_name: str
mob: str
timestamp: datetime
ew: float
ns: float
z: float = 0.0
class RareEvent(BaseModel):
"""
Model for a rare mob event when a player encounters or discovers a rare entity.
Includes character, event name, timestamp, and location coordinates.
"""
character_name: str
name: str
timestamp: datetime
ew: float
ns: float
z: float = 0.0
@app.on_event("startup")
async def on_startup():
"""Event handler triggered when application starts up.
Attempts to connect to the database with retry logic to accommodate
potential startup delays (e.g., waiting for Postgres to be ready).
"""
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
await database.connect()
await init_db_async()
print(f"DB connected on attempt {attempt}")
break
except Exception as e:
print(f"DB connection failed (attempt {attempt}/{max_attempts}): {e}")
if attempt < max_attempts:
await asyncio.sleep(5)
else:
raise RuntimeError(f"Could not connect to database after {max_attempts} attempts")
# Start background cache refresh (live & trails)
global _cache_task
_cache_task = asyncio.create_task(_refresh_cache_loop())
@app.on_event("shutdown")
async def on_shutdown():
"""Event handler triggered when application is shutting down.
Ensures the database connection is closed cleanly.
"""
# Stop cache refresh task
global _cache_task
if _cache_task:
_cache_task.cancel()
try:
await _cache_task
except asyncio.CancelledError:
pass
await database.disconnect()
# ------------------------ GET -----------------------------------
@app.get("/debug")
def debug():
return {"status": "OK"}
@app.get("/live", response_model=dict)
@app.get("/live/", response_model=dict)
async def get_live_players():
"""Return cached live telemetry per character."""
return JSONResponse(content=jsonable_encoder(_cached_live))
# --- GET Trails ---------------------------------
@app.get("/trails")
@app.get("/trails/")
async def get_trails(
seconds: int = Query(600, ge=0, description="Lookback window in seconds"),
):
"""Return cached trails (updated every 5 seconds)."""
return JSONResponse(content=jsonable_encoder(_cached_trails))
# -------------------- WebSocket endpoints -----------------------
## WebSocket connection tracking
# Set of browser WebSocket clients subscribed to live updates
browser_conns: set[WebSocket] = set()
# Mapping of plugin clients by character_name to their WebSocket for command forwarding
plugin_conns: Dict[str, WebSocket] = {}
async def _broadcast_to_browser_clients(snapshot: dict):
"""Broadcast a telemetry or chat message to all connected browser clients.
Converts any non-serializable types (e.g., datetime) before sending.
"""
# Convert snapshot payload to JSON-friendly types
data = jsonable_encoder(snapshot)
for ws in list(browser_conns):
try:
await ws.send_json(data)
except WebSocketDisconnect:
browser_conns.remove(ws)
@app.websocket("/ws/position")
async def ws_receive_snapshots(
websocket: WebSocket,
secret: str | None = Query(None),
x_plugin_secret: str | None = Header(None)
):
"""WebSocket endpoint for plugin clients to send telemetry and events.
Validates a shared secret for authentication, then listens for messages of
various types (register, spawn, telemetry, rare, chat) and handles each:
- register: record plugin WebSocket for command forwarding
- spawn: persist spawn event
- telemetry: store snapshot, update stats, broadcast to browsers
- rare: update total and session rare counts, persist event
- chat: broadcast chat messages to browsers
"""
# Authenticate plugin connection using shared secret
key = secret or x_plugin_secret
if key != SHARED_SECRET:
# Reject without completing the WebSocket handshake
await websocket.close(code=1008)
return
# Accept the WebSocket connection
await websocket.accept()
print(f"[WS] Plugin connected: {websocket.client}")
try:
while True:
# Read next text frame
try:
raw = await websocket.receive_text()
# Debug: log all incoming plugin WebSocket messages
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
except WebSocketDisconnect:
print(f"[WS] Plugin disconnected: {websocket.client}")
break
# Parse JSON payload
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = data.get("type")
# --- Registration: associate character_name with this plugin socket ---
if msg_type == "register":
name = data.get("character_name") or data.get("player_name")
if isinstance(name, str):
plugin_conns[name] = websocket
continue
# --- Spawn event: persist to spawn_events table ---
if msg_type == "spawn":
payload = data.copy()
payload.pop("type", None)
try:
spawn = SpawnEvent.parse_obj(payload)
except Exception:
continue
await database.execute(
spawn_events.insert().values(**spawn.dict())
)
continue
# --- Telemetry message: persist snapshot and update kill stats ---
if msg_type == "telemetry":
# Parse telemetry snapshot and update in-memory state
payload = data.copy()
payload.pop("type", None)
snap = TelemetrySnapshot.parse_obj(payload)
live_snapshots[snap.character_name] = snap.dict()
# Prepare data and compute kill delta
db_data = snap.dict()
db_data['rares_found'] = 0
key = (snap.session_id, snap.character_name)
last = ws_receive_snapshots._last_kills.get(key, 0)
delta = snap.kills - last
# Persist snapshot and any kill delta in a single transaction
async with database.transaction():
await database.execute(
telemetry_events.insert().values(**db_data)
)
if delta > 0:
stmt = pg_insert(char_stats).values(
character_name=snap.character_name,
total_kills=delta
).on_conflict_do_update(
index_elements=["character_name"],
set_={"total_kills": char_stats.c.total_kills + delta},
)
await database.execute(stmt)
ws_receive_snapshots._last_kills[key] = snap.kills
# Broadcast updated snapshot to all browser clients
await _broadcast_to_browser_clients(snap.dict())
continue
# --- Rare event: update total and session counters and persist ---
if msg_type == "rare":
name = data.get("character_name")
if isinstance(name, str):
# Total rare count per character
stmt_tot = pg_insert(rare_stats).values(
character_name=name,
total_rares=1
).on_conflict_do_update(
index_elements=["character_name"],
set_={"total_rares": rare_stats.c.total_rares + 1},
)
await database.execute(stmt_tot)
# Session-specific rare count
session_id = live_snapshots.get(name, {}).get("session_id")
if session_id:
stmt_sess = pg_insert(rare_stats_sessions).values(
character_name=name,
session_id=session_id,
session_rares=1
).on_conflict_do_update(
index_elements=["character_name", "session_id"],
set_={"session_rares": rare_stats_sessions.c.session_rares + 1},
)
await database.execute(stmt_sess)
# Persist individual rare event for future analysis
payload = data.copy()
payload.pop("type", None)
try:
rare_ev = RareEvent.parse_obj(payload)
await database.execute(
rare_events.insert().values(**rare_ev.dict())
)
except Exception:
pass
continue
# --- Chat message: forward chat payload to browser clients ---
if msg_type == "chat":
await _broadcast_to_browser_clients(data)
continue
# Unknown message types are ignored
finally:
# Clean up any plugin registrations for this socket
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
for n in to_remove:
del plugin_conns[n]
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
# In-memory cache of last seen kill counts per (session_id, character_name)
# Used to compute deltas for updating persistent kill statistics efficiently
ws_receive_snapshots._last_kills = {}
@app.websocket("/ws/live")
async def ws_live_updates(websocket: WebSocket):
"""WebSocket endpoint for browser clients to receive live updates and send commands.
Manages a set of connected browser clients; listens for incoming command messages
and forwards them to the appropriate plugin client WebSocket.
"""
# Add new browser client to the set
await websocket.accept()
browser_conns.add(websocket)
try:
while True:
# Receive command messages from browser
try:
data = await websocket.receive_json()
# Debug: log all incoming browser WebSocket messages
print(f"[WS-LIVE RX] {websocket.client}: {data}")
except WebSocketDisconnect:
break
# Determine command envelope format (new or legacy)
if "player_name" in data and "command" in data:
# New format: { player_name, command }
target_name = data["player_name"]
payload = data
elif data.get("type") == "command" and "character_name" in data and "text" in data:
# Legacy format: { type: 'command', character_name, text }
target_name = data.get("character_name")
payload = {"player_name": target_name, "command": data.get("text")}
else:
# Not a recognized command envelope
continue
# Forward command envelope to the appropriate plugin WebSocket
target_ws = plugin_conns.get(target_name)
if target_ws:
await target_ws.send_json(payload)
except WebSocketDisconnect:
pass
finally:
browser_conns.remove(websocket)
## -------------------- static frontend ---------------------------
## (static mount moved to end of file, below API routes)
# list routes for convenience
print("🔍 Registered HTTP API routes:")
for route in app.routes:
if isinstance(route, APIRoute):
# Log the path and allowed methods for each API route
print(f"{route.path} -> {route.methods}")
# Add stats endpoint for per-character metrics
@app.get("/stats/{character_name}")
async def get_stats(character_name: str):
"""
HTTP GET endpoint to retrieve per-character metrics:
- latest_snapshot: most recent telemetry entry for the character
- total_kills: accumulated kills from char_stats
- total_rares: accumulated rares from rare_stats
Returns 404 if character has no recorded telemetry.
"""
# Latest snapshot
sql_snap = (
"SELECT * FROM telemetry_events "
"WHERE character_name = :cn "
"ORDER BY timestamp DESC LIMIT 1"
)
snap = await database.fetch_one(sql_snap, {"cn": character_name})
if not snap:
raise HTTPException(status_code=404, detail="Character not found")
snap_dict = dict(snap)
# Total kills
sql_kills = "SELECT total_kills FROM char_stats WHERE character_name = :cn"
row_kills = await database.fetch_one(sql_kills, {"cn": character_name})
total_kills = row_kills["total_kills"] if row_kills else 0
# Total rares
sql_rares = "SELECT total_rares FROM rare_stats WHERE character_name = :cn"
row_rares = await database.fetch_one(sql_rares, {"cn": character_name})
total_rares = row_rares["total_rares"] if row_rares else 0
result = {
"character_name": character_name,
"latest_snapshot": snap_dict,
"total_kills": total_kills,
"total_rares": total_rares,
}
return JSONResponse(content=jsonable_encoder(result))
# -------------------- static frontend ---------------------------
# Serve SPA files (catch-all for frontend routes)
# Mount the single-page application frontend (static assets) at root path
app.mount("/", StaticFiles(directory="static", html=True), name="static")