SQLite 持久连接 — sandbox 不再每次查询开关连接,改为 __init__ 时建连、close() 时释放
Explorer 的 system prompt 明确告知 sandbox 规则 — "每条 SQL 必须包含聚合函数或 LIMIT",减少 LLM 生成违规 SQL 浪费轮次 LLM 客户端单例 — 所有组件共享一个 openai.OpenAI 实例,不再各建各的 sanitize 顺序修复 — 小样本抑制放在 float round 之前,避免被 round 干扰 quick_detect 从 O(n²) 改为 O(n) — 按列聚合一次,加去重,不再对每行重复算整列统计 历史上下文实际生效 — get_context_for 的结果现在会注入到 Explorer 的初始 prompt 里,多轮分析时 LLM 能看到之前的发现
This commit is contained in:
22
core/__init__.py
Normal file
22
core/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
iov_ana 核心包
|
||||
|
||||
项目结构:
|
||||
core/
|
||||
config.py - 配置
|
||||
utils.py - 公共工具(JSON 提取、LLM 客户端)
|
||||
schema.py - Schema 提取
|
||||
sandbox.py - SQL 沙箱执行器
|
||||
layers/
|
||||
planner.py - Layer 1: 意图规划
|
||||
playbook.py - Layer 1.5: 预设剧本
|
||||
explorer.py - Layer 2: 自适应探索
|
||||
insights.py - Layer 3: 异常洞察
|
||||
context.py - Layer 4: 上下文记忆
|
||||
output/
|
||||
reporter.py - 单次报告生成
|
||||
consolidator.py - 多次报告整合
|
||||
chart.py - 图表生成
|
||||
agent.py - Agent 编排层
|
||||
cli.py - 交互式 CLI
|
||||
"""
|
||||
38
core/config.py
Normal file
38
core/config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
配置文件
|
||||
"""
|
||||
import os
|
||||
|
||||
# LLM 配置(兼容 OpenAI API 格式,包括 Ollama / vLLM / DeepSeek 等)
|
||||
LLM_CONFIG = {
|
||||
"api_key": os.getenv("LLM_API_KEY", "sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4"),
|
||||
"base_url": os.getenv("LLM_BASE_URL", "https://api.xiaomimimo.com/v1"),
|
||||
"model": os.getenv("LLM_MODEL", "mimo-v2-flash"),
|
||||
}
|
||||
|
||||
# 沙箱安全规则
|
||||
SANDBOX_RULES = {
|
||||
"max_result_rows": 1000,
|
||||
"round_floats": 2,
|
||||
"suppress_small_n": 5,
|
||||
"banned_keywords": [
|
||||
"SELECT *", "INSERT", "UPDATE", "DELETE",
|
||||
"DROP", "ALTER", "CREATE", "ATTACH", "PRAGMA",
|
||||
],
|
||||
"require_aggregation": True,
|
||||
}
|
||||
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
# 数据库路径
|
||||
DB_PATH = os.getenv("DB_PATH", os.path.join(PROJECT_ROOT, "demo.db"))
|
||||
|
||||
# Playbook 目录
|
||||
PLAYBOOK_DIR = os.getenv("PLAYBOOK_DIR", os.path.join(PROJECT_ROOT, "playbooks"))
|
||||
|
||||
# 图表输出目录
|
||||
CHARTS_DIR = os.getenv("CHARTS_DIR", os.path.join(PROJECT_ROOT, "charts"))
|
||||
|
||||
# 分析控制
|
||||
MAX_EXPLORATION_ROUNDS = int(os.getenv("MAX_ROUNDS", "6"))
|
||||
99
core/sandbox.py
Normal file
99
core/sandbox.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
沙箱执行器 —— 执行 SQL,只返回聚合结果
|
||||
"""
|
||||
import sqlite3
|
||||
import re
|
||||
from typing import Any
|
||||
from core.config import SANDBOX_RULES
|
||||
class SandboxError(Exception):
|
||||
"""沙箱安全违规"""
|
||||
pass
|
||||
class SandboxExecutor:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.rules = SANDBOX_RULES
|
||||
self.execution_log: list[dict] = []
|
||||
# 持久连接,避免每次查询都开关
|
||||
self._conn = sqlite3.connect(db_path)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
|
||||
def execute(self, sql: str) -> dict[str, Any]:
|
||||
"""执行 SQL,返回脱敏后的聚合结果"""
|
||||
self._validate(sql)
|
||||
|
||||
cur = self._conn.cursor()
|
||||
try:
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
columns = [desc[0] for desc in cur.description] if cur.description else []
|
||||
results = [dict(row) for row in rows]
|
||||
sanitized = self._sanitize(results, columns)
|
||||
|
||||
self.execution_log.append({
|
||||
"sql": sql, "rows_returned": len(results), "columns": columns,
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True, "columns": columns,
|
||||
"rows": sanitized, "row_count": len(sanitized), "sql": sql,
|
||||
}
|
||||
except sqlite3.Error as e:
|
||||
return {"success": False, "error": str(e), "sql": sql}
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def _validate(self, sql: str):
|
||||
"""SQL 安全验证"""
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
for banned in self.rules["banned_keywords"]:
|
||||
if banned.upper() in sql_upper:
|
||||
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
|
||||
|
||||
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
||||
if len(statements) > 1:
|
||||
raise SandboxError("禁止多语句执行")
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
raise SandboxError("只允许 SELECT 查询")
|
||||
|
||||
if self.rules["require_aggregation"]:
|
||||
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
|
||||
"DISTINCT", "HAVING", "ROUND", "CAST"]
|
||||
has_agg = any(kw in sql_upper for kw in agg_keywords)
|
||||
if not has_agg and "LIMIT" not in sql_upper:
|
||||
raise SandboxError("要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT")
|
||||
|
||||
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
|
||||
if limit_match and int(limit_match.group(1)) > self.rules["max_result_rows"]:
|
||||
raise SandboxError(f"LIMIT 超过最大允许值 {self.rules['max_result_rows']}")
|
||||
|
||||
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
|
||||
"""脱敏处理"""
|
||||
rows = rows[:self.rules["max_result_rows"]]
|
||||
suppress_n = self.rules["suppress_small_n"]
|
||||
round_digits = self.rules["round_floats"]
|
||||
|
||||
for row in rows:
|
||||
for col in columns:
|
||||
val = row.get(col)
|
||||
# 小样本抑制(先做,避免被 round 影响)
|
||||
if col.lower() in ("count", "cnt", "n", "total"):
|
||||
if isinstance(val, (int, float)) and val < suppress_n:
|
||||
row[col] = f"<{suppress_n}"
|
||||
continue
|
||||
# 浮点数四舍五入
|
||||
if isinstance(val, float):
|
||||
row[col] = round(val, round_digits)
|
||||
return rows
|
||||
|
||||
def get_execution_summary(self) -> str:
|
||||
if not self.execution_log:
|
||||
return "尚未执行任何查询"
|
||||
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
|
||||
for i, log in enumerate(self.execution_log, 1):
|
||||
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
|
||||
return "\n".join(lines)
|
||||
98
core/schema.py
Normal file
98
core/schema.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Schema 提取器 —— 只提取表结构,不碰数据
|
||||
"""
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_schema(db_path: str) -> dict[str, Any]:
|
||||
"""从数据库提取 Schema,只返回结构信息"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
||||
tables = [row["name"] for row in cur.fetchall()]
|
||||
|
||||
schema = {"tables": []}
|
||||
|
||||
for table in tables:
|
||||
cur.execute(f"PRAGMA table_info('{table}')")
|
||||
columns = []
|
||||
for col in cur.fetchall():
|
||||
columns.append({
|
||||
"name": col["name"],
|
||||
"type": col["type"],
|
||||
"nullable": col["notnull"] == 0,
|
||||
"is_primary_key": col["pk"] == 1,
|
||||
})
|
||||
|
||||
cur.execute(f"PRAGMA foreign_key_list('{table}')")
|
||||
fks = [
|
||||
{"column": fk["from"], "references_table": fk["table"], "references_column": fk["to"]}
|
||||
for fk in cur.fetchall()
|
||||
]
|
||||
|
||||
cur.execute(f"SELECT COUNT(*) AS cnt FROM '{table}'")
|
||||
row_count = cur.fetchone()["cnt"]
|
||||
|
||||
data_profile = {}
|
||||
for col in columns:
|
||||
col_name = col["name"]
|
||||
col_type = (col["type"] or "").upper()
|
||||
|
||||
if any(t in col_type for t in ("VARCHAR", "TEXT", "CHAR")):
|
||||
cur.execute(f'SELECT DISTINCT "{col_name}" FROM "{table}" WHERE "{col_name}" IS NOT NULL LIMIT 20')
|
||||
vals = [row[0] for row in cur.fetchall()]
|
||||
if len(vals) <= 20:
|
||||
data_profile[col_name] = {"type": "enum", "distinct_count": len(vals), "values": vals}
|
||||
elif any(t in col_type for t in ("INT", "REAL", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC")):
|
||||
cur.execute(f'''
|
||||
SELECT MIN("{col_name}") AS min_val, MAX("{col_name}") AS max_val,
|
||||
AVG("{col_name}") AS avg_val, COUNT(DISTINCT "{col_name}") AS distinct_count
|
||||
FROM "{table}" WHERE "{col_name}" IS NOT NULL
|
||||
''')
|
||||
row = cur.fetchone()
|
||||
if row and row["min_val"] is not None:
|
||||
data_profile[col_name] = {
|
||||
"type": "numeric",
|
||||
"min": round(row["min_val"], 2), "max": round(row["max_val"], 2),
|
||||
"avg": round(row["avg_val"], 2), "distinct_count": row["distinct_count"],
|
||||
}
|
||||
|
||||
schema["tables"].append({
|
||||
"name": table, "columns": columns, "foreign_keys": fks,
|
||||
"row_count": row_count, "data_profile": data_profile,
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return schema
|
||||
|
||||
|
||||
def schema_to_text(schema: dict) -> str:
|
||||
"""将 Schema 转为可读文本,供 LLM 理解"""
|
||||
lines = ["=== 数据库 Schema ===\n"]
|
||||
for table in schema["tables"]:
|
||||
lines.append(f"📋 表: {table['name']} (共 {table['row_count']} 行)")
|
||||
lines.append(" 列:")
|
||||
for col in table["columns"]:
|
||||
pk = " [PK]" if col["is_primary_key"] else ""
|
||||
null = " NULL" if col["nullable"] else " NOT NULL"
|
||||
lines.append(f' - {col["name"]}: {col["type"]}{pk}{null}')
|
||||
if table["foreign_keys"]:
|
||||
lines.append(" 外键:")
|
||||
for fk in table["foreign_keys"]:
|
||||
lines.append(f' - {fk["column"]} → {fk["references_table"]}.{fk["references_column"]}')
|
||||
if table["data_profile"]:
|
||||
lines.append(" 数据画像:")
|
||||
for col_name, profile in table["data_profile"].items():
|
||||
if profile["type"] == "enum":
|
||||
vals = ", ".join(str(v) for v in profile["values"][:10])
|
||||
lines.append(f' - {col_name}: 枚举值({profile["distinct_count"]}个) = [{vals}]')
|
||||
elif profile["type"] == "numeric":
|
||||
lines.append(
|
||||
f' - {col_name}: 范围[{profile["min"]} ~ {profile["max"]}], '
|
||||
f'均值{profile["avg"]}, {profile["distinct_count"]}个不同值'
|
||||
)
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
93
core/utils.py
Normal file
93
core/utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
公共工具 —— JSON 提取、LLM 客户端单例
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
# ── LLM 客户端单例 ──────────────────────────────────
|
||||
|
||||
_llm_client: openai.OpenAI | None = None
|
||||
_llm_model: str = ""
|
||||
|
||||
|
||||
def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]:
|
||||
"""获取 LLM 客户端(单例),避免每个组件各建一个"""
|
||||
global _llm_client, _llm_model
|
||||
if _llm_client is None:
|
||||
_llm_client = openai.OpenAI(
|
||||
api_key=config["api_key"],
|
||||
base_url=config["base_url"],
|
||||
)
|
||||
_llm_model = config["model"]
|
||||
return _llm_client, _llm_model
|
||||
|
||||
|
||||
# ── JSON 提取 ────────────────────────────────────────
|
||||
|
||||
def extract_json_object(text: str) -> dict:
|
||||
"""从 LLM 输出提取 JSON 对象"""
|
||||
text = _clean_json_text(text)
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(_clean_json_text(match.group(1)))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(_clean_json_text(match.group()))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def extract_json_array(text: str) -> list[dict]:
|
||||
"""从 LLM 输出提取 JSON 数组(处理尾逗号、注释等)"""
|
||||
text = _clean_json_text(text)
|
||||
|
||||
try:
|
||||
result = json.loads(text)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
result = json.loads(_clean_json_text(match.group(1)))
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
match = re.search(r'\[.*\]', text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(_clean_json_text(match.group()))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _clean_json_text(s: str) -> str:
|
||||
"""清理 LLM 常见的非标准 JSON"""
|
||||
s = re.sub(r'//.*?\n', '\n', s)
|
||||
s = re.sub(r'/\*.*?\*/', '', s, flags=re.DOTALL)
|
||||
s = re.sub(r',\s*([}\]])', r'\1', s)
|
||||
return s.strip()
|
||||
Reference in New Issue
Block a user