MosswartOverlord/agent/tools.py
Erik 1196746dbe fix(agent): SQL parser robust against sqlglot version drift
The query_telemetry_db tool was crashing with AttributeError because
exp.AlterTable doesn't exist in this sqlglot version (renamed to Alter).
Made the deny-class list build defensively via getattr and dropped any
classes that the installed sqlglot doesn't expose.

Also broadened the deny list (Alter, AlterColumn, AlterDatabase, Truncate,
Grant, Revoke, Copy) and made the toplevel allowlist tolerant of missing
classes too. The walk() return shape is also normalized in case sqlglot
versions yield (node, parent, key) tuples vs. bare nodes.

Belt-and-suspenders is fine — the GRANT-SELECT-only PG role is the real
write barrier; the parser is just a faster/friendlier reject path.
2026-04-25 23:07:00 +02:00

451 lines
15 KiB
Python

"""Tool implementations exposed to Claude via the MCP server.
These are pure functions — the MCP server (mcp_overlord.py) only handles
the protocol wrapping. Keep tool logic here so it's easy to test in
isolation and reuse from elsewhere (e.g. /agent/ask shortcuts).
Two flavors of data access:
* HTTP loopback to the dereth-tracker container (for endpoints that
already exist and have validated logic).
* Direct asyncpg to the read-only PG role for ad-hoc queries
(rare_events, telemetry, anything not exposed via HTTP).
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
from typing import Any
from urllib.parse import quote
import asyncpg
import httpx
import sqlglot
import sqlglot.errors
import sqlglot.expressions as exp
logger = logging.getLogger(__name__)
# The dereth-tracker FastAPI app, reachable from the host because Docker
# port-forwards 127.0.0.1:8765:8765 in docker-compose.yml.
TRACKER_URL = os.getenv("TRACKER_URL", "http://127.0.0.1:8765")
# Read-only PG role; see deployment plan.
DB_DSN = os.getenv(
"AGENT_DB_DSN",
"postgresql://overlord_agent_ro@127.0.0.1:5432/dereth",
)
# Hard caps for the SQL tool to keep the agent honest.
SQL_TIMEOUT_S = float(os.getenv("AGENT_SQL_TIMEOUT_S", "10"))
SQL_MAX_ROWS = int(os.getenv("AGENT_SQL_MAX_ROWS", "200"))
# ─── HTTP loopback helpers ──────────────────────────────────────────
_http_client: httpx.AsyncClient | None = None
async def _http() -> httpx.AsyncClient:
"""Lazily create + reuse a single httpx client (connection pool)."""
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient(base_url=TRACKER_URL, timeout=30.0)
return _http_client
async def _get_json(path: str) -> Any:
client = await _http()
resp = await client.get(path)
resp.raise_for_status()
return resp.json()
# ─── DB helpers ─────────────────────────────────────────────────────
_db_pool: asyncpg.Pool | None = None
async def _db() -> asyncpg.Pool:
global _db_pool
if _db_pool is None:
_db_pool = await asyncpg.create_pool(
DB_DSN, min_size=1, max_size=4, command_timeout=SQL_TIMEOUT_S
)
return _db_pool
# ─── SQL safety ─────────────────────────────────────────────────────
_ALLOWED_TOPLEVEL = tuple(
cls for cls in (
getattr(exp, "Select", None),
getattr(exp, "With", None),
getattr(exp, "Union", None),
getattr(exp, "Subquery", None),
getattr(exp, "Intersect", None),
getattr(exp, "Except", None),
)
if cls is not None
)
class SqlNotAllowed(ValueError):
"""Raised when the agent attempts a non-read-only SQL statement."""
def assert_read_only(sql: str) -> None:
"""Parse `sql` and reject anything that isn't a read query.
Belt-and-suspenders: the PG role is also read-only (GRANT SELECT only),
so even a parser bypass can't actually mutate. This is the first line
of defense — friendlier error messages and faster reject.
"""
try:
statements = sqlglot.parse(sql, read="postgres")
except sqlglot.errors.ParseError as e:
raise SqlNotAllowed(f"SQL parse error: {e}") from e
if not statements:
raise SqlNotAllowed("empty SQL")
if len(statements) > 1:
raise SqlNotAllowed("only one statement allowed")
stmt = statements[0]
if stmt is None:
raise SqlNotAllowed("empty parse result")
if not isinstance(stmt, _ALLOWED_TOPLEVEL):
raise SqlNotAllowed(
f"only SELECT / WITH allowed, got {type(stmt).__name__}"
)
# Walk the tree and reject any DML/DDL hidden inside (e.g. CTE with
# INSERT — yes, postgres allows that). Use getattr so version drift
# in sqlglot (renamed classes like AlterTable→Alter) doesn't crash
# the whole tool.
_DENY_NAMES = (
"Insert", "Update", "Delete", "Drop", "Create", "Merge",
"Alter", "AlterTable", "AlterColumn", "AlterDatabase",
"Truncate", "TruncateTable",
"Grant", "Revoke",
"Copy", # PostgreSQL COPY can write files
)
deny_classes = tuple(
cls for cls in (getattr(exp, name, None) for name in _DENY_NAMES)
if cls is not None
)
for node in stmt.walk():
# walk() returns the node, then in some sqlglot versions a tuple of
# (node, parent, key). Normalize.
actual = node[0] if isinstance(node, tuple) else node
if isinstance(actual, deny_classes):
raise SqlNotAllowed(
f"writes/DDL not allowed (found {type(actual).__name__})"
)
# ─── Tools ──────────────────────────────────────────────────────────
async def get_live_players() -> dict[str, Any]:
"""Active characters (telemetry seen in the last ~30s).
Returns the same shape as `GET /live`:
{ "players": [ { character_name, ew, ns, z, kills, ... } ] }
"""
return await _get_json("/live")
async def get_recent_rares(hours: int = 24, limit: int = 100) -> dict[str, Any]:
"""Rare item finds in the last N hours, newest first."""
hours = max(1, min(int(hours), 24 * 30)) # cap at 30 days
limit = max(1, min(int(limit), SQL_MAX_ROWS))
pool = await _db()
rows = await pool.fetch(
"""
SELECT timestamp, character_name, name, ew, ns, z
FROM rare_events
WHERE timestamp >= NOW() - ($1::int || ' hours')::interval
ORDER BY timestamp DESC
LIMIT $2
""",
hours,
limit,
)
return {
"hours": hours,
"count": len(rows),
"rares": [
{
"timestamp": r["timestamp"].isoformat(),
"character_name": r["character_name"],
"name": r["name"],
"ew": r["ew"],
"ns": r["ns"],
"z": r["z"],
}
for r in rows
],
}
async def query_telemetry_db(sql: str) -> dict[str, Any]:
"""Run a read-only SQL statement against the telemetry DB.
The query is parsed and any non-SELECT/WITH statement is rejected.
The connection role is also GRANT SELECT only (defense in depth).
Useful for ad-hoc questions: "top 5 KPH today", "kill count by character
yesterday", etc.
"""
assert_read_only(sql)
pool = await _db()
try:
rows = await asyncio.wait_for(pool.fetch(sql), timeout=SQL_TIMEOUT_S)
except asyncio.TimeoutError:
raise SqlNotAllowed(f"query exceeded {SQL_TIMEOUT_S:.0f}s timeout")
if len(rows) > SQL_MAX_ROWS:
rows = rows[:SQL_MAX_ROWS]
truncated = True
else:
truncated = False
return {
"row_count": len(rows),
"truncated": truncated,
"rows": [
{k: _json_safe(v) for k, v in dict(r).items()} for r in rows
],
}
def _json_safe(v: Any) -> Any:
"""Convert datetime / Decimal / etc. to JSON-friendly types."""
from datetime import date, datetime, timedelta
from decimal import Decimal
if v is None:
return None
if isinstance(v, (str, int, float, bool)):
return v
if isinstance(v, (datetime, date)):
return v.isoformat()
if isinstance(v, timedelta):
return v.total_seconds()
if isinstance(v, Decimal):
return float(v)
if isinstance(v, (list, tuple)):
return [_json_safe(x) for x in v]
if isinstance(v, dict):
return {k: _json_safe(x) for k, x in v.items()}
return str(v)
# ─── Per-character lookups (HTTP loopback) ──────────────────────────
async def get_player_state(character_name: str) -> dict[str, Any]:
"""Combined snapshot for one character: live telemetry + character stats.
Returns:
{
"character_name": str,
"telemetry": {...} | None, # from /live, or None if offline
"character_stats": {...} | None, # from /character-stats/<name>
"vitals": {...} | None, # last vitals from /live (subset)
"online": bool, # whether telemetry was found in /live
}
"""
name = character_name.strip()
live = await _get_json("/live")
players = live.get("players", []) if isinstance(live, dict) else []
telemetry = next(
(p for p in players if p.get("character_name") == name), None
)
char_stats: dict[str, Any] | None = None
try:
client = await _http()
resp = await client.get(f"/character-stats/{quote(name, safe='')}")
if resp.status_code == 200:
char_stats = resp.json()
except Exception:
char_stats = None
return {
"character_name": name,
"online": telemetry is not None,
"telemetry": telemetry,
"character_stats": char_stats,
}
async def get_inventory(character_name: str) -> dict[str, Any]:
"""Full inventory for one character. Items only — for filtered queries
use get_inventory_search."""
client = await _http()
resp = await client.get(f"/inventory/{quote(character_name, safe='')}")
resp.raise_for_status()
return resp.json()
async def get_inventory_search(
character_name: str, filters: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Filtered inventory search. `filters` is a dict of query params, e.g.
{"name": "pearl", "armor_level_min": 500}.
Caller is expected to know the supported filters from the dereth-tracker
/inventory/{name}/search route — pass through opaquely.
"""
client = await _http()
resp = await client.get(
f"/inventory/{quote(character_name, safe='')}/search",
params=filters or {},
)
resp.raise_for_status()
return resp.json()
async def search_items_global(filters: dict[str, Any]) -> dict[str, Any]:
"""Cross-character item search via the inventory service's /search/items.
Use this INSTEAD of looping per-character when the user asks "find an X
on any of my chars" — one DB query vs. 60+ HTTP roundtrips.
Common filter keys (passed straight through as query params):
include_all_characters: bool (set true to search every char)
character: str (single char) | characters: "A,B,C"
text: str (name/description substring)
has_spell: "Legendary Acid Ward" — exact spell name
spell_contains: "Legendary" — substring match
legendary_cantrips: "Foo,Bar"
equipment_status: "equipped" | "unequipped"
equipment_slot: int (bitmask: 4=chest, 2048=bracelet, 4096=ring, ...)
slot_names: "Bracelet,Ring"
armor_only / jewelry_only / weapon_only: bool
min_armor / max_armor / min_damage / max_damage: int
...and many more — see /search/items endpoint docs.
"""
client = await _http()
# Default to all-character search if caller didn't scope; otherwise the
# endpoint refuses with a 400.
params = dict(filters or {})
if not any(
k in params
for k in ("character", "characters", "include_all_characters")
):
params["include_all_characters"] = True
resp = await client.get("/search/items", params=params)
resp.raise_for_status()
return resp.json()
async def get_combat_stats(character_name: str) -> dict[str, Any]:
"""Lifetime + session combat stats for one character (per-element split,
monster encounters, surge counts)."""
client = await _http()
resp = await client.get(f"/combat-stats/{quote(character_name, safe='')}")
resp.raise_for_status()
return resp.json()
async def get_equipment_cantrips(character_name: str) -> dict[str, Any]:
"""Currently-equipped items + their active cantrip/spell state."""
client = await _http()
resp = await client.get(
f"/equipment-cantrip-state/{quote(character_name, safe='')}"
)
resp.raise_for_status()
return resp.json()
async def get_quest_status() -> dict[str, Any]:
"""All characters' active quest timers and progress."""
return await _get_json("/quest-status")
async def get_server_health() -> dict[str, Any]:
"""Coldeve server status: up/down, latency, current player count, uptime."""
return await _get_json("/server-health")
async def suitbuilder_search(
params: dict[str, Any], max_phase_events: int = 50
) -> dict[str, Any]:
"""Drive a suitbuilder constraint search synchronously.
The dereth-tracker /inv/suitbuilder/search endpoint is an SSE stream.
We collect events until the stream closes, drop intermediate phase
chatter (keeping the last N), and return:
{ "final_suits": [...], "phases": [...latest few...] }
`params` is the JSON body the suitbuilder expects. Call it like the
/suitbuilder.html page does.
"""
client = await _http()
final: list[dict[str, Any]] = []
phases: list[dict[str, Any]] = []
# Use a fresh long-timeout client for the SSE stream — don't tie up the
# shared pool for a 5-minute search.
async with httpx.AsyncClient(
base_url=TRACKER_URL, timeout=httpx.Timeout(300.0, connect=10.0)
) as stream_client:
async with stream_client.stream(
"POST",
"/inv/suitbuilder/search",
json=params,
headers={"Content-Type": "application/json"},
) as resp:
event_name = "message"
data_lines: list[str] = []
async for line_bytes in resp.aiter_lines():
line = line_bytes.rstrip("\r")
if line.startswith("event:"):
event_name = line[6:].strip()
elif line.startswith("data:"):
data_lines.append(line[5:].strip())
elif line == "":
# Dispatch
if data_lines:
try:
payload = json.loads("\n".join(data_lines))
except json.JSONDecodeError:
payload = {"raw": "\n".join(data_lines)}
if event_name == "result" or event_name == "final":
final.append(payload)
elif event_name == "error":
phases.append({"event": "error", "data": payload})
else:
phases.append({"event": event_name, "data": payload})
phases = phases[-max_phase_events:]
data_lines = []
event_name = "message"
return {
"final_suits": final,
"phases": phases[-max_phase_events:],
"phase_count": len(phases),
}
# ─── Cleanup ────────────────────────────────────────────────────────
async def shutdown() -> None:
"""Close shared resources. Call from MCP server lifespan / on exit."""
global _http_client, _db_pool
if _http_client is not None:
await _http_client.aclose()
_http_client = None
if _db_pool is not None:
await _db_pool.close()
_db_pool = None