MosswartOverlord/agent/claude_wrapper.py

149 lines
4.9 KiB
Python

"""Subprocess wrapper around `claude -p` (Claude Code in headless JSON mode).
Run from cwd=/home/erik/MosswartOverlord so:
• Sessions persist at ~/.claude/projects/-home-erik-MosswartOverlord/<uuid>.jsonl
• Project-level .mcp.json is auto-loaded
• CLAUDE.md in the repo root briefs the agent
The `--session-id` flag both creates a new session (first call) and resumes
an existing one (subsequent calls), so we don't need separate code paths.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# These can be overridden via env vars for non-prod testing.
CLAUDE_BIN = os.getenv("CLAUDE_BIN", "/home/erik/.local/bin/claude")
CLAUDE_CWD = os.getenv("CLAUDE_CWD", "/home/erik/MosswartOverlord")
# Hard cap on how long a single agent turn may take. Claude Code can spin a
# while when chaining many tool calls; we don't want to leave a zombie
# subprocess if something gets stuck.
CLAUDE_TIMEOUT_S = int(os.getenv("CLAUDE_TIMEOUT_S", "120"))
@dataclass
class ClaudeResult:
result: str
session_id: str
duration_ms: int
num_turns: int
is_error: bool
raw: dict[str, Any]
class ClaudeError(RuntimeError):
"""Raised when the claude CLI returns a non-zero exit or unparseable output."""
async def ask_claude(message: str, session_id: str) -> ClaudeResult:
"""Send `message` to `claude -p` resuming session_id; return parsed result.
Raises ClaudeError on subprocess failure, JSON parse failure, or timeout.
"""
if not Path(CLAUDE_BIN).exists():
raise ClaudeError(f"claude binary not found at {CLAUDE_BIN}")
if not Path(CLAUDE_CWD).is_dir():
raise ClaudeError(f"CLAUDE_CWD does not exist: {CLAUDE_CWD}")
# Whitelist only our MCP tools so Claude Code can call them without
# human approval. Names follow the convention mcp__<server>__<tool>.
# We deliberately omit built-in tools (Bash, Write, Edit, Read, etc.)
# — the assistant doesn't need them for live-state Q&A and they'd be a
# security/permissions footgun on an unattended service.
allowed_tools = ",".join(
[
"mcp__overlord__get_live_players",
"mcp__overlord__get_recent_rares",
"mcp__overlord__query_telemetry_db",
"mcp__overlord__get_player_state",
"mcp__overlord__get_inventory",
"mcp__overlord__get_inventory_search",
"mcp__overlord__get_combat_stats",
"mcp__overlord__get_equipment_cantrips",
"mcp__overlord__get_quest_status",
"mcp__overlord__get_server_health",
"mcp__overlord__suitbuilder_search",
]
)
args = [
CLAUDE_BIN,
"-p",
"--session-id",
session_id,
"--output-format",
"json",
"--allowed-tools",
allowed_tools,
# Auto-approve any tool that's in --allowed-tools.
"--permission-mode",
"bypassPermissions",
]
logger.info(
"claude exec: session=%s msg_len=%d cwd=%s", session_id, len(message), CLAUDE_CWD
)
proc = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=CLAUDE_CWD,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(input=message.encode("utf-8")),
timeout=CLAUDE_TIMEOUT_S,
)
except asyncio.TimeoutError:
try:
proc.kill()
except ProcessLookupError:
pass
raise ClaudeError(f"claude timed out after {CLAUDE_TIMEOUT_S}s")
if proc.returncode != 0:
raise ClaudeError(
f"claude exited {proc.returncode}: {stderr.decode('utf-8', 'replace')[:500]}"
)
raw_text = stdout.decode("utf-8", "replace").strip()
if not raw_text:
raise ClaudeError("claude produced empty stdout")
# In --output-format json mode the LAST line is the JSON envelope; some
# earlier lines may be progress. Be tolerant.
try:
envelope = json.loads(raw_text)
except json.JSONDecodeError:
# Try the last non-empty line
last = next(
(line for line in reversed(raw_text.splitlines()) if line.strip()),
"",
)
try:
envelope = json.loads(last)
except json.JSONDecodeError as e:
raise ClaudeError(
f"claude stdout was not JSON: {raw_text[:500]}"
) from e
return ClaudeResult(
result=envelope.get("result", ""),
session_id=envelope.get("session_id", session_id),
duration_ms=int(envelope.get("duration_ms", 0)),
num_turns=int(envelope.get("num_turns", 0)),
is_error=bool(envelope.get("is_error", False)),
raw=envelope,
)