2026-03-19 12:21:04 +08:00
|
|
|
|
"""
|
|
|
|
|
|
沙箱执行器 —— 执行 SQL,只返回聚合结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
import re
|
|
|
|
|
|
from typing import Any
|
2026-03-20 13:20:31 +08:00
|
|
|
|
from core.config import SANDBOX_RULES
|
2026-03-19 12:21:04 +08:00
|
|
|
|
class SandboxError(Exception):
|
|
|
|
|
|
"""沙箱安全违规"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
class SandboxExecutor:
|
2026-03-20 13:20:31 +08:00
|
|
|
|
def __init__(self, db_path: str):
|
2026-03-19 12:21:04 +08:00
|
|
|
|
self.db_path = db_path
|
|
|
|
|
|
self.rules = SANDBOX_RULES
|
|
|
|
|
|
self.execution_log: list[dict] = []
|
2026-03-20 13:20:31 +08:00
|
|
|
|
# 持久连接,避免每次查询都开关
|
|
|
|
|
|
self._conn = sqlite3.connect(db_path)
|
|
|
|
|
|
self._conn.row_factory = sqlite3.Row
|
2026-03-19 12:21:04 +08:00
|
|
|
|
|
|
|
|
|
|
def execute(self, sql: str) -> dict[str, Any]:
|
2026-03-20 13:20:31 +08:00
|
|
|
|
"""执行 SQL,返回脱敏后的聚合结果"""
|
2026-03-19 12:21:04 +08:00
|
|
|
|
self._validate(sql)
|
|
|
|
|
|
|
2026-03-20 13:20:31 +08:00
|
|
|
|
cur = self._conn.cursor()
|
2026-03-19 12:21:04 +08:00
|
|
|
|
try:
|
|
|
|
|
|
cur.execute(sql)
|
|
|
|
|
|
rows = cur.fetchall()
|
|
|
|
|
|
columns = [desc[0] for desc in cur.description] if cur.description else []
|
|
|
|
|
|
results = [dict(row) for row in rows]
|
|
|
|
|
|
sanitized = self._sanitize(results, columns)
|
|
|
|
|
|
|
|
|
|
|
|
self.execution_log.append({
|
2026-03-20 13:20:31 +08:00
|
|
|
|
"sql": sql, "rows_returned": len(results), "columns": columns,
|
2026-03-19 12:21:04 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
2026-03-20 13:20:31 +08:00
|
|
|
|
"success": True, "columns": columns,
|
|
|
|
|
|
"rows": sanitized, "row_count": len(sanitized), "sql": sql,
|
2026-03-19 12:21:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
except sqlite3.Error as e:
|
2026-03-20 13:20:31 +08:00
|
|
|
|
return {"success": False, "error": str(e), "sql": sql}
|
|
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
|
"""关闭连接"""
|
|
|
|
|
|
if self._conn:
|
|
|
|
|
|
self._conn.close()
|
|
|
|
|
|
self._conn = None
|
2026-03-19 12:21:04 +08:00
|
|
|
|
|
|
|
|
|
|
def _validate(self, sql: str):
|
|
|
|
|
|
"""SQL 安全验证"""
|
|
|
|
|
|
sql_upper = sql.upper().strip()
|
|
|
|
|
|
|
|
|
|
|
|
for banned in self.rules["banned_keywords"]:
|
|
|
|
|
|
if banned.upper() in sql_upper:
|
|
|
|
|
|
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
|
|
|
|
|
|
|
|
|
|
|
|
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
|
|
|
|
|
if len(statements) > 1:
|
|
|
|
|
|
raise SandboxError("禁止多语句执行")
|
|
|
|
|
|
if not sql_upper.startswith("SELECT"):
|
|
|
|
|
|
raise SandboxError("只允许 SELECT 查询")
|
|
|
|
|
|
|
|
|
|
|
|
if self.rules["require_aggregation"]:
|
|
|
|
|
|
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
|
|
|
|
|
|
"DISTINCT", "HAVING", "ROUND", "CAST"]
|
|
|
|
|
|
has_agg = any(kw in sql_upper for kw in agg_keywords)
|
2026-03-20 13:20:31 +08:00
|
|
|
|
if not has_agg and "LIMIT" not in sql_upper:
|
|
|
|
|
|
raise SandboxError("要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT")
|
2026-03-19 12:21:04 +08:00
|
|
|
|
|
|
|
|
|
|
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
|
2026-03-20 13:20:31 +08:00
|
|
|
|
if limit_match and int(limit_match.group(1)) > self.rules["max_result_rows"]:
|
|
|
|
|
|
raise SandboxError(f"LIMIT 超过最大允许值 {self.rules['max_result_rows']}")
|
2026-03-19 12:21:04 +08:00
|
|
|
|
|
|
|
|
|
|
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
|
2026-03-20 13:20:31 +08:00
|
|
|
|
"""脱敏处理"""
|
2026-03-19 12:21:04 +08:00
|
|
|
|
rows = rows[:self.rules["max_result_rows"]]
|
2026-03-20 13:20:31 +08:00
|
|
|
|
suppress_n = self.rules["suppress_small_n"]
|
|
|
|
|
|
round_digits = self.rules["round_floats"]
|
2026-03-19 12:21:04 +08:00
|
|
|
|
|
|
|
|
|
|
for row in rows:
|
|
|
|
|
|
for col in columns:
|
|
|
|
|
|
val = row.get(col)
|
2026-03-20 13:20:31 +08:00
|
|
|
|
# 小样本抑制(先做,避免被 round 影响)
|
2026-03-19 12:21:04 +08:00
|
|
|
|
if col.lower() in ("count", "cnt", "n", "total"):
|
2026-03-20 13:20:31 +08:00
|
|
|
|
if isinstance(val, (int, float)) and val < suppress_n:
|
|
|
|
|
|
row[col] = f"<{suppress_n}"
|
|
|
|
|
|
continue
|
|
|
|
|
|
# 浮点数四舍五入
|
|
|
|
|
|
if isinstance(val, float):
|
|
|
|
|
|
row[col] = round(val, round_digits)
|
2026-03-19 12:21:04 +08:00
|
|
|
|
return rows
|
|
|
|
|
|
|
|
|
|
|
|
def get_execution_summary(self) -> str:
|
|
|
|
|
|
if not self.execution_log:
|
|
|
|
|
|
return "尚未执行任何查询"
|
|
|
|
|
|
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
|
|
|
|
|
|
for i, log in enumerate(self.execution_log, 1):
|
|
|
|
|
|
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
|
|
|
|
|
|
return "\n".join(lines)
|