466 lines
18 KiB
Python
466 lines
18 KiB
Python
"""
|
||
main.py - FastAPI-based telemetry server for Dereth Tracker.
|
||
|
||
This service ingests real-time position and event data from plugin clients via WebSockets,
|
||
stores telemetry and statistics in a TimescaleDB backend, and exposes HTTP and WebSocket
|
||
endpoints for browser clients to retrieve live and historical data, trails, and per-character stats.
|
||
"""
|
||
from datetime import datetime, timedelta, timezone
|
||
import json
|
||
import os
|
||
from typing import Dict
|
||
|
||
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.routing import APIRoute
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.encoders import jsonable_encoder
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
|
||
# Async database support
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
from db_async import (
|
||
database,
|
||
telemetry_events,
|
||
char_stats,
|
||
rare_stats,
|
||
rare_stats_sessions,
|
||
spawn_events,
|
||
rare_events,
|
||
init_db_async
|
||
)
|
||
import asyncio
|
||
# In-memory caches for REST endpoints
|
||
_cached_live: dict = {"players": []}
|
||
_cached_trails: dict = {"trails": []}
|
||
_cache_task: asyncio.Task | None = None
|
||
|
||
async def _refresh_cache_loop() -> None:
|
||
"""Background task: refresh `/live` and `/trails` caches every 5 seconds."""
|
||
while True:
|
||
try:
|
||
# Recompute live players (last 30s)
|
||
cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW
|
||
sql_live = """
|
||
SELECT sub.*,
|
||
COALESCE(rs.total_rares, 0) AS total_rares,
|
||
COALESCE(rss.session_rares, 0) AS session_rares
|
||
FROM (
|
||
SELECT DISTINCT ON (character_name) *
|
||
FROM telemetry_events
|
||
WHERE timestamp > :cutoff
|
||
ORDER BY character_name, timestamp DESC
|
||
) sub
|
||
LEFT JOIN rare_stats rs
|
||
ON sub.character_name = rs.character_name
|
||
LEFT JOIN rare_stats_sessions rss
|
||
ON sub.character_name = rss.character_name
|
||
AND sub.session_id = rss.session_id
|
||
"""
|
||
rows = await database.fetch_all(sql_live, {"cutoff": cutoff})
|
||
_cached_live["players"] = [dict(r) for r in rows]
|
||
# Recompute trails (last 600s)
|
||
cutoff2 = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=600)
|
||
sql_trail = """
|
||
SELECT timestamp, character_name, ew, ns, z
|
||
FROM telemetry_events
|
||
WHERE timestamp >= :cutoff
|
||
ORDER BY character_name, timestamp
|
||
"""
|
||
rows2 = await database.fetch_all(sql_trail, {"cutoff": cutoff2})
|
||
_cached_trails["trails"] = [
|
||
{"timestamp": r["timestamp"], "character_name": r["character_name"],
|
||
"ew": r["ew"], "ns": r["ns"], "z": r["z"]}
|
||
for r in rows2
|
||
]
|
||
except Exception as e:
|
||
print(f"[CACHE] refresh error: {e}")
|
||
await asyncio.sleep(5)
|
||
|
||
# ------------------------------------------------------------------
|
||
app = FastAPI()
|
||
# In-memory store mapping character_name to the most recent telemetry snapshot
|
||
live_snapshots: Dict[str, dict] = {}
|
||
|
||
# Shared secret used to authenticate plugin WebSocket connections (override for production)
|
||
SHARED_SECRET = "your_shared_secret"
|
||
# LOG_FILE = "telemetry_log.jsonl"
|
||
# ------------------------------------------------------------------
|
||
ACTIVE_WINDOW = timedelta(seconds=30) # Time window defining “online” players (last 30 seconds)
|
||
|
||
"""
|
||
Data models for plugin events:
|
||
- TelemetrySnapshot: periodic telemetry data from a player client
|
||
- SpawnEvent: information about a mob spawn event
|
||
- RareEvent: details of a rare mob event
|
||
"""
|
||
|
||
|
||
class TelemetrySnapshot(BaseModel):
|
||
character_name: str
|
||
char_tag: Optional[str] = None
|
||
session_id: str
|
||
timestamp: datetime
|
||
|
||
ew: float # +E / –W
|
||
ns: float # +N / –S
|
||
z: float
|
||
|
||
kills: int
|
||
kills_per_hour: Optional[float] = None
|
||
onlinetime: Optional[str] = None
|
||
deaths: int
|
||
# Removed from telemetry payload; always enforced to 0 and tracked via rare events
|
||
rares_found: int = 0
|
||
prismatic_taper_count: int
|
||
vt_state: str
|
||
# Optional telemetry metrics
|
||
mem_mb: Optional[float] = None
|
||
cpu_pct: Optional[float] = None
|
||
mem_handles: Optional[int] = None
|
||
latency_ms: Optional[float] = None
|
||
|
||
|
||
class SpawnEvent(BaseModel):
|
||
"""
|
||
Model for a spawn event emitted by plugin clients when a mob appears.
|
||
Records character context, mob type, timestamp, and spawn location.
|
||
"""
|
||
character_name: str
|
||
mob: str
|
||
timestamp: datetime
|
||
ew: float
|
||
ns: float
|
||
z: float = 0.0
|
||
|
||
class RareEvent(BaseModel):
|
||
"""
|
||
Model for a rare mob event when a player encounters or discovers a rare entity.
|
||
Includes character, event name, timestamp, and location coordinates.
|
||
"""
|
||
character_name: str
|
||
name: str
|
||
timestamp: datetime
|
||
ew: float
|
||
ns: float
|
||
z: float = 0.0
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def on_startup():
|
||
"""Event handler triggered when application starts up.
|
||
|
||
Attempts to connect to the database with retry logic to accommodate
|
||
potential startup delays (e.g., waiting for Postgres to be ready).
|
||
"""
|
||
max_attempts = 5
|
||
for attempt in range(1, max_attempts + 1):
|
||
try:
|
||
await database.connect()
|
||
await init_db_async()
|
||
print(f"DB connected on attempt {attempt}")
|
||
break
|
||
except Exception as e:
|
||
print(f"DB connection failed (attempt {attempt}/{max_attempts}): {e}")
|
||
if attempt < max_attempts:
|
||
await asyncio.sleep(5)
|
||
else:
|
||
raise RuntimeError(f"Could not connect to database after {max_attempts} attempts")
|
||
# Start background cache refresh (live & trails)
|
||
global _cache_task
|
||
_cache_task = asyncio.create_task(_refresh_cache_loop())
|
||
@app.on_event("shutdown")
|
||
async def on_shutdown():
|
||
"""Event handler triggered when application is shutting down.
|
||
|
||
Ensures the database connection is closed cleanly.
|
||
"""
|
||
# Stop cache refresh task
|
||
global _cache_task
|
||
if _cache_task:
|
||
_cache_task.cancel()
|
||
try:
|
||
await _cache_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
await database.disconnect()
|
||
|
||
|
||
|
||
# ------------------------ GET -----------------------------------
|
||
@app.get("/debug")
|
||
def debug():
|
||
return {"status": "OK"}
|
||
|
||
|
||
@app.get("/live", response_model=dict)
|
||
@app.get("/live/", response_model=dict)
|
||
async def get_live_players():
|
||
"""Return cached live telemetry per character."""
|
||
return JSONResponse(content=jsonable_encoder(_cached_live))
|
||
|
||
|
||
|
||
|
||
# --- GET Trails ---------------------------------
|
||
@app.get("/trails")
|
||
@app.get("/trails/")
|
||
async def get_trails(
|
||
seconds: int = Query(600, ge=0, description="Lookback window in seconds"),
|
||
):
|
||
"""Return cached trails (updated every 5 seconds)."""
|
||
return JSONResponse(content=jsonable_encoder(_cached_trails))
|
||
|
||
# -------------------- WebSocket endpoints -----------------------
|
||
## WebSocket connection tracking
|
||
# Set of browser WebSocket clients subscribed to live updates
|
||
browser_conns: set[WebSocket] = set()
|
||
# Mapping of plugin clients by character_name to their WebSocket for command forwarding
|
||
plugin_conns: Dict[str, WebSocket] = {}
|
||
|
||
async def _broadcast_to_browser_clients(snapshot: dict):
|
||
"""Broadcast a telemetry or chat message to all connected browser clients.
|
||
|
||
Converts any non-serializable types (e.g., datetime) before sending.
|
||
"""
|
||
# Convert snapshot payload to JSON-friendly types
|
||
data = jsonable_encoder(snapshot)
|
||
for ws in list(browser_conns):
|
||
try:
|
||
await ws.send_json(data)
|
||
except WebSocketDisconnect:
|
||
browser_conns.remove(ws)
|
||
|
||
@app.websocket("/ws/position")
|
||
async def ws_receive_snapshots(
|
||
websocket: WebSocket,
|
||
secret: str | None = Query(None),
|
||
x_plugin_secret: str | None = Header(None)
|
||
):
|
||
"""WebSocket endpoint for plugin clients to send telemetry and events.
|
||
|
||
Validates a shared secret for authentication, then listens for messages of
|
||
various types (register, spawn, telemetry, rare, chat) and handles each:
|
||
- register: record plugin WebSocket for command forwarding
|
||
- spawn: persist spawn event
|
||
- telemetry: store snapshot, update stats, broadcast to browsers
|
||
- rare: update total and session rare counts, persist event
|
||
- chat: broadcast chat messages to browsers
|
||
"""
|
||
# Authenticate plugin connection using shared secret
|
||
key = secret or x_plugin_secret
|
||
if key != SHARED_SECRET:
|
||
# Reject without completing the WebSocket handshake
|
||
await websocket.close(code=1008)
|
||
return
|
||
# Accept the WebSocket connection
|
||
await websocket.accept()
|
||
print(f"[WS] Plugin connected: {websocket.client}")
|
||
try:
|
||
while True:
|
||
# Read next text frame
|
||
try:
|
||
raw = await websocket.receive_text()
|
||
# Debug: log all incoming plugin WebSocket messages
|
||
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
|
||
except WebSocketDisconnect:
|
||
print(f"[WS] Plugin disconnected: {websocket.client}")
|
||
break
|
||
# Parse JSON payload
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
msg_type = data.get("type")
|
||
# --- Registration: associate character_name with this plugin socket ---
|
||
if msg_type == "register":
|
||
name = data.get("character_name") or data.get("player_name")
|
||
if isinstance(name, str):
|
||
plugin_conns[name] = websocket
|
||
continue
|
||
# --- Spawn event: persist to spawn_events table ---
|
||
if msg_type == "spawn":
|
||
payload = data.copy()
|
||
payload.pop("type", None)
|
||
try:
|
||
spawn = SpawnEvent.parse_obj(payload)
|
||
except Exception:
|
||
continue
|
||
await database.execute(
|
||
spawn_events.insert().values(**spawn.dict())
|
||
)
|
||
continue
|
||
# --- Telemetry message: persist snapshot and update kill stats ---
|
||
if msg_type == "telemetry":
|
||
# Parse telemetry snapshot and update in-memory state
|
||
payload = data.copy()
|
||
payload.pop("type", None)
|
||
snap = TelemetrySnapshot.parse_obj(payload)
|
||
live_snapshots[snap.character_name] = snap.dict()
|
||
# Prepare data and compute kill delta
|
||
db_data = snap.dict()
|
||
db_data['rares_found'] = 0
|
||
key = (snap.session_id, snap.character_name)
|
||
last = ws_receive_snapshots._last_kills.get(key, 0)
|
||
delta = snap.kills - last
|
||
# Persist snapshot and any kill delta in a single transaction
|
||
async with database.transaction():
|
||
await database.execute(
|
||
telemetry_events.insert().values(**db_data)
|
||
)
|
||
if delta > 0:
|
||
stmt = pg_insert(char_stats).values(
|
||
character_name=snap.character_name,
|
||
total_kills=delta
|
||
).on_conflict_do_update(
|
||
index_elements=["character_name"],
|
||
set_={"total_kills": char_stats.c.total_kills + delta},
|
||
)
|
||
await database.execute(stmt)
|
||
ws_receive_snapshots._last_kills[key] = snap.kills
|
||
# Broadcast updated snapshot to all browser clients
|
||
await _broadcast_to_browser_clients(snap.dict())
|
||
continue
|
||
# --- Rare event: update total and session counters and persist ---
|
||
if msg_type == "rare":
|
||
name = data.get("character_name")
|
||
if isinstance(name, str):
|
||
# Total rare count per character
|
||
stmt_tot = pg_insert(rare_stats).values(
|
||
character_name=name,
|
||
total_rares=1
|
||
).on_conflict_do_update(
|
||
index_elements=["character_name"],
|
||
set_={"total_rares": rare_stats.c.total_rares + 1},
|
||
)
|
||
await database.execute(stmt_tot)
|
||
# Session-specific rare count
|
||
session_id = live_snapshots.get(name, {}).get("session_id")
|
||
if session_id:
|
||
stmt_sess = pg_insert(rare_stats_sessions).values(
|
||
character_name=name,
|
||
session_id=session_id,
|
||
session_rares=1
|
||
).on_conflict_do_update(
|
||
index_elements=["character_name", "session_id"],
|
||
set_={"session_rares": rare_stats_sessions.c.session_rares + 1},
|
||
)
|
||
await database.execute(stmt_sess)
|
||
# Persist individual rare event for future analysis
|
||
payload = data.copy()
|
||
payload.pop("type", None)
|
||
try:
|
||
rare_ev = RareEvent.parse_obj(payload)
|
||
await database.execute(
|
||
rare_events.insert().values(**rare_ev.dict())
|
||
)
|
||
except Exception:
|
||
pass
|
||
continue
|
||
# --- Chat message: forward chat payload to browser clients ---
|
||
if msg_type == "chat":
|
||
await _broadcast_to_browser_clients(data)
|
||
continue
|
||
# Unknown message types are ignored
|
||
finally:
|
||
# Clean up any plugin registrations for this socket
|
||
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
|
||
for n in to_remove:
|
||
del plugin_conns[n]
|
||
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
|
||
|
||
# In-memory cache of last seen kill counts per (session_id, character_name)
|
||
# Used to compute deltas for updating persistent kill statistics efficiently
|
||
ws_receive_snapshots._last_kills = {}
|
||
|
||
@app.websocket("/ws/live")
|
||
async def ws_live_updates(websocket: WebSocket):
|
||
"""WebSocket endpoint for browser clients to receive live updates and send commands.
|
||
|
||
Manages a set of connected browser clients; listens for incoming command messages
|
||
and forwards them to the appropriate plugin client WebSocket.
|
||
"""
|
||
# Add new browser client to the set
|
||
await websocket.accept()
|
||
browser_conns.add(websocket)
|
||
try:
|
||
while True:
|
||
# Receive command messages from browser
|
||
try:
|
||
data = await websocket.receive_json()
|
||
# Debug: log all incoming browser WebSocket messages
|
||
print(f"[WS-LIVE RX] {websocket.client}: {data}")
|
||
except WebSocketDisconnect:
|
||
break
|
||
# Determine command envelope format (new or legacy)
|
||
if "player_name" in data and "command" in data:
|
||
# New format: { player_name, command }
|
||
target_name = data["player_name"]
|
||
payload = data
|
||
elif data.get("type") == "command" and "character_name" in data and "text" in data:
|
||
# Legacy format: { type: 'command', character_name, text }
|
||
target_name = data.get("character_name")
|
||
payload = {"player_name": target_name, "command": data.get("text")}
|
||
else:
|
||
# Not a recognized command envelope
|
||
continue
|
||
# Forward command envelope to the appropriate plugin WebSocket
|
||
target_ws = plugin_conns.get(target_name)
|
||
if target_ws:
|
||
await target_ws.send_json(payload)
|
||
except WebSocketDisconnect:
|
||
pass
|
||
finally:
|
||
browser_conns.remove(websocket)
|
||
|
||
|
||
## -------------------- static frontend ---------------------------
|
||
## (static mount moved to end of file, below API routes)
|
||
|
||
# list routes for convenience
|
||
print("🔍 Registered HTTP API routes:")
|
||
for route in app.routes:
|
||
if isinstance(route, APIRoute):
|
||
# Log the path and allowed methods for each API route
|
||
print(f"{route.path} -> {route.methods}")
|
||
# Add stats endpoint for per-character metrics
|
||
@app.get("/stats/{character_name}")
|
||
async def get_stats(character_name: str):
|
||
"""
|
||
HTTP GET endpoint to retrieve per-character metrics:
|
||
- latest_snapshot: most recent telemetry entry for the character
|
||
- total_kills: accumulated kills from char_stats
|
||
- total_rares: accumulated rares from rare_stats
|
||
Returns 404 if character has no recorded telemetry.
|
||
"""
|
||
# Latest snapshot
|
||
sql_snap = (
|
||
"SELECT * FROM telemetry_events "
|
||
"WHERE character_name = :cn "
|
||
"ORDER BY timestamp DESC LIMIT 1"
|
||
)
|
||
snap = await database.fetch_one(sql_snap, {"cn": character_name})
|
||
if not snap:
|
||
raise HTTPException(status_code=404, detail="Character not found")
|
||
snap_dict = dict(snap)
|
||
# Total kills
|
||
sql_kills = "SELECT total_kills FROM char_stats WHERE character_name = :cn"
|
||
row_kills = await database.fetch_one(sql_kills, {"cn": character_name})
|
||
total_kills = row_kills["total_kills"] if row_kills else 0
|
||
# Total rares
|
||
sql_rares = "SELECT total_rares FROM rare_stats WHERE character_name = :cn"
|
||
row_rares = await database.fetch_one(sql_rares, {"cn": character_name})
|
||
total_rares = row_rares["total_rares"] if row_rares else 0
|
||
result = {
|
||
"character_name": character_name,
|
||
"latest_snapshot": snap_dict,
|
||
"total_kills": total_kills,
|
||
"total_rares": total_rares,
|
||
}
|
||
return JSONResponse(content=jsonable_encoder(result))
|
||
|
||
# -------------------- static frontend ---------------------------
|
||
# Serve SPA files (catch-all for frontend routes)
|
||
# Mount the single-page application frontend (static assets) at root path
|
||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|