302 lines
9.5 KiB
Python
302 lines
9.5 KiB
Python
from datetime import datetime, timedelta, timezone
|
||
import json
|
||
import sqlite3
|
||
from typing import Dict
|
||
|
||
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.routing import APIRoute
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.encoders import jsonable_encoder
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
|
||
from db import init_db, save_snapshot, DB_FILE
|
||
import asyncio
|
||
from starlette.concurrency import run_in_threadpool
|
||
|
||
# ------------------------------------------------------------------
|
||
app = FastAPI()
|
||
# test
|
||
# In-memory store of the last packet per character
|
||
live_snapshots: Dict[str, dict] = {}
|
||
|
||
SHARED_SECRET = "your_shared_secret"
|
||
# LOG_FILE = "telemetry_log.jsonl"
|
||
# ------------------------------------------------------------------
|
||
ACTIVE_WINDOW = timedelta(seconds=30) # player is “online” if seen in last 30 s
|
||
|
||
|
||
class TelemetrySnapshot(BaseModel):
|
||
character_name: str
|
||
char_tag: str
|
||
session_id: str
|
||
timestamp: datetime
|
||
|
||
ew: float # +E / –W
|
||
ns: float # +N / –S
|
||
z: float
|
||
|
||
kills: int
|
||
kills_per_hour: Optional[str] = None # now optional
|
||
onlinetime: Optional[str] = None # now optional
|
||
deaths: int
|
||
rares_found: int
|
||
prismatic_taper_count: int
|
||
vt_state: str
|
||
|
||
|
||
@app.on_event("startup")
|
||
def on_startup():
|
||
init_db()
|
||
|
||
|
||
|
||
# ------------------------ GET -----------------------------------
|
||
@app.get("/debug")
|
||
def debug():
|
||
return {"status": "OK"}
|
||
|
||
|
||
@app.get("/live", response_model=dict)
|
||
@app.get("/live/", response_model=dict)
|
||
def get_live_players():
|
||
# compute cutoff once
|
||
now_utc = datetime.now(timezone.utc)
|
||
cutoff = now_utc - ACTIVE_WINDOW
|
||
|
||
cutoff_sql = cutoff.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
try:
|
||
with sqlite3.connect(DB_FILE) as conn:
|
||
conn.row_factory = sqlite3.Row
|
||
query = """
|
||
SELECT *
|
||
FROM live_state
|
||
WHERE datetime(timestamp) > datetime(?, 'utc')
|
||
"""
|
||
rows = conn.execute(query, (cutoff_sql,)).fetchall()
|
||
|
||
except sqlite3.Error as e:
|
||
# log e if you have logging set up
|
||
raise HTTPException(status_code=500, detail="Database error")
|
||
|
||
# build list of dicts
|
||
players = []
|
||
for r in rows:
|
||
players.append(dict(r))
|
||
|
||
return JSONResponse(content={"players": players})
|
||
|
||
|
||
@app.get("/history/")
|
||
@app.get("/history")
|
||
def get_history(
|
||
from_ts: str | None = Query(None, alias="from"),
|
||
to_ts: str | None = Query(None, alias="to"),
|
||
):
|
||
"""
|
||
Returns a time‐ordered list of telemetry snapshots:
|
||
- timestamp: ISO8601 string
|
||
- character_name: str
|
||
- kills: cumulative kill count (int)
|
||
- kph: kills_per_hour (float)
|
||
"""
|
||
conn = sqlite3.connect(DB_FILE)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# Build the base query
|
||
sql = """
|
||
SELECT
|
||
timestamp,
|
||
character_name,
|
||
kills,
|
||
CAST(kills_per_hour AS REAL) AS kph
|
||
FROM telemetry_log
|
||
"""
|
||
params: list[str] = []
|
||
conditions: list[str] = []
|
||
|
||
# Add optional filters
|
||
if from_ts:
|
||
conditions.append("timestamp >= ?")
|
||
params.append(from_ts)
|
||
if to_ts:
|
||
conditions.append("timestamp <= ?")
|
||
params.append(to_ts)
|
||
if conditions:
|
||
sql += " WHERE " + " AND ".join(conditions)
|
||
|
||
sql += " ORDER BY timestamp"
|
||
|
||
rows = conn.execute(sql, params).fetchall()
|
||
conn.close()
|
||
|
||
data = [
|
||
{
|
||
"timestamp": row["timestamp"],
|
||
"character_name": row["character_name"],
|
||
"kills": row["kills"],
|
||
"kph": row["kph"],
|
||
}
|
||
for row in rows
|
||
]
|
||
return JSONResponse(content={"data": data})
|
||
|
||
|
||
# ------------------------ GET Trails ---------------------------------
|
||
@app.get("/trails")
|
||
@app.get("/trails/")
|
||
def get_trails(
|
||
seconds: int = Query(600, ge=0, description="Lookback window in seconds")
|
||
):
|
||
"""
|
||
Return position snapshots (timestamp, character_name, ew, ns, z)
|
||
for the past `seconds` seconds.
|
||
"""
|
||
# match the same string format as stored timestamps (via str(datetime))
|
||
cutoff_dt = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(
|
||
seconds=seconds
|
||
)
|
||
cutoff = str(cutoff_dt)
|
||
conn = sqlite3.connect(DB_FILE)
|
||
conn.row_factory = sqlite3.Row
|
||
rows = conn.execute(
|
||
"""
|
||
SELECT timestamp, character_name, ew, ns, z
|
||
FROM telemetry_log
|
||
WHERE timestamp >= ?
|
||
ORDER BY character_name, timestamp
|
||
""",
|
||
(cutoff,),
|
||
).fetchall()
|
||
conn.close()
|
||
trails = [
|
||
{
|
||
"timestamp": r["timestamp"],
|
||
"character_name": r["character_name"],
|
||
"ew": r["ew"],
|
||
"ns": r["ns"],
|
||
"z": r["z"],
|
||
}
|
||
for r in rows
|
||
]
|
||
return JSONResponse(content={"trails": trails})
|
||
|
||
# -------------------- WebSocket endpoints -----------------------
|
||
browser_conns: set[WebSocket] = set()
|
||
# Map of registered plugin clients: character_name -> WebSocket
|
||
plugin_conns: Dict[str, WebSocket] = {}
|
||
|
||
async def _broadcast_to_browser_clients(snapshot: dict):
|
||
# Ensure all data (e.g. datetime) is JSON-serializable
|
||
data = jsonable_encoder(snapshot)
|
||
for ws in list(browser_conns):
|
||
try:
|
||
await ws.send_json(data)
|
||
except WebSocketDisconnect:
|
||
browser_conns.remove(ws)
|
||
|
||
@app.websocket("/ws/position")
|
||
async def ws_receive_snapshots(
|
||
websocket: WebSocket,
|
||
secret: str | None = Query(None),
|
||
x_plugin_secret: str | None = Header(None)
|
||
):
|
||
# Verify shared secret from query parameter or header
|
||
key = secret or x_plugin_secret
|
||
if key != SHARED_SECRET:
|
||
# Reject without completing the WebSocket handshake
|
||
await websocket.close(code=1008)
|
||
return
|
||
# Accept the WebSocket connection
|
||
await websocket.accept()
|
||
print(f"[WS] Plugin connected: {websocket.client}")
|
||
try:
|
||
while True:
|
||
# Read next text frame
|
||
try:
|
||
raw = await websocket.receive_text()
|
||
# Debug: log all incoming plugin WebSocket messages
|
||
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
|
||
except WebSocketDisconnect:
|
||
print(f"[WS] Plugin disconnected: {websocket.client}")
|
||
break
|
||
# Parse JSON payload
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
msg_type = data.get("type")
|
||
# Registration message: map character to this socket
|
||
if msg_type == "register":
|
||
name = data.get("character_name") or data.get("player_name")
|
||
if isinstance(name, str):
|
||
plugin_conns[name] = websocket
|
||
continue
|
||
# Telemetry message: save to DB and broadcast
|
||
if msg_type == "telemetry":
|
||
payload = data.copy()
|
||
payload.pop("type", None)
|
||
snap = TelemetrySnapshot.parse_obj(payload)
|
||
live_snapshots[snap.character_name] = snap.dict()
|
||
await run_in_threadpool(save_snapshot, snap.dict())
|
||
await _broadcast_to_browser_clients(snap.dict())
|
||
continue
|
||
# Chat message: broadcast to browser clients only (no DB write)
|
||
if msg_type == "chat":
|
||
await _broadcast_to_browser_clients(data)
|
||
continue
|
||
# Unknown message types are ignored
|
||
finally:
|
||
# Clean up any plugin registrations for this socket
|
||
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
|
||
for n in to_remove:
|
||
del plugin_conns[n]
|
||
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
|
||
|
||
@app.websocket("/ws/live")
|
||
async def ws_live_updates(websocket: WebSocket):
|
||
# Browser clients connect here to receive telemetry and chat, and send commands
|
||
await websocket.accept()
|
||
browser_conns.add(websocket)
|
||
try:
|
||
while True:
|
||
# Receive command messages from browser
|
||
try:
|
||
data = await websocket.receive_json()
|
||
# Debug: log all incoming browser WebSocket messages
|
||
print(f"[WS-LIVE RX] {websocket.client}: {data}")
|
||
except WebSocketDisconnect:
|
||
break
|
||
# Determine command envelope format (new or legacy)
|
||
if "player_name" in data and "command" in data:
|
||
# New format: { player_name, command }
|
||
target_name = data["player_name"]
|
||
payload = data
|
||
elif data.get("type") == "command" and "character_name" in data and "text" in data:
|
||
# Legacy format: { type: 'command', character_name, text }
|
||
target_name = data.get("character_name")
|
||
payload = {"player_name": target_name, "command": data.get("text")}
|
||
else:
|
||
# Not a recognized command envelope
|
||
continue
|
||
# Forward command envelope to the appropriate plugin WebSocket
|
||
target_ws = plugin_conns.get(target_name)
|
||
if target_ws:
|
||
await target_ws.send_json(payload)
|
||
except WebSocketDisconnect:
|
||
pass
|
||
finally:
|
||
browser_conns.remove(websocket)
|
||
|
||
|
||
# -------------------- static frontend ---------------------------
|
||
# static frontend
|
||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||
|
||
# list routes for convenience
|
||
print("🔍 Registered routes:")
|
||
for route in app.routes:
|
||
if isinstance(route, APIRoute):
|
||
print(f"{route.path} -> {route.methods}")
|