MosswartOverlord/main.py
2025-05-09 23:31:01 +00:00

302 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta, timezone
import json
import sqlite3
from typing import Dict
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from typing import Optional
from db import init_db, save_snapshot, DB_FILE
import asyncio
from starlette.concurrency import run_in_threadpool
# ------------------------------------------------------------------
app = FastAPI()
# test
# In-memory store of the last packet per character
live_snapshots: Dict[str, dict] = {}
SHARED_SECRET = "your_shared_secret"
# LOG_FILE = "telemetry_log.jsonl"
# ------------------------------------------------------------------
ACTIVE_WINDOW = timedelta(seconds=30) # player is “online” if seen in last 30 s
class TelemetrySnapshot(BaseModel):
character_name: str
char_tag: str
session_id: str
timestamp: datetime
ew: float # +E / W
ns: float # +N / S
z: float
kills: int
kills_per_hour: Optional[str] = None # now optional
onlinetime: Optional[str] = None # now optional
deaths: int
rares_found: int
prismatic_taper_count: int
vt_state: str
@app.on_event("startup")
def on_startup():
init_db()
# ------------------------ GET -----------------------------------
@app.get("/debug")
def debug():
return {"status": "OK"}
@app.get("/live", response_model=dict)
@app.get("/live/", response_model=dict)
def get_live_players():
# compute cutoff once
now_utc = datetime.now(timezone.utc)
cutoff = now_utc - ACTIVE_WINDOW
cutoff_sql = cutoff.strftime("%Y-%m-%d %H:%M:%S")
try:
with sqlite3.connect(DB_FILE) as conn:
conn.row_factory = sqlite3.Row
query = """
SELECT *
FROM live_state
WHERE datetime(timestamp) > datetime(?, 'utc')
"""
rows = conn.execute(query, (cutoff_sql,)).fetchall()
except sqlite3.Error as e:
# log e if you have logging set up
raise HTTPException(status_code=500, detail="Database error")
# build list of dicts
players = []
for r in rows:
players.append(dict(r))
return JSONResponse(content={"players": players})
@app.get("/history/")
@app.get("/history")
def get_history(
from_ts: str | None = Query(None, alias="from"),
to_ts: str | None = Query(None, alias="to"),
):
"""
Returns a timeordered list of telemetry snapshots:
- timestamp: ISO8601 string
- character_name: str
- kills: cumulative kill count (int)
- kph: kills_per_hour (float)
"""
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
# Build the base query
sql = """
SELECT
timestamp,
character_name,
kills,
CAST(kills_per_hour AS REAL) AS kph
FROM telemetry_log
"""
params: list[str] = []
conditions: list[str] = []
# Add optional filters
if from_ts:
conditions.append("timestamp >= ?")
params.append(from_ts)
if to_ts:
conditions.append("timestamp <= ?")
params.append(to_ts)
if conditions:
sql += " WHERE " + " AND ".join(conditions)
sql += " ORDER BY timestamp"
rows = conn.execute(sql, params).fetchall()
conn.close()
data = [
{
"timestamp": row["timestamp"],
"character_name": row["character_name"],
"kills": row["kills"],
"kph": row["kph"],
}
for row in rows
]
return JSONResponse(content={"data": data})
# ------------------------ GET Trails ---------------------------------
@app.get("/trails")
@app.get("/trails/")
def get_trails(
seconds: int = Query(600, ge=0, description="Lookback window in seconds")
):
"""
Return position snapshots (timestamp, character_name, ew, ns, z)
for the past `seconds` seconds.
"""
# match the same string format as stored timestamps (via str(datetime))
cutoff_dt = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(
seconds=seconds
)
cutoff = str(cutoff_dt)
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"""
SELECT timestamp, character_name, ew, ns, z
FROM telemetry_log
WHERE timestamp >= ?
ORDER BY character_name, timestamp
""",
(cutoff,),
).fetchall()
conn.close()
trails = [
{
"timestamp": r["timestamp"],
"character_name": r["character_name"],
"ew": r["ew"],
"ns": r["ns"],
"z": r["z"],
}
for r in rows
]
return JSONResponse(content={"trails": trails})
# -------------------- WebSocket endpoints -----------------------
browser_conns: set[WebSocket] = set()
# Map of registered plugin clients: character_name -> WebSocket
plugin_conns: Dict[str, WebSocket] = {}
async def _broadcast_to_browser_clients(snapshot: dict):
# Ensure all data (e.g. datetime) is JSON-serializable
data = jsonable_encoder(snapshot)
for ws in list(browser_conns):
try:
await ws.send_json(data)
except WebSocketDisconnect:
browser_conns.remove(ws)
@app.websocket("/ws/position")
async def ws_receive_snapshots(
websocket: WebSocket,
secret: str | None = Query(None),
x_plugin_secret: str | None = Header(None)
):
# Verify shared secret from query parameter or header
key = secret or x_plugin_secret
if key != SHARED_SECRET:
# Reject without completing the WebSocket handshake
await websocket.close(code=1008)
return
# Accept the WebSocket connection
await websocket.accept()
print(f"[WS] Plugin connected: {websocket.client}")
try:
while True:
# Read next text frame
try:
raw = await websocket.receive_text()
# Debug: log all incoming plugin WebSocket messages
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
except WebSocketDisconnect:
print(f"[WS] Plugin disconnected: {websocket.client}")
break
# Parse JSON payload
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = data.get("type")
# Registration message: map character to this socket
if msg_type == "register":
name = data.get("character_name") or data.get("player_name")
if isinstance(name, str):
plugin_conns[name] = websocket
continue
# Telemetry message: save to DB and broadcast
if msg_type == "telemetry":
payload = data.copy()
payload.pop("type", None)
snap = TelemetrySnapshot.parse_obj(payload)
live_snapshots[snap.character_name] = snap.dict()
await run_in_threadpool(save_snapshot, snap.dict())
await _broadcast_to_browser_clients(snap.dict())
continue
# Chat message: broadcast to browser clients only (no DB write)
if msg_type == "chat":
await _broadcast_to_browser_clients(data)
continue
# Unknown message types are ignored
finally:
# Clean up any plugin registrations for this socket
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
for n in to_remove:
del plugin_conns[n]
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
@app.websocket("/ws/live")
async def ws_live_updates(websocket: WebSocket):
# Browser clients connect here to receive telemetry and chat, and send commands
await websocket.accept()
browser_conns.add(websocket)
try:
while True:
# Receive command messages from browser
try:
data = await websocket.receive_json()
# Debug: log all incoming browser WebSocket messages
print(f"[WS-LIVE RX] {websocket.client}: {data}")
except WebSocketDisconnect:
break
# Determine command envelope format (new or legacy)
if "player_name" in data and "command" in data:
# New format: { player_name, command }
target_name = data["player_name"]
payload = data
elif data.get("type") == "command" and "character_name" in data and "text" in data:
# Legacy format: { type: 'command', character_name, text }
target_name = data.get("character_name")
payload = {"player_name": target_name, "command": data.get("text")}
else:
# Not a recognized command envelope
continue
# Forward command envelope to the appropriate plugin WebSocket
target_ws = plugin_conns.get(target_name)
if target_ws:
await target_ws.send_json(payload)
except WebSocketDisconnect:
pass
finally:
browser_conns.remove(websocket)
# -------------------- static frontend ---------------------------
# static frontend
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# list routes for convenience
print("🔍 Registered routes:")
for route in app.routes:
if isinstance(route, APIRoute):
print(f"{route.path} -> {route.methods}")