Files
iov_ana/sandbox_executor.py
OpenClaw Agent 96927a789d feat: 四层架构数据分析 Agent
- Layer 1 Planner: 意图规划,将问题转为结构化分析计划
- Layer 2 Explorer: 自适应探索循环,多轮迭代动态生成 SQL
- Layer 3 InsightEngine: 异常检测 + 主动洞察
- Layer 4 ContextManager: 多轮对话上下文记忆

安全设计:AI 只看 Schema + 聚合结果,不接触原始数据。
支持任意 OpenAI 兼容 API(OpenAI / Ollama / DeepSeek / vLLM)
2026-03-19 12:21:04 +08:00

140 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
沙箱执行器 —— 执行 SQL只返回聚合结果
"""
import sqlite3
import re
from typing import Any
from config import SANDBOX_RULES, DB_PATH
class SandboxError(Exception):
"""沙箱安全违规"""
pass
class SandboxExecutor:
def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path
self.rules = SANDBOX_RULES
self.execution_log: list[dict] = []
def execute(self, sql: str) -> dict[str, Any]:
"""
执行 SQL返回脱敏后的聚合结果。
如果违反安全规则,抛出 SandboxError。
"""
# 验证 SQL 安全性
self._validate(sql)
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
cur.execute(sql)
rows = cur.fetchall()
# 转为字典列表
columns = [desc[0] for desc in cur.description] if cur.description else []
results = [dict(row) for row in rows]
# 脱敏处理
sanitized = self._sanitize(results, columns)
# 记录执行日志
self.execution_log.append({
"sql": sql,
"rows_returned": len(results),
"columns": columns,
})
return {
"success": True,
"columns": columns,
"rows": sanitized,
"row_count": len(sanitized),
"sql": sql,
}
except sqlite3.Error as e:
return {
"success": False,
"error": str(e),
"sql": sql,
}
finally:
conn.close()
def _validate(self, sql: str):
"""SQL 安全验证"""
sql_upper = sql.upper().strip()
# 1. 检查禁止的关键字
for banned in self.rules["banned_keywords"]:
if banned.upper() in sql_upper:
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
# 2. 只允许 SELECT不能有多语句
statements = [s.strip() for s in sql.split(";") if s.strip()]
if len(statements) > 1:
raise SandboxError("禁止多语句执行")
if not sql_upper.startswith("SELECT"):
raise SandboxError("只允许 SELECT 查询")
# 3. 检查是否使用了聚合函数或 GROUP BY可选要求
if self.rules["require_aggregation"]:
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
"DISTINCT", "HAVING", "ROUND", "CAST"]
has_agg = any(kw in sql_upper for kw in agg_keywords)
has_limit = "LIMIT" in sql_upper
if not has_agg and not has_limit:
raise SandboxError(
"要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT"
)
# 4. LIMIT 检查
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match:
limit_val = int(limit_match.group(1))
if limit_val > self.rules["max_result_rows"]:
raise SandboxError(
f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}"
)
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
"""对结果进行脱敏处理"""
if not rows:
return rows
# 1. 限制行数
rows = rows[:self.rules["max_result_rows"]]
# 2. 浮点数四舍五入
for row in rows:
for col in columns:
val = row.get(col)
if isinstance(val, float):
row[col] = round(val, self.rules["round_floats"])
# 3. 小样本抑制k-anonymity
# 如果某个分组的 count 小于阈值,标记为 "<n"
for row in rows:
for col in columns:
if col.lower() in ("count", "cnt", "n", "total"):
val = row.get(col)
if isinstance(val, (int, float)) and val < self.rules["suppress_small_n"]:
row[col] = f"<{self.rules['suppress_small_n']}"
return rows
def get_execution_summary(self) -> str:
"""获取执行摘要"""
if not self.execution_log:
return "尚未执行任何查询"
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
for i, log in enumerate(self.execution_log, 1):
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
return "\n".join(lines)