ws version with nice DB select
This commit is contained in:
parent
a121d57a13
commit
73ae756e5c
6 changed files with 491 additions and 106 deletions
159
main.py
159
main.py
|
|
@ -7,6 +7,7 @@ from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketD
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -50,31 +51,6 @@ def on_startup():
|
|||
init_db()
|
||||
|
||||
|
||||
# ------------------------ POST ----------------------------------
|
||||
@app.post("/position")
|
||||
@app.post("/position/")
|
||||
async def receive_snapshot(
|
||||
snapshot: TelemetrySnapshot, x_plugin_secret: str = Header(None)
|
||||
):
|
||||
if x_plugin_secret != SHARED_SECRET:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
# cache for /live
|
||||
live_snapshots[snapshot.character_name] = snapshot.dict()
|
||||
|
||||
# save in sqlite
|
||||
save_snapshot(snapshot.dict())
|
||||
|
||||
# optional log-file append
|
||||
# with open(LOG_FILE, "a") as f:
|
||||
# f.write(json.dumps(snapshot.dict(), default=str) + "\n")
|
||||
|
||||
print(
|
||||
f"[{datetime.now()}] {snapshot.character_name} @ NS={snapshot.ns:+.2f}, EW={snapshot.ew:+.2f}"
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ------------------------ GET -----------------------------------
|
||||
@app.get("/debug")
|
||||
|
|
@ -82,22 +58,33 @@ def debug():
|
|||
return {"status": "OK"}
|
||||
|
||||
|
||||
@app.get("/live")
|
||||
@app.get("/live/")
|
||||
@app.get("/live", response_model=dict)
|
||||
@app.get("/live/", response_model=dict)
|
||||
def get_live_players():
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute("SELECT * FROM live_state").fetchall()
|
||||
conn.close()
|
||||
# compute cutoff once
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
cutoff = now_utc - ACTIVE_WINDOW
|
||||
|
||||
# aware cutoff (UTC)
|
||||
cutoff = datetime.utcnow().replace(tzinfo=timezone.utc) - ACTIVE_WINDOW
|
||||
cutoff_sql = cutoff.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
players = [
|
||||
dict(r)
|
||||
for r in rows
|
||||
if datetime.fromisoformat(r["timestamp"].replace("Z", "+00:00")) > cutoff
|
||||
]
|
||||
try:
|
||||
with sqlite3.connect(DB_FILE) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = """
|
||||
SELECT *
|
||||
FROM live_state
|
||||
WHERE datetime(timestamp) > datetime(?, 'utc')
|
||||
"""
|
||||
rows = conn.execute(query, (cutoff_sql,)).fetchall()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
# log e if you have logging set up
|
||||
raise HTTPException(status_code=500, detail="Database error")
|
||||
|
||||
# build list of dicts
|
||||
players = []
|
||||
for r in rows:
|
||||
players.append(dict(r))
|
||||
|
||||
return JSONResponse(content={"players": players})
|
||||
|
||||
|
|
@ -198,42 +185,114 @@ def get_trails(
|
|||
|
||||
# -------------------- WebSocket endpoints -----------------------
|
||||
browser_conns: set[WebSocket] = set()
|
||||
# Map of registered plugin clients: character_name -> WebSocket
|
||||
plugin_conns: Dict[str, WebSocket] = {}
|
||||
|
||||
async def _broadcast_to_browser_clients(snapshot: dict):
|
||||
# Ensure all data (e.g. datetime) is JSON-serializable
|
||||
data = jsonable_encoder(snapshot)
|
||||
for ws in list(browser_conns):
|
||||
try:
|
||||
await ws.send_json(snapshot)
|
||||
await ws.send_json(data)
|
||||
except WebSocketDisconnect:
|
||||
browser_conns.remove(ws)
|
||||
|
||||
@app.websocket("/ws/position")
|
||||
async def ws_receive_snapshots(websocket: WebSocket, secret: str = Query(...)):
|
||||
await websocket.accept()
|
||||
if secret != SHARED_SECRET:
|
||||
async def ws_receive_snapshots(
|
||||
websocket: WebSocket,
|
||||
secret: str | None = Query(None),
|
||||
x_plugin_secret: str | None = Header(None)
|
||||
):
|
||||
# Verify shared secret from query parameter or header
|
||||
key = secret or x_plugin_secret
|
||||
if key != SHARED_SECRET:
|
||||
# Reject without completing the WebSocket handshake
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
# Accept the WebSocket connection
|
||||
await websocket.accept()
|
||||
print(f"[WS] Plugin connected: {websocket.client}")
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
snap = TelemetrySnapshot.parse_obj(data)
|
||||
live_snapshots[snap.character_name] = snap.dict()
|
||||
await run_in_threadpool(save_snapshot, snap.dict())
|
||||
await _broadcast_to_browser_clients(snap.dict())
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
# Read next text frame
|
||||
try:
|
||||
raw = await websocket.receive_text()
|
||||
# Debug: log all incoming plugin WebSocket messages
|
||||
print(f"[WS-PLUGIN RX] {websocket.client}: {raw}")
|
||||
except WebSocketDisconnect:
|
||||
print(f"[WS] Plugin disconnected: {websocket.client}")
|
||||
break
|
||||
# Parse JSON payload
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
msg_type = data.get("type")
|
||||
# Registration message: map character to this socket
|
||||
if msg_type == "register":
|
||||
name = data.get("character_name") or data.get("player_name")
|
||||
if isinstance(name, str):
|
||||
plugin_conns[name] = websocket
|
||||
continue
|
||||
# Telemetry message: save to DB and broadcast
|
||||
if msg_type == "telemetry":
|
||||
payload = data.copy()
|
||||
payload.pop("type", None)
|
||||
snap = TelemetrySnapshot.parse_obj(payload)
|
||||
live_snapshots[snap.character_name] = snap.dict()
|
||||
await run_in_threadpool(save_snapshot, snap.dict())
|
||||
await _broadcast_to_browser_clients(snap.dict())
|
||||
continue
|
||||
# Chat message: broadcast to browser clients only (no DB write)
|
||||
if msg_type == "chat":
|
||||
await _broadcast_to_browser_clients(data)
|
||||
continue
|
||||
# Unknown message types are ignored
|
||||
finally:
|
||||
# Clean up any plugin registrations for this socket
|
||||
to_remove = [n for n, ws in plugin_conns.items() if ws is websocket]
|
||||
for n in to_remove:
|
||||
del plugin_conns[n]
|
||||
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
|
||||
|
||||
@app.websocket("/ws/live")
|
||||
async def ws_live_updates(websocket: WebSocket):
|
||||
# Browser clients connect here to receive telemetry and chat, and send commands
|
||||
await websocket.accept()
|
||||
browser_conns.add(websocket)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
# Receive command messages from browser
|
||||
try:
|
||||
data = await websocket.receive_json()
|
||||
# Debug: log all incoming browser WebSocket messages
|
||||
print(f"[WS-LIVE RX] {websocket.client}: {data}")
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
# Determine command envelope format (new or legacy)
|
||||
if "player_name" in data and "command" in data:
|
||||
# New format: { player_name, command }
|
||||
target_name = data["player_name"]
|
||||
payload = data
|
||||
elif data.get("type") == "command" and "character_name" in data and "text" in data:
|
||||
# Legacy format: { type: 'command', character_name, text }
|
||||
target_name = data.get("character_name")
|
||||
payload = {"player_name": target_name, "command": data.get("text")}
|
||||
else:
|
||||
# Not a recognized command envelope
|
||||
continue
|
||||
# Forward command envelope to the appropriate plugin WebSocket
|
||||
target_ws = plugin_conns.get(target_name)
|
||||
if target_ws:
|
||||
await target_ws.send_json(payload)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
browser_conns.remove(websocket)
|
||||
|
||||
|
||||
# -------------------- static frontend ---------------------------
|
||||
# static frontend
|
||||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||||
|
||||
# list routes for convenience
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue