MosswartOverlord/main.py

373 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta, timezone
import json
import os
from typing import Dict
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from typing import Optional
# Async database support
from sqlalchemy.dialects.postgresql import insert as pg_insert
from db_async import (
database,
telemetry_events,
char_stats,
rare_stats,
rare_stats_sessions,
spawn_events,
init_db_async
)
import asyncio
# ------------------------------------------------------------------
app = FastAPI()
# test
# In-memory store of the last packet per character
live_snapshots: Dict[str, dict] = {}
SHARED_SECRET = "your_shared_secret"
# LOG_FILE = "telemetry_log.jsonl"
# ------------------------------------------------------------------
ACTIVE_WINDOW = timedelta(seconds=30) # player is “online” if seen in last 30 s
class TelemetrySnapshot(BaseModel):
character_name: str
char_tag: Optional[str] = None
session_id: str
timestamp: datetime
ew: float # +E / W
ns: float # +N / S
z: float
kills: int
kills_per_hour: Optional[float] = None
onlinetime: Optional[str] = None
deaths: int
# Removed from telemetry payload; always enforced to 0 and tracked via rare events
rares_found: int = 0
prismatic_taper_count: int
vt_state: str
# Optional telemetry metrics
mem_mb: Optional[float] = None
cpu_pct: Optional[float] = None
mem_handles: Optional[int] = None
latency_ms: Optional[float] = None
class SpawnEvent(BaseModel):
character_name: str
mob: str
timestamp: datetime
ew: float
ns: float
z: float = 0.0
@app.on_event("startup")
async def on_startup():
# Retry connecting to database on startup to handle DB readiness delays
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
await database.connect()
await init_db_async()
print(f"DB connected on attempt {attempt}")
break
except Exception as e:
print(f"DB connection failed (attempt {attempt}/{max_attempts}): {e}")
if attempt < max_attempts:
await asyncio.sleep(5)
else:
raise RuntimeError(f"Could not connect to database after {max_attempts} attempts")
@app.on_event("shutdown")
async def on_shutdown():
# Disconnect from database
await database.disconnect()
# ------------------------ GET -----------------------------------
@app.get("/debug")
def debug():
return {"status": "OK"}
@app.get("/live", response_model=dict)
@app.get("/live/", response_model=dict)
async def get_live_players():
"""Return recent live telemetry per character (last 30 seconds)."""
cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW
# Include rare counts: total and session-specific
sql = """
SELECT sub.*,
COALESCE(rs.total_rares, 0) AS total_rares,
COALESCE(rss.session_rares, 0) AS session_rares
FROM (
SELECT DISTINCT ON (character_name) *
FROM telemetry_events
ORDER BY character_name, timestamp DESC
) sub
LEFT JOIN rare_stats rs
ON sub.character_name = rs.character_name
LEFT JOIN rare_stats_sessions rss
ON sub.character_name = rss.character_name
AND sub.session_id = rss.session_id
WHERE sub.timestamp > :cutoff
"""
rows = await database.fetch_all(sql, {"cutoff": cutoff})
players = [dict(r) for r in rows]
# Ensure all types (e.g. datetime) are JSON serializable
return JSONResponse(content=jsonable_encoder({"players": players}))
@app.get("/history/")
@app.get("/history")
async def get_history(
from_ts: str | None = Query(None, alias="from"),
to_ts: str | None = Query(None, alias="to"),
):
"""Returns a time-ordered list of telemetry snapshots."""
sql = (
"SELECT timestamp, character_name, kills, kills_per_hour AS kph "
"FROM telemetry_events"
)
values: dict = {}
conditions: list[str] = []
if from_ts:
conditions.append("timestamp >= :from_ts")
values["from_ts"] = from_ts
if to_ts:
conditions.append("timestamp <= :to_ts")
values["to_ts"] = to_ts
if conditions:
sql += " WHERE " + " AND ".join(conditions)
sql += " ORDER BY timestamp"
rows = await database.fetch_all(sql, values)
data = [
{
"timestamp": row["timestamp"],
"character_name": row["character_name"],
"kills": row["kills"],
"kph": row["kph"],
}
for row in rows
]
# Ensure all types (e.g. datetime) are JSON serializable
return JSONResponse(content=jsonable_encoder({"data": data}))
# --- GET Trails ---------------------------------
@app.get("/trails")
@app.get("/trails/")
async def get_trails(
seconds: int = Query(600, ge=0, description="Lookback window in seconds"),
):
"""Return position snapshots (timestamp, character_name, ew, ns, z) for the past `seconds`."""
cutoff = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=seconds)
sql = """
SELECT timestamp, character_name, ew, ns, z
FROM telemetry_events
WHERE timestamp >= :cutoff
ORDER BY character_name, timestamp
"""
rows = await database.fetch_all(sql, {"cutoff": cutoff})
trails = [
{
"timestamp": r["timestamp"],
"character_name": r["character_name"],
"ew": r["ew"],
"ns": r["ns"],
"z": r["z"],
}
for r in rows
]
# Ensure all types (e.g. datetime) are JSON serializable
return JSONResponse(content=jsonable_encoder({"trails": trails}))
# -------------------- WebSocket endpoints -----------------------
browser_conns: set[WebSocket] = set()
# Map of registered plugin clients: character_name -> WebSocket
plugin_conns: Dict[str, WebSocket] = {}
async def _broadcast_to_browser_clients(snapshot: dict):
# Ensure all data (e.g. datetime) is JSON-serializable
data = jsonable_encoder(snapshot)
for ws in list(browser_conns):
try:
await ws.send_json(data)
except WebSocketDisconnect:
browser_conns.remove(ws)
@app.websocket("/ws/position")
async def ws_receive_snapshots(
websocket: WebSocket,
secret: str | None = Query(None),
x_plugin_secret: str | None = Header(None)
):
# Verify shared secret from query parameter or header
key = secret or x_plugin_secret
if key != SHARED_SECRET:
# Reject without completing the WebSocket handshake
await websocket.close(code=1008)
return
# Accept the WebSocket connection
await websocket.accept()
print(f"[WS] Plugin connected: {websocket.client}")
try:
while True:
# Read next text frame
try:
raw = await websocket.receive_text()
# Debug: log all incoming plugin WebSocket messages
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
except WebSocketDisconnect:
print(f"[WS] Plugin disconnected: {websocket.client}")
break
# Parse JSON payload
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = data.get("type")
# Registration message: map character to this socket
if msg_type == "register":
name = data.get("character_name") or data.get("player_name")
if isinstance(name, str):
plugin_conns[name] = websocket
continue
# Spawn event: persist spawn for heatmaps
if msg_type == "spawn":
payload = data.copy()
payload.pop("type", None)
try:
spawn = SpawnEvent.parse_obj(payload)
except Exception:
continue
await database.execute(
spawn_events.insert().values(**spawn.dict())
)
continue
# Telemetry message: save to DB and broadcast
if msg_type == "telemetry":
# Parse telemetry snapshot and update in-memory state
payload = data.copy()
payload.pop("type", None)
snap = TelemetrySnapshot.parse_obj(payload)
live_snapshots[snap.character_name] = snap.dict()
# Persist snapshot to TimescaleDB, force rares_found=0
db_data = snap.dict()
db_data['rares_found'] = 0
await database.execute(
telemetry_events.insert().values(**db_data)
)
# Update persistent kill stats (delta per session)
key = (snap.session_id, snap.character_name)
last = ws_receive_snapshots._last_kills.get(key, 0)
delta = snap.kills - last
if delta > 0:
stmt = pg_insert(char_stats).values(
character_name=snap.character_name,
total_kills=delta
).on_conflict_do_update(
index_elements=["character_name"],
set_={"total_kills": char_stats.c.total_kills + delta},
)
await database.execute(stmt)
ws_receive_snapshots._last_kills[key] = snap.kills
# Broadcast to browser clients
await _broadcast_to_browser_clients(snap.dict())
continue
# Rare event: increment total and session counts
if msg_type == "rare":
name = data.get("character_name")
if isinstance(name, str):
# Total rare count per character
stmt_tot = pg_insert(rare_stats).values(
character_name=name,
total_rares=1
).on_conflict_do_update(
index_elements=["character_name"],
set_={"total_rares": rare_stats.c.total_rares + 1},
)
await database.execute(stmt_tot)
# Session-specific rare count
session_id = live_snapshots.get(name, {}).get("session_id")
if session_id:
stmt_sess = pg_insert(rare_stats_sessions).values(
character_name=name,
session_id=session_id,
session_rares=1
).on_conflict_do_update(
index_elements=["character_name", "session_id"],
set_={"session_rares": rare_stats_sessions.c.session_rares + 1},
)
await database.execute(stmt_sess)
continue
# Chat message: broadcast to browser clients only (no DB write)
if msg_type == "chat":
await _broadcast_to_browser_clients(data)
continue
# Unknown message types are ignored
finally:
# Clean up any plugin registrations for this socket
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
for n in to_remove:
del plugin_conns[n]
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
# In-memory store of last kills per session for delta calculations
ws_receive_snapshots._last_kills = {}
@app.websocket("/ws/live")
async def ws_live_updates(websocket: WebSocket):
# Browser clients connect here to receive telemetry and chat, and send commands
await websocket.accept()
browser_conns.add(websocket)
try:
while True:
# Receive command messages from browser
try:
data = await websocket.receive_json()
# Debug: log all incoming browser WebSocket messages
print(f"[WS-LIVE RX] {websocket.client}: {data}")
except WebSocketDisconnect:
break
# Determine command envelope format (new or legacy)
if "player_name" in data and "command" in data:
# New format: { player_name, command }
target_name = data["player_name"]
payload = data
elif data.get("type") == "command" and "character_name" in data and "text" in data:
# Legacy format: { type: 'command', character_name, text }
target_name = data.get("character_name")
payload = {"player_name": target_name, "command": data.get("text")}
else:
# Not a recognized command envelope
continue
# Forward command envelope to the appropriate plugin WebSocket
target_ws = plugin_conns.get(target_name)
if target_ws:
await target_ws.send_json(payload)
except WebSocketDisconnect:
pass
finally:
browser_conns.remove(websocket)
# -------------------- static frontend ---------------------------
# static frontend
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# list routes for convenience
print("🔍 Registered routes:")
for route in app.routes:
if isinstance(route, APIRoute):
print(f"{route.path} -> {route.methods}")