Adds an MCP tool wrapping the inventory-service /search/items endpoint with include_all_characters=true, so questions like 'find me a bracelet with Legendary Acid Ward on any unequipped char' resolve in ONE tool call instead of looping get_inventory over 60+ chars (which timed out at 120s). - agent/tools.py: search_items_global wrapper - agent/mcp_overlord.py: register new tool with detailed schema doc - agent/claude_wrapper.py: include in --allowed-tools whitelist; bump timeout 120s -> 240s - nginx/overlord.conf: bump /api/agent/ proxy timeout 180s -> 300s - CLAUDE.md: brief Claude to USE search_items for cross-char searches
435 lines
15 KiB
Python
435 lines
15 KiB
Python
"""Tool implementations exposed to Claude via the MCP server.
|
|
|
|
These are pure functions — the MCP server (mcp_overlord.py) only handles
|
|
the protocol wrapping. Keep tool logic here so it's easy to test in
|
|
isolation and reuse from elsewhere (e.g. /agent/ask shortcuts).
|
|
|
|
Two flavors of data access:
|
|
* HTTP loopback to the dereth-tracker container (for endpoints that
|
|
already exist and have validated logic).
|
|
* Direct asyncpg to the read-only PG role for ad-hoc queries
|
|
(rare_events, telemetry, anything not exposed via HTTP).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
import asyncpg
|
|
import httpx
|
|
import sqlglot
|
|
import sqlglot.errors
|
|
import sqlglot.expressions as exp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# The dereth-tracker FastAPI app, reachable from the host because Docker
|
|
# port-forwards 127.0.0.1:8765:8765 in docker-compose.yml.
|
|
TRACKER_URL = os.getenv("TRACKER_URL", "http://127.0.0.1:8765")
|
|
|
|
# Read-only PG role; see deployment plan.
|
|
DB_DSN = os.getenv(
|
|
"AGENT_DB_DSN",
|
|
"postgresql://overlord_agent_ro@127.0.0.1:5432/dereth",
|
|
)
|
|
|
|
# Hard caps for the SQL tool to keep the agent honest.
|
|
SQL_TIMEOUT_S = float(os.getenv("AGENT_SQL_TIMEOUT_S", "10"))
|
|
SQL_MAX_ROWS = int(os.getenv("AGENT_SQL_MAX_ROWS", "200"))
|
|
|
|
|
|
# ─── HTTP loopback helpers ──────────────────────────────────────────
|
|
|
|
|
|
_http_client: httpx.AsyncClient | None = None
|
|
|
|
|
|
async def _http() -> httpx.AsyncClient:
|
|
"""Lazily create + reuse a single httpx client (connection pool)."""
|
|
global _http_client
|
|
if _http_client is None:
|
|
_http_client = httpx.AsyncClient(base_url=TRACKER_URL, timeout=30.0)
|
|
return _http_client
|
|
|
|
|
|
async def _get_json(path: str) -> Any:
|
|
client = await _http()
|
|
resp = await client.get(path)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
# ─── DB helpers ─────────────────────────────────────────────────────
|
|
|
|
|
|
_db_pool: asyncpg.Pool | None = None
|
|
|
|
|
|
async def _db() -> asyncpg.Pool:
|
|
global _db_pool
|
|
if _db_pool is None:
|
|
_db_pool = await asyncpg.create_pool(
|
|
DB_DSN, min_size=1, max_size=4, command_timeout=SQL_TIMEOUT_S
|
|
)
|
|
return _db_pool
|
|
|
|
|
|
# ─── SQL safety ─────────────────────────────────────────────────────
|
|
|
|
|
|
_ALLOWED_TOPLEVEL = (exp.Select, exp.With, exp.Union, exp.Subquery)
|
|
|
|
|
|
class SqlNotAllowed(ValueError):
|
|
"""Raised when the agent attempts a non-read-only SQL statement."""
|
|
|
|
|
|
def assert_read_only(sql: str) -> None:
|
|
"""Parse `sql` and reject anything that isn't a read query.
|
|
|
|
Belt-and-suspenders: the PG role is also read-only (GRANT SELECT only),
|
|
so even a parser bypass can't actually mutate. This is the first line
|
|
of defense — friendlier error messages and faster reject.
|
|
"""
|
|
try:
|
|
statements = sqlglot.parse(sql, read="postgres")
|
|
except sqlglot.errors.ParseError as e:
|
|
raise SqlNotAllowed(f"SQL parse error: {e}") from e
|
|
|
|
if not statements:
|
|
raise SqlNotAllowed("empty SQL")
|
|
if len(statements) > 1:
|
|
raise SqlNotAllowed("only one statement allowed")
|
|
|
|
stmt = statements[0]
|
|
if not isinstance(stmt, _ALLOWED_TOPLEVEL):
|
|
raise SqlNotAllowed(
|
|
f"only SELECT / WITH allowed, got {type(stmt).__name__}"
|
|
)
|
|
|
|
# Walk the tree and reject any DML/DDL hidden inside (e.g. CTE with
|
|
# INSERT — yes, postgres allows that).
|
|
for node in stmt.walk():
|
|
if isinstance(
|
|
node,
|
|
(
|
|
exp.Insert,
|
|
exp.Update,
|
|
exp.Delete,
|
|
exp.Drop,
|
|
exp.AlterTable,
|
|
exp.Create,
|
|
exp.TruncateTable,
|
|
exp.Merge,
|
|
),
|
|
):
|
|
raise SqlNotAllowed(
|
|
f"writes/DDL not allowed (found {type(node).__name__})"
|
|
)
|
|
|
|
|
|
# ─── Tools ──────────────────────────────────────────────────────────
|
|
|
|
|
|
async def get_live_players() -> dict[str, Any]:
|
|
"""Active characters (telemetry seen in the last ~30s).
|
|
|
|
Returns the same shape as `GET /live`:
|
|
{ "players": [ { character_name, ew, ns, z, kills, ... } ] }
|
|
"""
|
|
return await _get_json("/live")
|
|
|
|
|
|
async def get_recent_rares(hours: int = 24, limit: int = 100) -> dict[str, Any]:
|
|
"""Rare item finds in the last N hours, newest first."""
|
|
hours = max(1, min(int(hours), 24 * 30)) # cap at 30 days
|
|
limit = max(1, min(int(limit), SQL_MAX_ROWS))
|
|
pool = await _db()
|
|
rows = await pool.fetch(
|
|
"""
|
|
SELECT timestamp, character_name, name, ew, ns, z
|
|
FROM rare_events
|
|
WHERE timestamp >= NOW() - ($1::int || ' hours')::interval
|
|
ORDER BY timestamp DESC
|
|
LIMIT $2
|
|
""",
|
|
hours,
|
|
limit,
|
|
)
|
|
return {
|
|
"hours": hours,
|
|
"count": len(rows),
|
|
"rares": [
|
|
{
|
|
"timestamp": r["timestamp"].isoformat(),
|
|
"character_name": r["character_name"],
|
|
"name": r["name"],
|
|
"ew": r["ew"],
|
|
"ns": r["ns"],
|
|
"z": r["z"],
|
|
}
|
|
for r in rows
|
|
],
|
|
}
|
|
|
|
|
|
async def query_telemetry_db(sql: str) -> dict[str, Any]:
|
|
"""Run a read-only SQL statement against the telemetry DB.
|
|
|
|
The query is parsed and any non-SELECT/WITH statement is rejected.
|
|
The connection role is also GRANT SELECT only (defense in depth).
|
|
|
|
Useful for ad-hoc questions: "top 5 KPH today", "kill count by character
|
|
yesterday", etc.
|
|
"""
|
|
assert_read_only(sql)
|
|
pool = await _db()
|
|
try:
|
|
rows = await asyncio.wait_for(pool.fetch(sql), timeout=SQL_TIMEOUT_S)
|
|
except asyncio.TimeoutError:
|
|
raise SqlNotAllowed(f"query exceeded {SQL_TIMEOUT_S:.0f}s timeout")
|
|
|
|
if len(rows) > SQL_MAX_ROWS:
|
|
rows = rows[:SQL_MAX_ROWS]
|
|
truncated = True
|
|
else:
|
|
truncated = False
|
|
|
|
return {
|
|
"row_count": len(rows),
|
|
"truncated": truncated,
|
|
"rows": [
|
|
{k: _json_safe(v) for k, v in dict(r).items()} for r in rows
|
|
],
|
|
}
|
|
|
|
|
|
def _json_safe(v: Any) -> Any:
|
|
"""Convert datetime / Decimal / etc. to JSON-friendly types."""
|
|
from datetime import date, datetime, timedelta
|
|
from decimal import Decimal
|
|
|
|
if v is None:
|
|
return None
|
|
if isinstance(v, (str, int, float, bool)):
|
|
return v
|
|
if isinstance(v, (datetime, date)):
|
|
return v.isoformat()
|
|
if isinstance(v, timedelta):
|
|
return v.total_seconds()
|
|
if isinstance(v, Decimal):
|
|
return float(v)
|
|
if isinstance(v, (list, tuple)):
|
|
return [_json_safe(x) for x in v]
|
|
if isinstance(v, dict):
|
|
return {k: _json_safe(x) for k, x in v.items()}
|
|
return str(v)
|
|
|
|
|
|
# ─── Per-character lookups (HTTP loopback) ──────────────────────────
|
|
|
|
|
|
async def get_player_state(character_name: str) -> dict[str, Any]:
|
|
"""Combined snapshot for one character: live telemetry + character stats.
|
|
|
|
Returns:
|
|
{
|
|
"character_name": str,
|
|
"telemetry": {...} | None, # from /live, or None if offline
|
|
"character_stats": {...} | None, # from /character-stats/<name>
|
|
"vitals": {...} | None, # last vitals from /live (subset)
|
|
"online": bool, # whether telemetry was found in /live
|
|
}
|
|
"""
|
|
name = character_name.strip()
|
|
live = await _get_json("/live")
|
|
players = live.get("players", []) if isinstance(live, dict) else []
|
|
telemetry = next(
|
|
(p for p in players if p.get("character_name") == name), None
|
|
)
|
|
|
|
char_stats: dict[str, Any] | None = None
|
|
try:
|
|
client = await _http()
|
|
resp = await client.get(f"/character-stats/{quote(name, safe='')}")
|
|
if resp.status_code == 200:
|
|
char_stats = resp.json()
|
|
except Exception:
|
|
char_stats = None
|
|
|
|
return {
|
|
"character_name": name,
|
|
"online": telemetry is not None,
|
|
"telemetry": telemetry,
|
|
"character_stats": char_stats,
|
|
}
|
|
|
|
|
|
async def get_inventory(character_name: str) -> dict[str, Any]:
|
|
"""Full inventory for one character. Items only — for filtered queries
|
|
use get_inventory_search."""
|
|
client = await _http()
|
|
resp = await client.get(f"/inventory/{quote(character_name, safe='')}")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def get_inventory_search(
|
|
character_name: str, filters: dict[str, Any] | None = None
|
|
) -> dict[str, Any]:
|
|
"""Filtered inventory search. `filters` is a dict of query params, e.g.
|
|
{"name": "pearl", "armor_level_min": 500}.
|
|
|
|
Caller is expected to know the supported filters from the dereth-tracker
|
|
/inventory/{name}/search route — pass through opaquely.
|
|
"""
|
|
client = await _http()
|
|
resp = await client.get(
|
|
f"/inventory/{quote(character_name, safe='')}/search",
|
|
params=filters or {},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def search_items_global(filters: dict[str, Any]) -> dict[str, Any]:
|
|
"""Cross-character item search via the inventory service's /search/items.
|
|
|
|
Use this INSTEAD of looping per-character when the user asks "find an X
|
|
on any of my chars" — one DB query vs. 60+ HTTP roundtrips.
|
|
|
|
Common filter keys (passed straight through as query params):
|
|
include_all_characters: bool (set true to search every char)
|
|
character: str (single char) | characters: "A,B,C"
|
|
text: str (name/description substring)
|
|
has_spell: "Legendary Acid Ward" — exact spell name
|
|
spell_contains: "Legendary" — substring match
|
|
legendary_cantrips: "Foo,Bar"
|
|
equipment_status: "equipped" | "unequipped"
|
|
equipment_slot: int (bitmask: 4=chest, 2048=bracelet, 4096=ring, ...)
|
|
slot_names: "Bracelet,Ring"
|
|
armor_only / jewelry_only / weapon_only: bool
|
|
min_armor / max_armor / min_damage / max_damage: int
|
|
...and many more — see /search/items endpoint docs.
|
|
"""
|
|
client = await _http()
|
|
# Default to all-character search if caller didn't scope; otherwise the
|
|
# endpoint refuses with a 400.
|
|
params = dict(filters or {})
|
|
if not any(
|
|
k in params
|
|
for k in ("character", "characters", "include_all_characters")
|
|
):
|
|
params["include_all_characters"] = True
|
|
resp = await client.get("/search/items", params=params)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def get_combat_stats(character_name: str) -> dict[str, Any]:
|
|
"""Lifetime + session combat stats for one character (per-element split,
|
|
monster encounters, surge counts)."""
|
|
client = await _http()
|
|
resp = await client.get(f"/combat-stats/{quote(character_name, safe='')}")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def get_equipment_cantrips(character_name: str) -> dict[str, Any]:
|
|
"""Currently-equipped items + their active cantrip/spell state."""
|
|
client = await _http()
|
|
resp = await client.get(
|
|
f"/equipment-cantrip-state/{quote(character_name, safe='')}"
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def get_quest_status() -> dict[str, Any]:
|
|
"""All characters' active quest timers and progress."""
|
|
return await _get_json("/quest-status")
|
|
|
|
|
|
async def get_server_health() -> dict[str, Any]:
|
|
"""Coldeve server status: up/down, latency, current player count, uptime."""
|
|
return await _get_json("/server-health")
|
|
|
|
|
|
async def suitbuilder_search(
|
|
params: dict[str, Any], max_phase_events: int = 50
|
|
) -> dict[str, Any]:
|
|
"""Drive a suitbuilder constraint search synchronously.
|
|
|
|
The dereth-tracker /inv/suitbuilder/search endpoint is an SSE stream.
|
|
We collect events until the stream closes, drop intermediate phase
|
|
chatter (keeping the last N), and return:
|
|
|
|
{ "final_suits": [...], "phases": [...latest few...] }
|
|
|
|
`params` is the JSON body the suitbuilder expects. Call it like the
|
|
/suitbuilder.html page does.
|
|
"""
|
|
client = await _http()
|
|
final: list[dict[str, Any]] = []
|
|
phases: list[dict[str, Any]] = []
|
|
|
|
# Use a fresh long-timeout client for the SSE stream — don't tie up the
|
|
# shared pool for a 5-minute search.
|
|
async with httpx.AsyncClient(
|
|
base_url=TRACKER_URL, timeout=httpx.Timeout(300.0, connect=10.0)
|
|
) as stream_client:
|
|
async with stream_client.stream(
|
|
"POST",
|
|
"/inv/suitbuilder/search",
|
|
json=params,
|
|
headers={"Content-Type": "application/json"},
|
|
) as resp:
|
|
event_name = "message"
|
|
data_lines: list[str] = []
|
|
async for line_bytes in resp.aiter_lines():
|
|
line = line_bytes.rstrip("\r")
|
|
if line.startswith("event:"):
|
|
event_name = line[6:].strip()
|
|
elif line.startswith("data:"):
|
|
data_lines.append(line[5:].strip())
|
|
elif line == "":
|
|
# Dispatch
|
|
if data_lines:
|
|
try:
|
|
payload = json.loads("\n".join(data_lines))
|
|
except json.JSONDecodeError:
|
|
payload = {"raw": "\n".join(data_lines)}
|
|
if event_name == "result" or event_name == "final":
|
|
final.append(payload)
|
|
elif event_name == "error":
|
|
phases.append({"event": "error", "data": payload})
|
|
else:
|
|
phases.append({"event": event_name, "data": payload})
|
|
phases = phases[-max_phase_events:]
|
|
data_lines = []
|
|
event_name = "message"
|
|
|
|
return {
|
|
"final_suits": final,
|
|
"phases": phases[-max_phase_events:],
|
|
"phase_count": len(phases),
|
|
}
|
|
|
|
|
|
# ─── Cleanup ────────────────────────────────────────────────────────
|
|
|
|
|
|
async def shutdown() -> None:
|
|
"""Close shared resources. Call from MCP server lifespan / on exit."""
|
|
global _http_client, _db_pool
|
|
if _http_client is not None:
|
|
await _http_client.aclose()
|
|
_http_client = None
|
|
if _db_pool is not None:
|
|
await _db_pool.close()
|
|
_db_pool = None
|