373 lines
14 KiB
Python
373 lines
14 KiB
Python
from datetime import datetime, timedelta, timezone
|
||
import json
|
||
import os
|
||
from typing import Dict
|
||
|
||
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.routing import APIRoute
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.encoders import jsonable_encoder
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
|
||
# Async database support
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
from db_async import (
|
||
database,
|
||
telemetry_events,
|
||
char_stats,
|
||
rare_stats,
|
||
rare_stats_sessions,
|
||
spawn_events,
|
||
init_db_async
|
||
)
|
||
import asyncio
|
||
|
||
# ------------------------------------------------------------------
|
||
app = FastAPI()
|
||
# test
|
||
# In-memory store of the last packet per character
|
||
live_snapshots: Dict[str, dict] = {}
|
||
|
||
SHARED_SECRET = "your_shared_secret"
|
||
# LOG_FILE = "telemetry_log.jsonl"
|
||
# ------------------------------------------------------------------
|
||
ACTIVE_WINDOW = timedelta(seconds=30) # player is “online” if seen in last 30 s
|
||
|
||
|
||
class TelemetrySnapshot(BaseModel):
|
||
character_name: str
|
||
char_tag: Optional[str] = None
|
||
session_id: str
|
||
timestamp: datetime
|
||
|
||
ew: float # +E / –W
|
||
ns: float # +N / –S
|
||
z: float
|
||
|
||
kills: int
|
||
kills_per_hour: Optional[float] = None
|
||
onlinetime: Optional[str] = None
|
||
deaths: int
|
||
# Removed from telemetry payload; always enforced to 0 and tracked via rare events
|
||
rares_found: int = 0
|
||
prismatic_taper_count: int
|
||
vt_state: str
|
||
# Optional telemetry metrics
|
||
mem_mb: Optional[float] = None
|
||
cpu_pct: Optional[float] = None
|
||
mem_handles: Optional[int] = None
|
||
latency_ms: Optional[float] = None
|
||
|
||
|
||
class SpawnEvent(BaseModel):
|
||
character_name: str
|
||
mob: str
|
||
timestamp: datetime
|
||
ew: float
|
||
ns: float
|
||
z: float = 0.0
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def on_startup():
|
||
# Retry connecting to database on startup to handle DB readiness delays
|
||
max_attempts = 5
|
||
for attempt in range(1, max_attempts + 1):
|
||
try:
|
||
await database.connect()
|
||
await init_db_async()
|
||
print(f"DB connected on attempt {attempt}")
|
||
break
|
||
except Exception as e:
|
||
print(f"DB connection failed (attempt {attempt}/{max_attempts}): {e}")
|
||
if attempt < max_attempts:
|
||
await asyncio.sleep(5)
|
||
else:
|
||
raise RuntimeError(f"Could not connect to database after {max_attempts} attempts")
|
||
|
||
@app.on_event("shutdown")
|
||
async def on_shutdown():
|
||
# Disconnect from database
|
||
await database.disconnect()
|
||
|
||
|
||
|
||
# ------------------------ GET -----------------------------------
|
||
@app.get("/debug")
|
||
def debug():
|
||
return {"status": "OK"}
|
||
|
||
|
||
@app.get("/live", response_model=dict)
|
||
@app.get("/live/", response_model=dict)
|
||
async def get_live_players():
|
||
"""Return recent live telemetry per character (last 30 seconds)."""
|
||
cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW
|
||
# Include rare counts: total and session-specific
|
||
sql = """
|
||
SELECT sub.*,
|
||
COALESCE(rs.total_rares, 0) AS total_rares,
|
||
COALESCE(rss.session_rares, 0) AS session_rares
|
||
FROM (
|
||
SELECT DISTINCT ON (character_name) *
|
||
FROM telemetry_events
|
||
ORDER BY character_name, timestamp DESC
|
||
) sub
|
||
LEFT JOIN rare_stats rs
|
||
ON sub.character_name = rs.character_name
|
||
LEFT JOIN rare_stats_sessions rss
|
||
ON sub.character_name = rss.character_name
|
||
AND sub.session_id = rss.session_id
|
||
WHERE sub.timestamp > :cutoff
|
||
"""
|
||
rows = await database.fetch_all(sql, {"cutoff": cutoff})
|
||
players = [dict(r) for r in rows]
|
||
# Ensure all types (e.g. datetime) are JSON serializable
|
||
return JSONResponse(content=jsonable_encoder({"players": players}))
|
||
|
||
|
||
@app.get("/history/")
|
||
@app.get("/history")
|
||
async def get_history(
|
||
from_ts: str | None = Query(None, alias="from"),
|
||
to_ts: str | None = Query(None, alias="to"),
|
||
):
|
||
"""Returns a time-ordered list of telemetry snapshots."""
|
||
sql = (
|
||
"SELECT timestamp, character_name, kills, kills_per_hour AS kph "
|
||
"FROM telemetry_events"
|
||
)
|
||
values: dict = {}
|
||
conditions: list[str] = []
|
||
if from_ts:
|
||
conditions.append("timestamp >= :from_ts")
|
||
values["from_ts"] = from_ts
|
||
if to_ts:
|
||
conditions.append("timestamp <= :to_ts")
|
||
values["to_ts"] = to_ts
|
||
if conditions:
|
||
sql += " WHERE " + " AND ".join(conditions)
|
||
sql += " ORDER BY timestamp"
|
||
rows = await database.fetch_all(sql, values)
|
||
data = [
|
||
{
|
||
"timestamp": row["timestamp"],
|
||
"character_name": row["character_name"],
|
||
"kills": row["kills"],
|
||
"kph": row["kph"],
|
||
}
|
||
for row in rows
|
||
]
|
||
# Ensure all types (e.g. datetime) are JSON serializable
|
||
return JSONResponse(content=jsonable_encoder({"data": data}))
|
||
|
||
|
||
# --- GET Trails ---------------------------------
|
||
@app.get("/trails")
|
||
@app.get("/trails/")
|
||
async def get_trails(
|
||
seconds: int = Query(600, ge=0, description="Lookback window in seconds"),
|
||
):
|
||
"""Return position snapshots (timestamp, character_name, ew, ns, z) for the past `seconds`."""
|
||
cutoff = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=seconds)
|
||
sql = """
|
||
SELECT timestamp, character_name, ew, ns, z
|
||
FROM telemetry_events
|
||
WHERE timestamp >= :cutoff
|
||
ORDER BY character_name, timestamp
|
||
"""
|
||
rows = await database.fetch_all(sql, {"cutoff": cutoff})
|
||
trails = [
|
||
{
|
||
"timestamp": r["timestamp"],
|
||
"character_name": r["character_name"],
|
||
"ew": r["ew"],
|
||
"ns": r["ns"],
|
||
"z": r["z"],
|
||
}
|
||
for r in rows
|
||
]
|
||
# Ensure all types (e.g. datetime) are JSON serializable
|
||
return JSONResponse(content=jsonable_encoder({"trails": trails}))
|
||
|
||
# -------------------- WebSocket endpoints -----------------------
|
||
browser_conns: set[WebSocket] = set()
|
||
# Map of registered plugin clients: character_name -> WebSocket
|
||
plugin_conns: Dict[str, WebSocket] = {}
|
||
|
||
async def _broadcast_to_browser_clients(snapshot: dict):
|
||
# Ensure all data (e.g. datetime) is JSON-serializable
|
||
data = jsonable_encoder(snapshot)
|
||
for ws in list(browser_conns):
|
||
try:
|
||
await ws.send_json(data)
|
||
except WebSocketDisconnect:
|
||
browser_conns.remove(ws)
|
||
|
||
@app.websocket("/ws/position")
|
||
async def ws_receive_snapshots(
|
||
websocket: WebSocket,
|
||
secret: str | None = Query(None),
|
||
x_plugin_secret: str | None = Header(None)
|
||
):
|
||
# Verify shared secret from query parameter or header
|
||
key = secret or x_plugin_secret
|
||
if key != SHARED_SECRET:
|
||
# Reject without completing the WebSocket handshake
|
||
await websocket.close(code=1008)
|
||
return
|
||
# Accept the WebSocket connection
|
||
await websocket.accept()
|
||
print(f"[WS] Plugin connected: {websocket.client}")
|
||
try:
|
||
while True:
|
||
# Read next text frame
|
||
try:
|
||
raw = await websocket.receive_text()
|
||
# Debug: log all incoming plugin WebSocket messages
|
||
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
|
||
except WebSocketDisconnect:
|
||
print(f"[WS] Plugin disconnected: {websocket.client}")
|
||
break
|
||
# Parse JSON payload
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
msg_type = data.get("type")
|
||
# Registration message: map character to this socket
|
||
if msg_type == "register":
|
||
name = data.get("character_name") or data.get("player_name")
|
||
if isinstance(name, str):
|
||
plugin_conns[name] = websocket
|
||
continue
|
||
# Spawn event: persist spawn for heatmaps
|
||
if msg_type == "spawn":
|
||
payload = data.copy()
|
||
payload.pop("type", None)
|
||
try:
|
||
spawn = SpawnEvent.parse_obj(payload)
|
||
except Exception:
|
||
continue
|
||
await database.execute(
|
||
spawn_events.insert().values(**spawn.dict())
|
||
)
|
||
continue
|
||
# Telemetry message: save to DB and broadcast
|
||
if msg_type == "telemetry":
|
||
# Parse telemetry snapshot and update in-memory state
|
||
payload = data.copy()
|
||
payload.pop("type", None)
|
||
snap = TelemetrySnapshot.parse_obj(payload)
|
||
live_snapshots[snap.character_name] = snap.dict()
|
||
# Persist snapshot to TimescaleDB, force rares_found=0
|
||
db_data = snap.dict()
|
||
db_data['rares_found'] = 0
|
||
await database.execute(
|
||
telemetry_events.insert().values(**db_data)
|
||
)
|
||
# Update persistent kill stats (delta per session)
|
||
key = (snap.session_id, snap.character_name)
|
||
last = ws_receive_snapshots._last_kills.get(key, 0)
|
||
delta = snap.kills - last
|
||
if delta > 0:
|
||
stmt = pg_insert(char_stats).values(
|
||
character_name=snap.character_name,
|
||
total_kills=delta
|
||
).on_conflict_do_update(
|
||
index_elements=["character_name"],
|
||
set_={"total_kills": char_stats.c.total_kills + delta},
|
||
)
|
||
await database.execute(stmt)
|
||
ws_receive_snapshots._last_kills[key] = snap.kills
|
||
# Broadcast to browser clients
|
||
await _broadcast_to_browser_clients(snap.dict())
|
||
continue
|
||
# Rare event: increment total and session counts
|
||
if msg_type == "rare":
|
||
name = data.get("character_name")
|
||
if isinstance(name, str):
|
||
# Total rare count per character
|
||
stmt_tot = pg_insert(rare_stats).values(
|
||
character_name=name,
|
||
total_rares=1
|
||
).on_conflict_do_update(
|
||
index_elements=["character_name"],
|
||
set_={"total_rares": rare_stats.c.total_rares + 1},
|
||
)
|
||
await database.execute(stmt_tot)
|
||
# Session-specific rare count
|
||
session_id = live_snapshots.get(name, {}).get("session_id")
|
||
if session_id:
|
||
stmt_sess = pg_insert(rare_stats_sessions).values(
|
||
character_name=name,
|
||
session_id=session_id,
|
||
session_rares=1
|
||
).on_conflict_do_update(
|
||
index_elements=["character_name", "session_id"],
|
||
set_={"session_rares": rare_stats_sessions.c.session_rares + 1},
|
||
)
|
||
await database.execute(stmt_sess)
|
||
continue
|
||
# Chat message: broadcast to browser clients only (no DB write)
|
||
if msg_type == "chat":
|
||
await _broadcast_to_browser_clients(data)
|
||
continue
|
||
# Unknown message types are ignored
|
||
finally:
|
||
# Clean up any plugin registrations for this socket
|
||
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
|
||
for n in to_remove:
|
||
del plugin_conns[n]
|
||
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
|
||
|
||
# In-memory store of last kills per session for delta calculations
|
||
ws_receive_snapshots._last_kills = {}
|
||
|
||
@app.websocket("/ws/live")
|
||
async def ws_live_updates(websocket: WebSocket):
|
||
# Browser clients connect here to receive telemetry and chat, and send commands
|
||
await websocket.accept()
|
||
browser_conns.add(websocket)
|
||
try:
|
||
while True:
|
||
# Receive command messages from browser
|
||
try:
|
||
data = await websocket.receive_json()
|
||
# Debug: log all incoming browser WebSocket messages
|
||
print(f"[WS-LIVE RX] {websocket.client}: {data}")
|
||
except WebSocketDisconnect:
|
||
break
|
||
# Determine command envelope format (new or legacy)
|
||
if "player_name" in data and "command" in data:
|
||
# New format: { player_name, command }
|
||
target_name = data["player_name"]
|
||
payload = data
|
||
elif data.get("type") == "command" and "character_name" in data and "text" in data:
|
||
# Legacy format: { type: 'command', character_name, text }
|
||
target_name = data.get("character_name")
|
||
payload = {"player_name": target_name, "command": data.get("text")}
|
||
else:
|
||
# Not a recognized command envelope
|
||
continue
|
||
# Forward command envelope to the appropriate plugin WebSocket
|
||
target_ws = plugin_conns.get(target_name)
|
||
if target_ws:
|
||
await target_ws.send_json(payload)
|
||
except WebSocketDisconnect:
|
||
pass
|
||
finally:
|
||
browser_conns.remove(websocket)
|
||
|
||
|
||
# -------------------- static frontend ---------------------------
|
||
# static frontend
|
||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||
|
||
# list routes for convenience
|
||
print("🔍 Registered routes:")
|
||
for route in app.routes:
|
||
if isinstance(route, APIRoute):
|
||
print(f"{route.path} -> {route.methods}")
|