"""Tool implementations exposed to Claude via the MCP server. These are pure functions — the MCP server (mcp_overlord.py) only handles the protocol wrapping. Keep tool logic here so it's easy to test in isolation and reuse from elsewhere (e.g. /agent/ask shortcuts). Two flavors of data access: * HTTP loopback to the dereth-tracker container (for endpoints that already exist and have validated logic). * Direct asyncpg to the read-only PG role for ad-hoc queries (rare_events, telemetry, anything not exposed via HTTP). """ from __future__ import annotations import asyncio import json import logging import os from typing import Any from urllib.parse import quote import asyncpg import httpx import sqlglot import sqlglot.errors import sqlglot.expressions as exp logger = logging.getLogger(__name__) # The dereth-tracker FastAPI app, reachable from the host because Docker # port-forwards 127.0.0.1:8765:8765 in docker-compose.yml. TRACKER_URL = os.getenv("TRACKER_URL", "http://127.0.0.1:8765") # Read-only PG role; see deployment plan. DB_DSN = os.getenv( "AGENT_DB_DSN", "postgresql://overlord_agent_ro@127.0.0.1:5432/dereth", ) # Hard caps for the SQL tool to keep the agent honest. SQL_TIMEOUT_S = float(os.getenv("AGENT_SQL_TIMEOUT_S", "10")) SQL_MAX_ROWS = int(os.getenv("AGENT_SQL_MAX_ROWS", "200")) # ─── HTTP loopback helpers ────────────────────────────────────────── _http_client: httpx.AsyncClient | None = None async def _http() -> httpx.AsyncClient: """Lazily create + reuse a single httpx client (connection pool).""" global _http_client if _http_client is None: _http_client = httpx.AsyncClient(base_url=TRACKER_URL, timeout=30.0) return _http_client async def _get_json(path: str) -> Any: client = await _http() resp = await client.get(path) resp.raise_for_status() return resp.json() # ─── DB helpers ───────────────────────────────────────────────────── _db_pool: asyncpg.Pool | None = None async def _db() -> asyncpg.Pool: global _db_pool if _db_pool is None: _db_pool = await asyncpg.create_pool( DB_DSN, min_size=1, max_size=4, command_timeout=SQL_TIMEOUT_S ) return _db_pool # ─── SQL safety ───────────────────────────────────────────────────── _ALLOWED_TOPLEVEL = tuple( cls for cls in ( getattr(exp, "Select", None), getattr(exp, "With", None), getattr(exp, "Union", None), getattr(exp, "Subquery", None), getattr(exp, "Intersect", None), getattr(exp, "Except", None), ) if cls is not None ) class SqlNotAllowed(ValueError): """Raised when the agent attempts a non-read-only SQL statement.""" def assert_read_only(sql: str) -> None: """Parse `sql` and reject anything that isn't a read query. Belt-and-suspenders: the PG role is also read-only (GRANT SELECT only), so even a parser bypass can't actually mutate. This is the first line of defense — friendlier error messages and faster reject. """ try: statements = sqlglot.parse(sql, read="postgres") except sqlglot.errors.ParseError as e: raise SqlNotAllowed(f"SQL parse error: {e}") from e if not statements: raise SqlNotAllowed("empty SQL") if len(statements) > 1: raise SqlNotAllowed("only one statement allowed") stmt = statements[0] if stmt is None: raise SqlNotAllowed("empty parse result") if not isinstance(stmt, _ALLOWED_TOPLEVEL): raise SqlNotAllowed( f"only SELECT / WITH allowed, got {type(stmt).__name__}" ) # Walk the tree and reject any DML/DDL hidden inside (e.g. CTE with # INSERT — yes, postgres allows that). Use getattr so version drift # in sqlglot (renamed classes like AlterTable→Alter) doesn't crash # the whole tool. _DENY_NAMES = ( "Insert", "Update", "Delete", "Drop", "Create", "Merge", "Alter", "AlterTable", "AlterColumn", "AlterDatabase", "Truncate", "TruncateTable", "Grant", "Revoke", "Copy", # PostgreSQL COPY can write files ) deny_classes = tuple( cls for cls in (getattr(exp, name, None) for name in _DENY_NAMES) if cls is not None ) for node in stmt.walk(): # walk() returns the node, then in some sqlglot versions a tuple of # (node, parent, key). Normalize. actual = node[0] if isinstance(node, tuple) else node if isinstance(actual, deny_classes): raise SqlNotAllowed( f"writes/DDL not allowed (found {type(actual).__name__})" ) # ─── Tools ────────────────────────────────────────────────────────── async def get_live_players() -> dict[str, Any]: """Active characters (telemetry seen in the last ~30s). Returns the same shape as `GET /live`: { "players": [ { character_name, ew, ns, z, kills, ... } ] } """ return await _get_json("/live") async def get_recent_rares(hours: int = 24, limit: int = 100) -> dict[str, Any]: """Rare item finds in the last N hours, newest first.""" hours = max(1, min(int(hours), 24 * 30)) # cap at 30 days limit = max(1, min(int(limit), SQL_MAX_ROWS)) pool = await _db() rows = await pool.fetch( """ SELECT timestamp, character_name, name, ew, ns, z FROM rare_events WHERE timestamp >= NOW() - ($1::int || ' hours')::interval ORDER BY timestamp DESC LIMIT $2 """, hours, limit, ) return { "hours": hours, "count": len(rows), "rares": [ { "timestamp": r["timestamp"].isoformat(), "character_name": r["character_name"], "name": r["name"], "ew": r["ew"], "ns": r["ns"], "z": r["z"], } for r in rows ], } async def query_telemetry_db(sql: str) -> dict[str, Any]: """Run a read-only SQL statement against the telemetry DB. The query is parsed and any non-SELECT/WITH statement is rejected. The connection role is also GRANT SELECT only (defense in depth). Useful for ad-hoc questions: "top 5 KPH today", "kill count by character yesterday", etc. """ assert_read_only(sql) pool = await _db() try: rows = await asyncio.wait_for(pool.fetch(sql), timeout=SQL_TIMEOUT_S) except asyncio.TimeoutError: raise SqlNotAllowed(f"query exceeded {SQL_TIMEOUT_S:.0f}s timeout") if len(rows) > SQL_MAX_ROWS: rows = rows[:SQL_MAX_ROWS] truncated = True else: truncated = False return { "row_count": len(rows), "truncated": truncated, "rows": [ {k: _json_safe(v) for k, v in dict(r).items()} for r in rows ], } def _json_safe(v: Any) -> Any: """Convert datetime / Decimal / etc. to JSON-friendly types.""" from datetime import date, datetime, timedelta from decimal import Decimal if v is None: return None if isinstance(v, (str, int, float, bool)): return v if isinstance(v, (datetime, date)): return v.isoformat() if isinstance(v, timedelta): return v.total_seconds() if isinstance(v, Decimal): return float(v) if isinstance(v, (list, tuple)): return [_json_safe(x) for x in v] if isinstance(v, dict): return {k: _json_safe(x) for k, x in v.items()} return str(v) # ─── Per-character lookups (HTTP loopback) ────────────────────────── async def get_player_state(character_name: str) -> dict[str, Any]: """Combined snapshot for one character: live telemetry + character stats. Returns: { "character_name": str, "telemetry": {...} | None, # from /live, or None if offline "character_stats": {...} | None, # from /character-stats/ "vitals": {...} | None, # last vitals from /live (subset) "online": bool, # whether telemetry was found in /live } """ name = character_name.strip() live = await _get_json("/live") players = live.get("players", []) if isinstance(live, dict) else [] telemetry = next( (p for p in players if p.get("character_name") == name), None ) char_stats: dict[str, Any] | None = None try: client = await _http() resp = await client.get(f"/character-stats/{quote(name, safe='')}") if resp.status_code == 200: char_stats = resp.json() except Exception: char_stats = None return { "character_name": name, "online": telemetry is not None, "telemetry": telemetry, "character_stats": char_stats, } async def get_inventory(character_name: str) -> dict[str, Any]: """Full inventory for one character. Items only — for filtered queries use get_inventory_search.""" client = await _http() resp = await client.get(f"/inventory/{quote(character_name, safe='')}") resp.raise_for_status() return resp.json() async def get_inventory_search( character_name: str, filters: dict[str, Any] | None = None ) -> dict[str, Any]: """Filtered inventory search. `filters` is a dict of query params, e.g. {"name": "pearl", "armor_level_min": 500}. Caller is expected to know the supported filters from the dereth-tracker /inventory/{name}/search route — pass through opaquely. """ client = await _http() resp = await client.get( f"/inventory/{quote(character_name, safe='')}/search", params=filters or {}, ) resp.raise_for_status() return resp.json() async def search_items_global(filters: dict[str, Any]) -> dict[str, Any]: """Cross-character item search via the inventory service's /search/items. Use this INSTEAD of looping per-character when the user asks "find an X on any of my chars" — one DB query vs. 60+ HTTP roundtrips. Common filter keys (passed straight through as query params): include_all_characters: bool (set true to search every char) character: str (single char) | characters: "A,B,C" text: str (name/description substring) has_spell: "Legendary Acid Ward" — exact spell name spell_contains: "Legendary" — substring match legendary_cantrips: "Foo,Bar" equipment_status: "equipped" | "unequipped" equipment_slot: int (bitmask: 4=chest, 2048=bracelet, 4096=ring, ...) slot_names: "Bracelet,Ring" armor_only / jewelry_only / weapon_only: bool min_armor / max_armor / min_damage / max_damage: int ...and many more — see /search/items endpoint docs. """ client = await _http() # Default to all-character search if caller didn't scope; otherwise the # endpoint refuses with a 400. params = dict(filters or {}) if not any( k in params for k in ("character", "characters", "include_all_characters") ): params["include_all_characters"] = True resp = await client.get("/search/items", params=params) resp.raise_for_status() return resp.json() async def get_combat_stats(character_name: str) -> dict[str, Any]: """Lifetime + session combat stats for one character (per-element split, monster encounters, surge counts).""" client = await _http() resp = await client.get(f"/combat-stats/{quote(character_name, safe='')}") resp.raise_for_status() return resp.json() async def get_equipment_cantrips(character_name: str) -> dict[str, Any]: """Currently-equipped items + their active cantrip/spell state.""" client = await _http() resp = await client.get( f"/equipment-cantrip-state/{quote(character_name, safe='')}" ) resp.raise_for_status() return resp.json() async def get_quest_status() -> dict[str, Any]: """All characters' active quest timers and progress.""" return await _get_json("/quest-status") async def get_server_health() -> dict[str, Any]: """Coldeve server status: up/down, latency, current player count, uptime.""" return await _get_json("/server-health") async def suitbuilder_search( params: dict[str, Any], max_phase_events: int = 50 ) -> dict[str, Any]: """Drive a suitbuilder constraint search synchronously. The dereth-tracker /inv/suitbuilder/search endpoint is an SSE stream. We collect events until the stream closes, drop intermediate phase chatter (keeping the last N), and return: { "final_suits": [...], "phases": [...latest few...] } `params` is the JSON body the suitbuilder expects. Call it like the /suitbuilder.html page does. """ client = await _http() final: list[dict[str, Any]] = [] phases: list[dict[str, Any]] = [] # Use a fresh long-timeout client for the SSE stream — don't tie up the # shared pool for a 5-minute search. async with httpx.AsyncClient( base_url=TRACKER_URL, timeout=httpx.Timeout(300.0, connect=10.0) ) as stream_client: async with stream_client.stream( "POST", "/inv/suitbuilder/search", json=params, headers={"Content-Type": "application/json"}, ) as resp: event_name = "message" data_lines: list[str] = [] async for line_bytes in resp.aiter_lines(): line = line_bytes.rstrip("\r") if line.startswith("event:"): event_name = line[6:].strip() elif line.startswith("data:"): data_lines.append(line[5:].strip()) elif line == "": # Dispatch if data_lines: try: payload = json.loads("\n".join(data_lines)) except json.JSONDecodeError: payload = {"raw": "\n".join(data_lines)} if event_name == "result" or event_name == "final": final.append(payload) elif event_name == "error": phases.append({"event": "error", "data": payload}) else: phases.append({"event": event_name, "data": payload}) phases = phases[-max_phase_events:] data_lines = [] event_name = "message" return { "final_suits": final, "phases": phases[-max_phase_events:], "phase_count": len(phases), } # ─── Cleanup ──────────────────────────────────────────────────────── async def shutdown() -> None: """Close shared resources. Call from MCP server lifespan / on exit.""" global _http_client, _db_pool if _http_client is not None: await _http_client.aclose() _http_client = None if _db_pool is not None: await _db_pool.close() _db_pool = None