Files
iov_ana/core/sandbox.py
Jeason b7a27b12bd SQLite 持久连接 — sandbox 不再每次查询开关连接,改为 __init__ 时建连、close() 时释放
Explorer 的 system prompt 明确告知 sandbox 规则 — "每条 SQL 必须包含聚合函数或 LIMIT",减少 LLM 生成违规 SQL 浪费轮次
LLM 客户端单例 — 所有组件共享一个 openai.OpenAI 实例,不再各建各的
sanitize 顺序修复 — 小样本抑制放在 float round 之前,避免被 round 干扰
quick_detect 从 O(n²) 改为 O(n) — 按列聚合一次,加去重,不再对每行重复算整列统计
历史上下文实际生效 — get_context_for 的结果现在会注入到 Explorer 的初始 prompt 里,多轮分析时 LLM 能看到之前的发现
2026-03-20 13:20:31 +08:00

100 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
沙箱执行器 —— 执行 SQL只返回聚合结果
"""
import sqlite3
import re
from typing import Any
from core.config import SANDBOX_RULES
class SandboxError(Exception):
"""沙箱安全违规"""
pass
class SandboxExecutor:
def __init__(self, db_path: str):
self.db_path = db_path
self.rules = SANDBOX_RULES
self.execution_log: list[dict] = []
# 持久连接,避免每次查询都开关
self._conn = sqlite3.connect(db_path)
self._conn.row_factory = sqlite3.Row
def execute(self, sql: str) -> dict[str, Any]:
"""执行 SQL返回脱敏后的聚合结果"""
self._validate(sql)
cur = self._conn.cursor()
try:
cur.execute(sql)
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description] if cur.description else []
results = [dict(row) for row in rows]
sanitized = self._sanitize(results, columns)
self.execution_log.append({
"sql": sql, "rows_returned": len(results), "columns": columns,
})
return {
"success": True, "columns": columns,
"rows": sanitized, "row_count": len(sanitized), "sql": sql,
}
except sqlite3.Error as e:
return {"success": False, "error": str(e), "sql": sql}
def close(self):
"""关闭连接"""
if self._conn:
self._conn.close()
self._conn = None
def _validate(self, sql: str):
"""SQL 安全验证"""
sql_upper = sql.upper().strip()
for banned in self.rules["banned_keywords"]:
if banned.upper() in sql_upper:
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
statements = [s.strip() for s in sql.split(";") if s.strip()]
if len(statements) > 1:
raise SandboxError("禁止多语句执行")
if not sql_upper.startswith("SELECT"):
raise SandboxError("只允许 SELECT 查询")
if self.rules["require_aggregation"]:
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
"DISTINCT", "HAVING", "ROUND", "CAST"]
has_agg = any(kw in sql_upper for kw in agg_keywords)
if not has_agg and "LIMIT" not in sql_upper:
raise SandboxError("要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT")
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match and int(limit_match.group(1)) > self.rules["max_result_rows"]:
raise SandboxError(f"LIMIT 超过最大允许值 {self.rules['max_result_rows']}")
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
"""脱敏处理"""
rows = rows[:self.rules["max_result_rows"]]
suppress_n = self.rules["suppress_small_n"]
round_digits = self.rules["round_floats"]
for row in rows:
for col in columns:
val = row.get(col)
# 小样本抑制(先做,避免被 round 影响)
if col.lower() in ("count", "cnt", "n", "total"):
if isinstance(val, (int, float)) and val < suppress_n:
row[col] = f"<{suppress_n}"
continue
# 浮点数四舍五入
if isinstance(val, float):
row[col] = round(val, round_digits)
return rows
def get_execution_summary(self) -> str:
if not self.execution_log:
return "尚未执行任何查询"
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
for i, log in enumerate(self.execution_log, 1):
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
return "\n".join(lines)