""" 沙箱执行器 —— 执行 SQL,只返回聚合结果 """ import sqlite3 import re from typing import Any from core.config import SANDBOX_RULES class SandboxError(Exception): """沙箱安全违规""" pass class SandboxExecutor: def __init__(self, db_path: str): self.db_path = db_path self.rules = SANDBOX_RULES self.execution_log: list[dict] = [] # 持久连接,避免每次查询都开关 self._conn = sqlite3.connect(db_path) self._conn.row_factory = sqlite3.Row def execute(self, sql: str) -> dict[str, Any]: """执行 SQL,返回脱敏后的聚合结果""" self._validate(sql) cur = self._conn.cursor() try: cur.execute(sql) rows = cur.fetchall() columns = [desc[0] for desc in cur.description] if cur.description else [] results = [dict(row) for row in rows] sanitized = self._sanitize(results, columns) self.execution_log.append({ "sql": sql, "rows_returned": len(results), "columns": columns, }) return { "success": True, "columns": columns, "rows": sanitized, "row_count": len(sanitized), "sql": sql, } except sqlite3.Error as e: return {"success": False, "error": str(e), "sql": sql} def close(self): """关闭连接""" if self._conn: self._conn.close() self._conn = None def _validate(self, sql: str): """SQL 安全验证""" sql_upper = sql.upper().strip() for banned in self.rules["banned_keywords"]: if banned.upper() in sql_upper: raise SandboxError(f"禁止的 SQL 关键字: {banned}") statements = [s.strip() for s in sql.split(";") if s.strip()] if len(statements) > 1: raise SandboxError("禁止多语句执行") if not sql_upper.startswith("SELECT"): raise SandboxError("只允许 SELECT 查询") if self.rules["require_aggregation"]: agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY", "DISTINCT", "HAVING", "ROUND", "CAST"] has_agg = any(kw in sql_upper for kw in agg_keywords) if not has_agg and "LIMIT" not in sql_upper: raise SandboxError("要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT") limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) if limit_match and int(limit_match.group(1)) > self.rules["max_result_rows"]: raise SandboxError(f"LIMIT 超过最大允许值 {self.rules['max_result_rows']}") def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]: """脱敏处理""" rows = rows[:self.rules["max_result_rows"]] suppress_n = self.rules["suppress_small_n"] round_digits = self.rules["round_floats"] for row in rows: for col in columns: val = row.get(col) # 小样本抑制(先做,避免被 round 影响) if col.lower() in ("count", "cnt", "n", "total"): if isinstance(val, (int, float)) and val < suppress_n: row[col] = f"<{suppress_n}" continue # 浮点数四舍五入 if isinstance(val, float): row[col] = round(val, round_digits) return rows def get_execution_summary(self) -> str: if not self.execution_log: return "尚未执行任何查询" lines = [f"共执行 {len(self.execution_log)} 条查询:"] for i, log in enumerate(self.execution_log, 1): lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果") return "\n".join(lines)