Files
iov_ana/core/sandbox.py

100 lines
3.8 KiB
Python
Raw Normal View History

"""
沙箱执行器 执行 SQL只返回聚合结果
"""
import sqlite3
import re
from typing import Any
from core.config import SANDBOX_RULES
class SandboxError(Exception):
"""沙箱安全违规"""
pass
class SandboxExecutor:
def __init__(self, db_path: str):
self.db_path = db_path
self.rules = SANDBOX_RULES
self.execution_log: list[dict] = []
# 持久连接,避免每次查询都开关
self._conn = sqlite3.connect(db_path)
self._conn.row_factory = sqlite3.Row
def execute(self, sql: str) -> dict[str, Any]:
"""执行 SQL返回脱敏后的聚合结果"""
self._validate(sql)
cur = self._conn.cursor()
try:
cur.execute(sql)
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description] if cur.description else []
results = [dict(row) for row in rows]
sanitized = self._sanitize(results, columns)
self.execution_log.append({
"sql": sql, "rows_returned": len(results), "columns": columns,
})
return {
"success": True, "columns": columns,
"rows": sanitized, "row_count": len(sanitized), "sql": sql,
}
except sqlite3.Error as e:
return {"success": False, "error": str(e), "sql": sql}
def close(self):
"""关闭连接"""
if self._conn:
self._conn.close()
self._conn = None
def _validate(self, sql: str):
"""SQL 安全验证"""
sql_upper = sql.upper().strip()
for banned in self.rules["banned_keywords"]:
if banned.upper() in sql_upper:
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
statements = [s.strip() for s in sql.split(";") if s.strip()]
if len(statements) > 1:
raise SandboxError("禁止多语句执行")
if not sql_upper.startswith("SELECT"):
raise SandboxError("只允许 SELECT 查询")
if self.rules["require_aggregation"]:
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
"DISTINCT", "HAVING", "ROUND", "CAST"]
has_agg = any(kw in sql_upper for kw in agg_keywords)
if not has_agg and "LIMIT" not in sql_upper:
raise SandboxError("要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT")
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match and int(limit_match.group(1)) > self.rules["max_result_rows"]:
raise SandboxError(f"LIMIT 超过最大允许值 {self.rules['max_result_rows']}")
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
"""脱敏处理"""
rows = rows[:self.rules["max_result_rows"]]
suppress_n = self.rules["suppress_small_n"]
round_digits = self.rules["round_floats"]
for row in rows:
for col in columns:
val = row.get(col)
# 小样本抑制(先做,避免被 round 影响)
if col.lower() in ("count", "cnt", "n", "total"):
if isinstance(val, (int, float)) and val < suppress_n:
row[col] = f"<{suppress_n}"
continue
# 浮点数四舍五入
if isinstance(val, float):
row[col] = round(val, round_digits)
return rows
def get_execution_summary(self) -> str:
if not self.execution_log:
return "尚未执行任何查询"
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
for i, log in enumerate(self.execution_log, 1):
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
return "\n".join(lines)