Refactor to async TimescaleDB backend & add Alembic migrations

This commit is contained in:
erik 2025-05-18 19:07:23 +00:00
parent d396942deb
commit c20d54d037
9 changed files with 328 additions and 99 deletions

172
main.py
View file

@ -1,6 +1,6 @@
from datetime import datetime, timedelta, timezone
import json
import sqlite3
import os
from typing import Dict
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
@ -11,9 +11,11 @@ from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from typing import Optional
from db import init_db, save_snapshot, DB_FILE
# Async database support
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from db_async import database, telemetry_events, char_stats, init_db_async
import asyncio
from starlette.concurrency import run_in_threadpool
# ------------------------------------------------------------------
app = FastAPI()
@ -29,7 +31,7 @@ ACTIVE_WINDOW = timedelta(seconds=30) # player is “online” if seen in last
class TelemetrySnapshot(BaseModel):
character_name: str
char_tag: str
char_tag: Optional[str] = None
session_id: str
timestamp: datetime
@ -38,17 +40,29 @@ class TelemetrySnapshot(BaseModel):
z: float
kills: int
kills_per_hour: Optional[str] = None # now optional
onlinetime: Optional[str] = None # now optional
kills_per_hour: Optional[float] = None
onlinetime: Optional[str] = None
deaths: int
rares_found: int
prismatic_taper_count: int
vt_state: str
# Optional telemetry metrics
mem_mb: Optional[float] = None
cpu_pct: Optional[float] = None
mem_handles: Optional[int] = None
latency_ms: Optional[float] = None
@app.on_event("startup")
def on_startup():
init_db()
async def on_startup():
# Connect to database and initialize TimescaleDB hypertable
await database.connect()
await init_db_async()
@app.on_event("shutdown")
async def on_shutdown():
# Disconnect from database
await database.disconnect()
@ -60,78 +74,47 @@ def debug():
@app.get("/live", response_model=dict)
@app.get("/live/", response_model=dict)
def get_live_players():
# compute cutoff once
now_utc = datetime.now(timezone.utc)
cutoff = now_utc - ACTIVE_WINDOW
cutoff_sql = cutoff.strftime("%Y-%m-%d %H:%M:%S")
try:
with sqlite3.connect(DB_FILE) as conn:
conn.row_factory = sqlite3.Row
query = """
SELECT *
FROM live_state
WHERE datetime(timestamp) > datetime(?, 'utc')
"""
rows = conn.execute(query, (cutoff_sql,)).fetchall()
except sqlite3.Error as e:
# log e if you have logging set up
raise HTTPException(status_code=500, detail="Database error")
# build list of dicts
players = []
for r in rows:
players.append(dict(r))
async def get_live_players():
"""Return recent live telemetry per character (last 30 seconds)."""
cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW
query = text(
"""
SELECT * FROM (
SELECT DISTINCT ON (character_name) *
FROM telemetry_events
ORDER BY character_name, timestamp DESC
) sub
WHERE timestamp > :cutoff
"""
)
rows = await database.fetch_all(query, {"cutoff": cutoff})
players = [dict(r) for r in rows]
return JSONResponse(content={"players": players})
@app.get("/history/")
@app.get("/history")
def get_history(
async def get_history(
from_ts: str | None = Query(None, alias="from"),
to_ts: str | None = Query(None, alias="to"),
):
"""
Returns a timeordered list of telemetry snapshots:
- timestamp: ISO8601 string
- character_name: str
- kills: cumulative kill count (int)
- kph: kills_per_hour (float)
"""
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
# Build the base query
sql = """
SELECT
timestamp,
character_name,
kills,
CAST(kills_per_hour AS REAL) AS kph
FROM telemetry_log
"""
params: list[str] = []
"""Returns a time-ordered list of telemetry snapshots."""
sql = (
"SELECT timestamp, character_name, kills, kills_per_hour AS kph "
"FROM telemetry_events"
)
values: dict = {}
conditions: list[str] = []
# Add optional filters
if from_ts:
conditions.append("timestamp >= ?")
params.append(from_ts)
conditions.append("timestamp >= :from_ts")
values["from_ts"] = from_ts
if to_ts:
conditions.append("timestamp <= ?")
params.append(to_ts)
conditions.append("timestamp <= :to_ts")
values["to_ts"] = to_ts
if conditions:
sql += " WHERE " + " AND ".join(conditions)
sql += " ORDER BY timestamp"
rows = conn.execute(sql, params).fetchall()
conn.close()
rows = await database.fetch_all(text(sql), values)
data = [
{
"timestamp": row["timestamp"],
@ -144,33 +127,23 @@ def get_history(
return JSONResponse(content={"data": data})
# ------------------------ GET Trails ---------------------------------
# --- GET Trails ---------------------------------
@app.get("/trails")
@app.get("/trails/")
def get_trails(
seconds: int = Query(600, ge=0, description="Lookback window in seconds")
async def get_trails(
seconds: int = Query(600, ge=0, description="Lookback window in seconds"),
):
"""
Return position snapshots (timestamp, character_name, ew, ns, z)
for the past `seconds` seconds.
"""
# match the same string format as stored timestamps (via str(datetime))
cutoff_dt = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(
seconds=seconds
)
cutoff = str(cutoff_dt)
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"""Return position snapshots (timestamp, character_name, ew, ns, z) for the past `seconds`."""
cutoff = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=seconds)
sql = text(
"""
SELECT timestamp, character_name, ew, ns, z
FROM telemetry_log
WHERE timestamp >= ?
ORDER BY character_name, timestamp
""",
(cutoff,),
).fetchall()
conn.close()
FROM telemetry_events
WHERE timestamp >= :cutoff
ORDER BY character_name, timestamp
"""
)
rows = await database.fetch_all(sql, {"cutoff": cutoff})
trails = [
{
"timestamp": r["timestamp"],
@ -236,11 +209,29 @@ async def ws_receive_snapshots(
continue
# Telemetry message: save to DB and broadcast
if msg_type == "telemetry":
# Parse and broadcast telemetry snapshot
payload = data.copy()
payload.pop("type", None)
snap = TelemetrySnapshot.parse_obj(payload)
live_snapshots[snap.character_name] = snap.dict()
await run_in_threadpool(save_snapshot, snap.dict())
# Persist to TimescaleDB
await database.execute(
telemetry_events.insert().values(**snap.dict())
)
# Update persistent kill stats (delta per session)
key = (snap.session_id, snap.character_name)
last = ws_receive_snapshots._last_kills.get(key, 0)
delta = snap.kills - last
if delta > 0:
stmt = pg_insert(char_stats).values(
character_name=snap.character_name,
total_kills=delta
).on_conflict_do_update(
index_elements=["character_name"],
set_={"total_kills": char_stats.c.total_kills + delta},
)
await database.execute(stmt)
ws_receive_snapshots._last_kills[key] = snap.kills
await _broadcast_to_browser_clients(snap.dict())
continue
# Chat message: broadcast to browser clients only (no DB write)
@ -255,6 +246,9 @@ async def ws_receive_snapshots(
del plugin_conns[n]
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
# In-memory store of last kills per session for delta calculations
ws_receive_snapshots._last_kills = {}
@app.websocket("/ws/live")
async def ws_live_updates(websocket: WebSocket):
# Browser clients connect here to receive telemetry and chat, and send commands