Files
iov_ana/sandbox_executor.py

140 lines
4.5 KiB
Python
Raw Normal View History

"""
沙箱执行器 执行 SQL只返回聚合结果
"""
import sqlite3
import re
from typing import Any
from config import SANDBOX_RULES, DB_PATH
class SandboxError(Exception):
"""沙箱安全违规"""
pass
class SandboxExecutor:
def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path
self.rules = SANDBOX_RULES
self.execution_log: list[dict] = []
def execute(self, sql: str) -> dict[str, Any]:
"""
执行 SQL返回脱敏后的聚合结果
如果违反安全规则抛出 SandboxError
"""
# 验证 SQL 安全性
self._validate(sql)
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
cur.execute(sql)
rows = cur.fetchall()
# 转为字典列表
columns = [desc[0] for desc in cur.description] if cur.description else []
results = [dict(row) for row in rows]
# 脱敏处理
sanitized = self._sanitize(results, columns)
# 记录执行日志
self.execution_log.append({
"sql": sql,
"rows_returned": len(results),
"columns": columns,
})
return {
"success": True,
"columns": columns,
"rows": sanitized,
"row_count": len(sanitized),
"sql": sql,
}
except sqlite3.Error as e:
return {
"success": False,
"error": str(e),
"sql": sql,
}
finally:
conn.close()
def _validate(self, sql: str):
"""SQL 安全验证"""
sql_upper = sql.upper().strip()
# 1. 检查禁止的关键字
for banned in self.rules["banned_keywords"]:
if banned.upper() in sql_upper:
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
# 2. 只允许 SELECT不能有多语句
statements = [s.strip() for s in sql.split(";") if s.strip()]
if len(statements) > 1:
raise SandboxError("禁止多语句执行")
if not sql_upper.startswith("SELECT"):
raise SandboxError("只允许 SELECT 查询")
# 3. 检查是否使用了聚合函数或 GROUP BY可选要求
if self.rules["require_aggregation"]:
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
"DISTINCT", "HAVING", "ROUND", "CAST"]
has_agg = any(kw in sql_upper for kw in agg_keywords)
has_limit = "LIMIT" in sql_upper
if not has_agg and not has_limit:
raise SandboxError(
"要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT"
)
# 4. LIMIT 检查
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match:
limit_val = int(limit_match.group(1))
if limit_val > self.rules["max_result_rows"]:
raise SandboxError(
f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}"
)
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
"""对结果进行脱敏处理"""
if not rows:
return rows
# 1. 限制行数
rows = rows[:self.rules["max_result_rows"]]
# 2. 浮点数四舍五入
for row in rows:
for col in columns:
val = row.get(col)
if isinstance(val, float):
row[col] = round(val, self.rules["round_floats"])
# 3. 小样本抑制k-anonymity
# 如果某个分组的 count 小于阈值,标记为 "<n"
for row in rows:
for col in columns:
if col.lower() in ("count", "cnt", "n", "total"):
val = row.get(col)
if isinstance(val, (int, float)) and val < self.rules["suppress_small_n"]:
row[col] = f"<{self.rules['suppress_small_n']}"
return rows
def get_execution_summary(self) -> str:
"""获取执行摘要"""
if not self.execution_log:
return "尚未执行任何查询"
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
for i, log in enumerate(self.execution_log, 1):
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
return "\n".join(lines)