diff --git a/agent/tools.py b/agent/tools.py index 554f587f..c7b3b8fe 100644 --- a/agent/tools.py +++ b/agent/tools.py @@ -82,7 +82,17 @@ async def _db() -> asyncpg.Pool: # ─── SQL safety ───────────────────────────────────────────────────── -_ALLOWED_TOPLEVEL = (exp.Select, exp.With, exp.Union, exp.Subquery) +_ALLOWED_TOPLEVEL = tuple( + cls for cls in ( + getattr(exp, "Select", None), + getattr(exp, "With", None), + getattr(exp, "Union", None), + getattr(exp, "Subquery", None), + getattr(exp, "Intersect", None), + getattr(exp, "Except", None), + ) + if cls is not None +) class SqlNotAllowed(ValueError): @@ -107,29 +117,35 @@ def assert_read_only(sql: str) -> None: raise SqlNotAllowed("only one statement allowed") stmt = statements[0] + if stmt is None: + raise SqlNotAllowed("empty parse result") if not isinstance(stmt, _ALLOWED_TOPLEVEL): raise SqlNotAllowed( f"only SELECT / WITH allowed, got {type(stmt).__name__}" ) # Walk the tree and reject any DML/DDL hidden inside (e.g. CTE with - # INSERT — yes, postgres allows that). + # INSERT — yes, postgres allows that). Use getattr so version drift + # in sqlglot (renamed classes like AlterTable→Alter) doesn't crash + # the whole tool. + _DENY_NAMES = ( + "Insert", "Update", "Delete", "Drop", "Create", "Merge", + "Alter", "AlterTable", "AlterColumn", "AlterDatabase", + "Truncate", "TruncateTable", + "Grant", "Revoke", + "Copy", # PostgreSQL COPY can write files + ) + deny_classes = tuple( + cls for cls in (getattr(exp, name, None) for name in _DENY_NAMES) + if cls is not None + ) for node in stmt.walk(): - if isinstance( - node, - ( - exp.Insert, - exp.Update, - exp.Delete, - exp.Drop, - exp.AlterTable, - exp.Create, - exp.TruncateTable, - exp.Merge, - ), - ): + # walk() returns the node, then in some sqlglot versions a tuple of + # (node, parent, key). Normalize. + actual = node[0] if isinstance(node, tuple) else node + if isinstance(actual, deny_classes): raise SqlNotAllowed( - f"writes/DDL not allowed (found {type(node).__name__})" + f"writes/DDL not allowed (found {type(actual).__name__})" )