feat: 四层架构数据分析 Agent
- Layer 1 Planner: 意图规划,将问题转为结构化分析计划 - Layer 2 Explorer: 自适应探索循环,多轮迭代动态生成 SQL - Layer 3 InsightEngine: 异常检测 + 主动洞察 - Layer 4 ContextManager: 多轮对话上下文记忆 安全设计:AI 只看 Schema + 聚合结果,不接触原始数据。 支持任意 OpenAI 兼容 API(OpenAI / Ollama / DeepSeek / vLLM)
This commit is contained in:
139
sandbox_executor.py
Normal file
139
sandbox_executor.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
沙箱执行器 —— 执行 SQL,只返回聚合结果
|
||||
"""
|
||||
import sqlite3
|
||||
import re
|
||||
from typing import Any
|
||||
from config import SANDBOX_RULES, DB_PATH
|
||||
|
||||
|
||||
class SandboxError(Exception):
|
||||
"""沙箱安全违规"""
|
||||
pass
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
self.db_path = db_path
|
||||
self.rules = SANDBOX_RULES
|
||||
self.execution_log: list[dict] = []
|
||||
|
||||
def execute(self, sql: str) -> dict[str, Any]:
|
||||
"""
|
||||
执行 SQL,返回脱敏后的聚合结果。
|
||||
如果违反安全规则,抛出 SandboxError。
|
||||
"""
|
||||
# 验证 SQL 安全性
|
||||
self._validate(sql)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
|
||||
# 转为字典列表
|
||||
columns = [desc[0] for desc in cur.description] if cur.description else []
|
||||
results = [dict(row) for row in rows]
|
||||
|
||||
# 脱敏处理
|
||||
sanitized = self._sanitize(results, columns)
|
||||
|
||||
# 记录执行日志
|
||||
self.execution_log.append({
|
||||
"sql": sql,
|
||||
"rows_returned": len(results),
|
||||
"columns": columns,
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"columns": columns,
|
||||
"rows": sanitized,
|
||||
"row_count": len(sanitized),
|
||||
"sql": sql,
|
||||
}
|
||||
|
||||
except sqlite3.Error as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"sql": sql,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _validate(self, sql: str):
|
||||
"""SQL 安全验证"""
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
# 1. 检查禁止的关键字
|
||||
for banned in self.rules["banned_keywords"]:
|
||||
if banned.upper() in sql_upper:
|
||||
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
|
||||
|
||||
# 2. 只允许 SELECT(不能有多语句)
|
||||
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
||||
if len(statements) > 1:
|
||||
raise SandboxError("禁止多语句执行")
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
raise SandboxError("只允许 SELECT 查询")
|
||||
|
||||
# 3. 检查是否使用了聚合函数或 GROUP BY(可选要求)
|
||||
if self.rules["require_aggregation"]:
|
||||
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
|
||||
"DISTINCT", "HAVING", "ROUND", "CAST"]
|
||||
has_agg = any(kw in sql_upper for kw in agg_keywords)
|
||||
has_limit = "LIMIT" in sql_upper
|
||||
|
||||
if not has_agg and not has_limit:
|
||||
raise SandboxError(
|
||||
"要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT"
|
||||
)
|
||||
|
||||
# 4. LIMIT 检查
|
||||
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
|
||||
if limit_match:
|
||||
limit_val = int(limit_match.group(1))
|
||||
if limit_val > self.rules["max_result_rows"]:
|
||||
raise SandboxError(
|
||||
f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}"
|
||||
)
|
||||
|
||||
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
|
||||
"""对结果进行脱敏处理"""
|
||||
if not rows:
|
||||
return rows
|
||||
|
||||
# 1. 限制行数
|
||||
rows = rows[:self.rules["max_result_rows"]]
|
||||
|
||||
# 2. 浮点数四舍五入
|
||||
for row in rows:
|
||||
for col in columns:
|
||||
val = row.get(col)
|
||||
if isinstance(val, float):
|
||||
row[col] = round(val, self.rules["round_floats"])
|
||||
|
||||
# 3. 小样本抑制(k-anonymity)
|
||||
# 如果某个分组的 count 小于阈值,标记为 "<n"
|
||||
for row in rows:
|
||||
for col in columns:
|
||||
if col.lower() in ("count", "cnt", "n", "total"):
|
||||
val = row.get(col)
|
||||
if isinstance(val, (int, float)) and val < self.rules["suppress_small_n"]:
|
||||
row[col] = f"<{self.rules['suppress_small_n']}"
|
||||
|
||||
return rows
|
||||
|
||||
def get_execution_summary(self) -> str:
|
||||
"""获取执行摘要"""
|
||||
if not self.execution_log:
|
||||
return "尚未执行任何查询"
|
||||
|
||||
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
|
||||
for i, log in enumerate(self.execution_log, 1):
|
||||
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user