""" 沙箱执行器 —— 执行 SQL,只返回聚合结果 """ import sqlite3 import re from typing import Any from config import SANDBOX_RULES, DB_PATH class SandboxError(Exception): """沙箱安全违规""" pass class SandboxExecutor: def __init__(self, db_path: str = DB_PATH): self.db_path = db_path self.rules = SANDBOX_RULES self.execution_log: list[dict] = [] def execute(self, sql: str) -> dict[str, Any]: """ 执行 SQL,返回脱敏后的聚合结果。 如果违反安全规则,抛出 SandboxError。 """ # 验证 SQL 安全性 self._validate(sql) conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row cur = conn.cursor() try: cur.execute(sql) rows = cur.fetchall() # 转为字典列表 columns = [desc[0] for desc in cur.description] if cur.description else [] results = [dict(row) for row in rows] # 脱敏处理 sanitized = self._sanitize(results, columns) # 记录执行日志 self.execution_log.append({ "sql": sql, "rows_returned": len(results), "columns": columns, }) return { "success": True, "columns": columns, "rows": sanitized, "row_count": len(sanitized), "sql": sql, } except sqlite3.Error as e: return { "success": False, "error": str(e), "sql": sql, } finally: conn.close() def _validate(self, sql: str): """SQL 安全验证""" sql_upper = sql.upper().strip() # 1. 检查禁止的关键字 for banned in self.rules["banned_keywords"]: if banned.upper() in sql_upper: raise SandboxError(f"禁止的 SQL 关键字: {banned}") # 2. 只允许 SELECT(不能有多语句) statements = [s.strip() for s in sql.split(";") if s.strip()] if len(statements) > 1: raise SandboxError("禁止多语句执行") if not sql_upper.startswith("SELECT"): raise SandboxError("只允许 SELECT 查询") # 3. 检查是否使用了聚合函数或 GROUP BY(可选要求) if self.rules["require_aggregation"]: agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY", "DISTINCT", "HAVING", "ROUND", "CAST"] has_agg = any(kw in sql_upper for kw in agg_keywords) has_limit = "LIMIT" in sql_upper if not has_agg and not has_limit: raise SandboxError( "要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT" ) # 4. LIMIT 检查 limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) if limit_match: limit_val = int(limit_match.group(1)) if limit_val > self.rules["max_result_rows"]: raise SandboxError( f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}" ) def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]: """对结果进行脱敏处理""" if not rows: return rows # 1. 限制行数 rows = rows[:self.rules["max_result_rows"]] # 2. 浮点数四舍五入 for row in rows: for col in columns: val = row.get(col) if isinstance(val, float): row[col] = round(val, self.rules["round_floats"]) # 3. 小样本抑制(k-anonymity) # 如果某个分组的 count 小于阈值,标记为 " str: """获取执行摘要""" if not self.execution_log: return "尚未执行任何查询" lines = [f"共执行 {len(self.execution_log)} 条查询:"] for i, log in enumerate(self.execution_log, 1): lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果") return "\n".join(lines)