Refactor to async TimescaleDB backend & add Alembic migrations
This commit is contained in:
parent
d396942deb
commit
c20d54d037
9 changed files with 328 additions and 99 deletions
172
main.py
172
main.py
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
import json
|
||||
import sqlite3
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
|
|
@ -11,9 +11,11 @@ from fastapi.encoders import jsonable_encoder
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from db import init_db, save_snapshot, DB_FILE
|
||||
# Async database support
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from db_async import database, telemetry_events, char_stats, init_db_async
|
||||
import asyncio
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
app = FastAPI()
|
||||
|
|
@ -29,7 +31,7 @@ ACTIVE_WINDOW = timedelta(seconds=30) # player is “online” if seen in last
|
|||
|
||||
class TelemetrySnapshot(BaseModel):
|
||||
character_name: str
|
||||
char_tag: str
|
||||
char_tag: Optional[str] = None
|
||||
session_id: str
|
||||
timestamp: datetime
|
||||
|
||||
|
|
@ -38,17 +40,29 @@ class TelemetrySnapshot(BaseModel):
|
|||
z: float
|
||||
|
||||
kills: int
|
||||
kills_per_hour: Optional[str] = None # now optional
|
||||
onlinetime: Optional[str] = None # now optional
|
||||
kills_per_hour: Optional[float] = None
|
||||
onlinetime: Optional[str] = None
|
||||
deaths: int
|
||||
rares_found: int
|
||||
prismatic_taper_count: int
|
||||
vt_state: str
|
||||
# Optional telemetry metrics
|
||||
mem_mb: Optional[float] = None
|
||||
cpu_pct: Optional[float] = None
|
||||
mem_handles: Optional[int] = None
|
||||
latency_ms: Optional[float] = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
init_db()
|
||||
async def on_startup():
|
||||
# Connect to database and initialize TimescaleDB hypertable
|
||||
await database.connect()
|
||||
await init_db_async()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown():
|
||||
# Disconnect from database
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
|
||||
|
|
@ -60,78 +74,47 @@ def debug():
|
|||
|
||||
@app.get("/live", response_model=dict)
|
||||
@app.get("/live/", response_model=dict)
|
||||
def get_live_players():
|
||||
# compute cutoff once
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
cutoff = now_utc - ACTIVE_WINDOW
|
||||
|
||||
cutoff_sql = cutoff.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
try:
|
||||
with sqlite3.connect(DB_FILE) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = """
|
||||
SELECT *
|
||||
FROM live_state
|
||||
WHERE datetime(timestamp) > datetime(?, 'utc')
|
||||
"""
|
||||
rows = conn.execute(query, (cutoff_sql,)).fetchall()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
# log e if you have logging set up
|
||||
raise HTTPException(status_code=500, detail="Database error")
|
||||
|
||||
# build list of dicts
|
||||
players = []
|
||||
for r in rows:
|
||||
players.append(dict(r))
|
||||
|
||||
async def get_live_players():
|
||||
"""Return recent live telemetry per character (last 30 seconds)."""
|
||||
cutoff = datetime.now(timezone.utc) - ACTIVE_WINDOW
|
||||
query = text(
|
||||
"""
|
||||
SELECT * FROM (
|
||||
SELECT DISTINCT ON (character_name) *
|
||||
FROM telemetry_events
|
||||
ORDER BY character_name, timestamp DESC
|
||||
) sub
|
||||
WHERE timestamp > :cutoff
|
||||
"""
|
||||
)
|
||||
rows = await database.fetch_all(query, {"cutoff": cutoff})
|
||||
players = [dict(r) for r in rows]
|
||||
return JSONResponse(content={"players": players})
|
||||
|
||||
|
||||
@app.get("/history/")
|
||||
@app.get("/history")
|
||||
def get_history(
|
||||
async def get_history(
|
||||
from_ts: str | None = Query(None, alias="from"),
|
||||
to_ts: str | None = Query(None, alias="to"),
|
||||
):
|
||||
"""
|
||||
Returns a time‐ordered list of telemetry snapshots:
|
||||
- timestamp: ISO8601 string
|
||||
- character_name: str
|
||||
- kills: cumulative kill count (int)
|
||||
- kph: kills_per_hour (float)
|
||||
"""
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Build the base query
|
||||
sql = """
|
||||
SELECT
|
||||
timestamp,
|
||||
character_name,
|
||||
kills,
|
||||
CAST(kills_per_hour AS REAL) AS kph
|
||||
FROM telemetry_log
|
||||
"""
|
||||
params: list[str] = []
|
||||
"""Returns a time-ordered list of telemetry snapshots."""
|
||||
sql = (
|
||||
"SELECT timestamp, character_name, kills, kills_per_hour AS kph "
|
||||
"FROM telemetry_events"
|
||||
)
|
||||
values: dict = {}
|
||||
conditions: list[str] = []
|
||||
|
||||
# Add optional filters
|
||||
if from_ts:
|
||||
conditions.append("timestamp >= ?")
|
||||
params.append(from_ts)
|
||||
conditions.append("timestamp >= :from_ts")
|
||||
values["from_ts"] = from_ts
|
||||
if to_ts:
|
||||
conditions.append("timestamp <= ?")
|
||||
params.append(to_ts)
|
||||
conditions.append("timestamp <= :to_ts")
|
||||
values["to_ts"] = to_ts
|
||||
if conditions:
|
||||
sql += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
sql += " ORDER BY timestamp"
|
||||
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
conn.close()
|
||||
|
||||
rows = await database.fetch_all(text(sql), values)
|
||||
data = [
|
||||
{
|
||||
"timestamp": row["timestamp"],
|
||||
|
|
@ -144,33 +127,23 @@ def get_history(
|
|||
return JSONResponse(content={"data": data})
|
||||
|
||||
|
||||
# ------------------------ GET Trails ---------------------------------
|
||||
# --- GET Trails ---------------------------------
|
||||
@app.get("/trails")
|
||||
@app.get("/trails/")
|
||||
def get_trails(
|
||||
seconds: int = Query(600, ge=0, description="Lookback window in seconds")
|
||||
async def get_trails(
|
||||
seconds: int = Query(600, ge=0, description="Lookback window in seconds"),
|
||||
):
|
||||
"""
|
||||
Return position snapshots (timestamp, character_name, ew, ns, z)
|
||||
for the past `seconds` seconds.
|
||||
"""
|
||||
# match the same string format as stored timestamps (via str(datetime))
|
||||
cutoff_dt = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(
|
||||
seconds=seconds
|
||||
)
|
||||
cutoff = str(cutoff_dt)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"""Return position snapshots (timestamp, character_name, ew, ns, z) for the past `seconds`."""
|
||||
cutoff = datetime.utcnow().replace(tzinfo=timezone.utc) - timedelta(seconds=seconds)
|
||||
sql = text(
|
||||
"""
|
||||
SELECT timestamp, character_name, ew, ns, z
|
||||
FROM telemetry_log
|
||||
WHERE timestamp >= ?
|
||||
ORDER BY character_name, timestamp
|
||||
""",
|
||||
(cutoff,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
FROM telemetry_events
|
||||
WHERE timestamp >= :cutoff
|
||||
ORDER BY character_name, timestamp
|
||||
"""
|
||||
)
|
||||
rows = await database.fetch_all(sql, {"cutoff": cutoff})
|
||||
trails = [
|
||||
{
|
||||
"timestamp": r["timestamp"],
|
||||
|
|
@ -236,11 +209,29 @@ async def ws_receive_snapshots(
|
|||
continue
|
||||
# Telemetry message: save to DB and broadcast
|
||||
if msg_type == "telemetry":
|
||||
# Parse and broadcast telemetry snapshot
|
||||
payload = data.copy()
|
||||
payload.pop("type", None)
|
||||
snap = TelemetrySnapshot.parse_obj(payload)
|
||||
live_snapshots[snap.character_name] = snap.dict()
|
||||
await run_in_threadpool(save_snapshot, snap.dict())
|
||||
# Persist to TimescaleDB
|
||||
await database.execute(
|
||||
telemetry_events.insert().values(**snap.dict())
|
||||
)
|
||||
# Update persistent kill stats (delta per session)
|
||||
key = (snap.session_id, snap.character_name)
|
||||
last = ws_receive_snapshots._last_kills.get(key, 0)
|
||||
delta = snap.kills - last
|
||||
if delta > 0:
|
||||
stmt = pg_insert(char_stats).values(
|
||||
character_name=snap.character_name,
|
||||
total_kills=delta
|
||||
).on_conflict_do_update(
|
||||
index_elements=["character_name"],
|
||||
set_={"total_kills": char_stats.c.total_kills + delta},
|
||||
)
|
||||
await database.execute(stmt)
|
||||
ws_receive_snapshots._last_kills[key] = snap.kills
|
||||
await _broadcast_to_browser_clients(snap.dict())
|
||||
continue
|
||||
# Chat message: broadcast to browser clients only (no DB write)
|
||||
|
|
@ -255,6 +246,9 @@ async def ws_receive_snapshots(
|
|||
del plugin_conns[n]
|
||||
print(f"[WS] Cleaned up plugin connections for {websocket.client}")
|
||||
|
||||
# In-memory store of last kills per session for delta calculations
|
||||
ws_receive_snapshots._last_kills = {}
|
||||
|
||||
@app.websocket("/ws/live")
|
||||
async def ws_live_updates(websocket: WebSocket):
|
||||
# Browser clients connect here to receive telemetry and chat, and send commands
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue