feat: 四层架构数据分析 Agent
- Layer 1 Planner: 意图规划,将问题转为结构化分析计划 - Layer 2 Explorer: 自适应探索循环,多轮迭代动态生成 SQL - Layer 3 InsightEngine: 异常检测 + 主动洞察 - Layer 4 ContextManager: 多轮对话上下文记忆 安全设计:AI 只看 Schema + 聚合结果,不接触原始数据。 支持任意 OpenAI 兼容 API(OpenAI / Ollama / DeepSeek / vLLM)
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
demo.db
|
||||||
117
README.md
Normal file
117
README.md
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
# 数据分析 Agent —— Schema-Only + 四层架构自适应分析
|
||||||
|
|
||||||
|
**AI 只看表结构,不碰原始数据。通过四层架构自适应探索,生成深度分析报告。**
|
||||||
|
|
||||||
|
## 架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Agent (编排层) │
|
||||||
|
│ 接收问题 → 调度四层 → 输出报告 │
|
||||||
|
└──────┬──────────┬──────────┬──────────┬─────────────────┘
|
||||||
|
│ │ │ │
|
||||||
|
┌────▼────┐ ┌──▼─────┐ ┌─▼──────┐ ┌▼─────────┐
|
||||||
|
│ Planner │ │Explorer│ │Insight │ │ Context │
|
||||||
|
│ 意图规划 │ │探索循环 │ │异常检测 │ │ 上下文记忆 │
|
||||||
|
└─────────┘ └────────┘ └────────┘ └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 四层分工
|
||||||
|
|
||||||
|
| 层 | 组件 | 职责 |
|
||||||
|
|---|---|---|
|
||||||
|
| L1 | **Planner** | 理解用户意图,生成结构化分析计划(类型、维度、指标) |
|
||||||
|
| L2 | **Explorer** | 基于计划多轮迭代探索,每轮根据上一轮结果决定下一步 |
|
||||||
|
| L3 | **InsightEngine** | 从探索结果中检测异常、趋势、关联,输出主动洞察 |
|
||||||
|
| L4 | **ContextManager** | 管理多轮对话历史,后续问题可引用之前的分析 |
|
||||||
|
|
||||||
|
### 安全隔离
|
||||||
|
|
||||||
|
```
|
||||||
|
用户提问 → Agent 看 Schema 生成 SQL → 沙箱执行 → 聚合结果 → Agent 生成报告
|
||||||
|
↑
|
||||||
|
原始数据永远留在这里
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Schema 提取器**:只提取表结构、列类型、行数、枚举值,不碰数据
|
||||||
|
- **沙箱执行器**:禁止 SELECT * / DDL / DML,必须聚合函数,小样本抑制(n<5)
|
||||||
|
- **AI 的视角**:只有 Schema + 聚合统计结果,从未接触任何一行原始数据
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 安装依赖
|
||||||
|
pip install openai
|
||||||
|
|
||||||
|
# 2. 配置 LLM(兼容 OpenAI API 格式)
|
||||||
|
|
||||||
|
# OpenAI
|
||||||
|
export LLM_API_KEY=sk-xxx
|
||||||
|
export LLM_BASE_URL=https://api.openai.com/v1
|
||||||
|
export LLM_MODEL=gpt-4o
|
||||||
|
|
||||||
|
# Ollama(本地部署,隐私优先)
|
||||||
|
export LLM_API_KEY=ollama
|
||||||
|
export LLM_BASE_URL=http://localhost:11434/v1
|
||||||
|
export LLM_MODEL=qwen2.5-coder:7b
|
||||||
|
|
||||||
|
# DeepSeek
|
||||||
|
export LLM_API_KEY=sk-xxx
|
||||||
|
export LLM_BASE_URL=https://api.deepseek.com
|
||||||
|
export LLM_MODEL=deepseek-chat
|
||||||
|
|
||||||
|
# 3. 运行演示(自动创建 5 万条示例数据 + 3 个分析任务)
|
||||||
|
python3 demo.py
|
||||||
|
|
||||||
|
# 4. 交互式分析
|
||||||
|
python3 cli.py
|
||||||
|
|
||||||
|
# 5. 分析你自己的数据库
|
||||||
|
python3 cli.py /path/to/your.db
|
||||||
|
```
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
├── config.py # 配置(LLM、安全规则、探索轮数)
|
||||||
|
├── schema_extractor.py # Schema 提取器(只提取结构)
|
||||||
|
├── sandbox_executor.py # 沙箱执行器(SQL 验证 + 结果脱敏)
|
||||||
|
├── planner.py # [L1] 意图规划器
|
||||||
|
├── explorer.py # [L2] 自适应探索器
|
||||||
|
├── insights.py # [L3] 洞察引擎(异常检测)
|
||||||
|
├── context.py # [L4] 上下文管理器
|
||||||
|
├── reporter.py # 报告生成器
|
||||||
|
├── agent.py # Agent 编排层
|
||||||
|
├── demo.py # 一键演示
|
||||||
|
├── cli.py # 交互式 CLI
|
||||||
|
├── requirements.txt # 依赖
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## 对比预制脚本
|
||||||
|
|
||||||
|
| | 预制脚本 / 模板 | 本方案(四层架构) |
|
||||||
|
|---|---|---|
|
||||||
|
| SQL 生成 | 模板拼接 | LLM 动态生成 |
|
||||||
|
| 查询数量 | 固定 | 1-6 轮,AI 自适应 |
|
||||||
|
| 后续追问 | 无 | AI 看到结果后判断是否深挖 |
|
||||||
|
| 异常发现 | 无 | 主动检测 + 主动输出 |
|
||||||
|
| 多轮对话 | 无 | 上下文记忆,可引用历史分析 |
|
||||||
|
| 适用场景 | 已知分析模式 | 探索性分析、开放性问题 |
|
||||||
|
|
||||||
|
## CLI 命令
|
||||||
|
|
||||||
|
```
|
||||||
|
📊 > 帮我分析各地区的销售表现 # 分析问题
|
||||||
|
📊 > rounds=3 最近的趋势怎么样 # 限制探索轮数
|
||||||
|
📊 > schema # 查看数据库 Schema
|
||||||
|
📊 > history # 查看分析历史
|
||||||
|
📊 > audit # 查看 SQL 审计日志
|
||||||
|
📊 > clear # 清空历史
|
||||||
|
📊 > help # 帮助
|
||||||
|
📊 > quit # 退出
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
139
agent.py
Normal file
139
agent.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Agent 编排层 —— 调度四层架构完成分析
|
||||||
|
|
||||||
|
Layer 1: Planner 意图规划
|
||||||
|
Layer 2: Explorer 自适应探索
|
||||||
|
Layer 3: Insight 异常洞察
|
||||||
|
Layer 4: Context 上下文记忆
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from config import DB_PATH, MAX_EXPLORATION_ROUNDS
|
||||||
|
from schema_extractor import extract_schema, schema_to_text
|
||||||
|
from sandbox_executor import SandboxExecutor
|
||||||
|
from planner import Planner
|
||||||
|
from explorer import Explorer
|
||||||
|
from insights import InsightEngine, quick_detect
|
||||||
|
from reporter import ReportGenerator
|
||||||
|
from context import ContextManager
|
||||||
|
|
||||||
|
|
||||||
|
class DataAnalysisAgent:
|
||||||
|
"""
|
||||||
|
数据分析 Agent
|
||||||
|
|
||||||
|
四层架构:
|
||||||
|
1. Planner - 理解用户意图,生成分析计划
|
||||||
|
2. Explorer - 基于计划多轮迭代探索
|
||||||
|
3. Insights - 从结果中检测异常、输出主动洞察
|
||||||
|
4. Context - 管理多轮对话上下文
|
||||||
|
|
||||||
|
Agent 负责编排这四层,从问题到报告。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
# 数据层
|
||||||
|
self.db_path = db_path
|
||||||
|
self.schema = extract_schema(db_path)
|
||||||
|
self.schema_text = schema_to_text(self.schema)
|
||||||
|
self.executor = SandboxExecutor(db_path)
|
||||||
|
|
||||||
|
# 四层组件
|
||||||
|
self.planner = Planner()
|
||||||
|
self.explorer = Explorer(self.executor)
|
||||||
|
self.insight_engine = InsightEngine()
|
||||||
|
self.reporter = ReportGenerator()
|
||||||
|
self.context = ContextManager()
|
||||||
|
|
||||||
|
def analyze(self, question: str, max_rounds: Optional[int] = None) -> str:
|
||||||
|
"""
|
||||||
|
完整分析流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: 用户分析问题
|
||||||
|
max_rounds: 最大探索轮数(默认用配置值)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的分析报告
|
||||||
|
"""
|
||||||
|
max_rounds = max_rounds or MAX_EXPLORATION_ROUNDS
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"📊 {question}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# ── Layer 0: 检查上下文 ──────────────────────────
|
||||||
|
prev_context = self.context.get_context_for(question)
|
||||||
|
if prev_context:
|
||||||
|
print("📎 发现历史分析上下文,将结合之前的发现")
|
||||||
|
|
||||||
|
# ── Layer 1: 意图规划 ────────────────────────────
|
||||||
|
print("\n🎯 [Layer 1] 意图规划...")
|
||||||
|
plan = self.planner.plan(question, self.schema_text)
|
||||||
|
|
||||||
|
analysis_type = plan.get("analysis_type", "unknown")
|
||||||
|
dimensions = plan.get("dimensions", [])
|
||||||
|
rationale = plan.get("rationale", "")
|
||||||
|
print(f" 类型: {analysis_type}")
|
||||||
|
print(f" 维度: {', '.join(dimensions) if dimensions else '自动发现'}")
|
||||||
|
print(f" 理由: {rationale[:80]}{'...' if len(rationale) > 80 else ''}")
|
||||||
|
|
||||||
|
# ── Layer 2: 自适应探索 ──────────────────────────
|
||||||
|
print(f"\n🔍 [Layer 2] 自适应探索 (最多 {max_rounds} 轮)...")
|
||||||
|
steps = self.explorer.explore(plan, self.schema_text, max_rounds=max_rounds)
|
||||||
|
|
||||||
|
successful = sum(1 for s in steps if s.success)
|
||||||
|
print(f"\n 完成: {len(steps)} 轮, {successful} 条成功查询")
|
||||||
|
|
||||||
|
# ── Layer 3: 异常洞察 ────────────────────────────
|
||||||
|
print("\n🔎 [Layer 3] 异常洞察...")
|
||||||
|
|
||||||
|
# 先做规则检测
|
||||||
|
rule_alerts = quick_detect(steps)
|
||||||
|
for alert in rule_alerts:
|
||||||
|
print(f" {alert}")
|
||||||
|
|
||||||
|
# 再做 LLM 深度分析
|
||||||
|
insights = self.insight_engine.analyze(steps, question)
|
||||||
|
if insights:
|
||||||
|
print(f" 发现 {len(insights)} 条洞察")
|
||||||
|
for insight in insights:
|
||||||
|
print(f" {insight}")
|
||||||
|
else:
|
||||||
|
print(" 未发现异常")
|
||||||
|
|
||||||
|
# ── 生成报告 ────────────────────────────────────
|
||||||
|
print("\n📝 正在生成报告...")
|
||||||
|
report = self.reporter.generate(question, plan, steps, insights)
|
||||||
|
|
||||||
|
# 追加主动洞察
|
||||||
|
if insights:
|
||||||
|
insight_text = self.insight_engine.format_insights(insights)
|
||||||
|
report += f"\n\n---\n\n{insight_text}"
|
||||||
|
|
||||||
|
# ── Layer 4: 记录上下文 ──────────────────────────
|
||||||
|
self.context.add_session(
|
||||||
|
question=question,
|
||||||
|
plan=plan,
|
||||||
|
steps=steps,
|
||||||
|
insights=insights,
|
||||||
|
report=report,
|
||||||
|
)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def get_schema(self) -> str:
|
||||||
|
"""获取 Schema 文本"""
|
||||||
|
return self.schema_text
|
||||||
|
|
||||||
|
def get_history(self) -> str:
|
||||||
|
"""获取分析历史摘要"""
|
||||||
|
return self.context.get_history_summary()
|
||||||
|
|
||||||
|
def get_audit(self) -> str:
|
||||||
|
"""获取执行审计日志"""
|
||||||
|
return self.executor.get_execution_summary()
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
"""清空分析历史"""
|
||||||
|
self.context.clear()
|
||||||
112
cli.py
Normal file
112
cli.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
交互式 CLI —— 四层架构自适应分析
|
||||||
|
用法: python3 cli.py [数据库路径]
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
from config import DB_PATH, LLM_CONFIG, MAX_EXPLORATION_ROUNDS
|
||||||
|
from agent import DataAnalysisAgent
|
||||||
|
|
||||||
|
|
||||||
|
def print_help():
|
||||||
|
print("""
|
||||||
|
可用命令:
|
||||||
|
<问题> 分析一个问题
|
||||||
|
rounds=<N> <问题> 设置探索轮数
|
||||||
|
schema 查看数据库 Schema
|
||||||
|
history 查看分析历史
|
||||||
|
audit 查看 SQL 审计日志
|
||||||
|
clear 清空分析历史
|
||||||
|
help 显示帮助
|
||||||
|
quit / q 退出
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
db_path = sys.argv[1] if len(sys.argv) > 1 else DB_PATH
|
||||||
|
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
print(f"❌ 数据库不存在: {db_path}")
|
||||||
|
print(f" 请先运行 python3 demo.py 创建示例数据库,或指定已有数据库路径")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not LLM_CONFIG["api_key"]:
|
||||||
|
print("⚠️ 未配置 LLM_API_KEY,请先设置环境变量:")
|
||||||
|
print(" export LLM_API_KEY=sk-xxx")
|
||||||
|
print(" export LLM_BASE_URL=https://api.openai.com/v1 # 或 Ollama 等")
|
||||||
|
print(" export LLM_MODEL=gpt-4o")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 初始化 Agent
|
||||||
|
agent = DataAnalysisAgent(db_path)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(" 🤖 数据分析 Agent —— 四层架构")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}")
|
||||||
|
print(f"🔄 最大探索轮数: {MAX_EXPLORATION_ROUNDS}")
|
||||||
|
print(f"💾 数据库: {db_path}")
|
||||||
|
print(f"\n💬 输入分析问题(help 查看命令)\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("📊 > ").strip()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
print("\n👋 再见!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cmd = user_input.lower()
|
||||||
|
|
||||||
|
if cmd in ("quit", "exit", "q"):
|
||||||
|
print("👋 再见!")
|
||||||
|
break
|
||||||
|
elif cmd == "help":
|
||||||
|
print_help()
|
||||||
|
continue
|
||||||
|
elif cmd == "schema":
|
||||||
|
print(agent.get_schema())
|
||||||
|
continue
|
||||||
|
elif cmd == "history":
|
||||||
|
print(agent.get_history())
|
||||||
|
continue
|
||||||
|
elif cmd == "audit":
|
||||||
|
print(agent.get_audit())
|
||||||
|
continue
|
||||||
|
elif cmd == "clear":
|
||||||
|
agent.clear_history()
|
||||||
|
print("✅ 历史已清空")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析可选参数:rounds=3
|
||||||
|
max_rounds = MAX_EXPLORATION_ROUNDS
|
||||||
|
question = user_input
|
||||||
|
if "rounds=" in question.lower():
|
||||||
|
parts = question.split("rounds=")
|
||||||
|
question = parts[0].strip()
|
||||||
|
try:
|
||||||
|
max_rounds = int(parts[1].strip().split()[0])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
report = agent.analyze(question, max_rounds=max_rounds)
|
||||||
|
print("\n" + report)
|
||||||
|
print("\n" + "~" * 60)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 分析出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 退出时显示审计
|
||||||
|
print("\n📋 本次会话审计:")
|
||||||
|
print(agent.get_audit())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
36
config.py
Normal file
36
config.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
配置文件
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# LLM 配置(兼容 OpenAI API 格式,包括 Ollama / vLLM / DeepSeek 等)
|
||||||
|
LLM_CONFIG = {
|
||||||
|
"api_key": os.getenv("LLM_API_KEY", ""),
|
||||||
|
"base_url": os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||||||
|
"model": os.getenv("LLM_MODEL", "gpt-4o"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 沙箱安全规则
|
||||||
|
SANDBOX_RULES = {
|
||||||
|
"max_result_rows": 1000, # 聚合结果最大行数
|
||||||
|
"round_floats": 2, # 浮点数保留位数
|
||||||
|
"suppress_small_n": 5, # 分组样本 < n 时模糊处理
|
||||||
|
"banned_keywords": [ # 禁止的 SQL 关键字
|
||||||
|
"SELECT *",
|
||||||
|
"INSERT",
|
||||||
|
"UPDATE",
|
||||||
|
"DELETE",
|
||||||
|
"DROP",
|
||||||
|
"ALTER",
|
||||||
|
"CREATE",
|
||||||
|
"ATTACH",
|
||||||
|
"PRAGMA",
|
||||||
|
],
|
||||||
|
"require_aggregation": True, # 是否要求使用聚合函数
|
||||||
|
}
|
||||||
|
|
||||||
|
# 数据库路径
|
||||||
|
DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "demo.db"))
|
||||||
|
|
||||||
|
# 分析控制
|
||||||
|
MAX_EXPLORATION_ROUNDS = int(os.getenv("MAX_ROUNDS", "6")) # 最大探索轮数
|
||||||
134
context.py
Normal file
134
context.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
Layer 4: 上下文管理器
|
||||||
|
管理多轮对话的分析上下文,让后续问题可以引用之前的发现。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from explorer import ExplorationStep
|
||||||
|
from insights import Insight
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AnalysisSession:
|
||||||
|
"""一次分析的完整记录"""
|
||||||
|
question: str
|
||||||
|
plan: dict
|
||||||
|
steps: list[ExplorationStep]
|
||||||
|
insights: list[Insight]
|
||||||
|
report: str
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""生成本次分析的摘要(供后续对话引用)"""
|
||||||
|
parts = [f"**问题**: {self.question}"]
|
||||||
|
|
||||||
|
if self.plan:
|
||||||
|
parts.append(f"**分析类型**: {self.plan.get('analysis_type', 'unknown')}")
|
||||||
|
parts.append(f"**关注维度**: {', '.join(self.plan.get('dimensions', []))}")
|
||||||
|
|
||||||
|
# 核心发现(从成功步骤中提取)
|
||||||
|
key_findings = []
|
||||||
|
for step in self.steps:
|
||||||
|
if step.success and step.rows:
|
||||||
|
# 提取最突出的值
|
||||||
|
top_row = step.rows[0] if step.rows else {}
|
||||||
|
finding = f"{step.purpose}: "
|
||||||
|
finding += ", ".join(
|
||||||
|
f"{k}={v}" for k, v in top_row.items() if k.lower() not in ("id",)
|
||||||
|
)
|
||||||
|
key_findings.append(finding)
|
||||||
|
|
||||||
|
if key_findings:
|
||||||
|
parts.append("**核心发现**:")
|
||||||
|
for f in key_findings[:5]:
|
||||||
|
parts.append(f" - {f}")
|
||||||
|
|
||||||
|
# 洞察
|
||||||
|
if self.insights:
|
||||||
|
parts.append("**主动洞察**:")
|
||||||
|
for insight in self.insights[:3]:
|
||||||
|
parts.append(f" - {insight}")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def to_reference_text(self) -> str:
|
||||||
|
"""生成供 LLM 使用的上下文文本"""
|
||||||
|
return (
|
||||||
|
f"## 之前的分析\n\n"
|
||||||
|
f"### 问题\n{self.question}\n\n"
|
||||||
|
f"### 摘要\n{self.summary()}\n\n"
|
||||||
|
f"### 详细发现\n"
|
||||||
|
+ "\n".join(
|
||||||
|
f"- {step.purpose}: {step.row_count} 行结果"
|
||||||
|
for step in self.steps if step.success
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManager:
|
||||||
|
"""上下文管理器:管理多轮对话的分析历史"""
|
||||||
|
|
||||||
|
def __init__(self, max_history: int = 10):
|
||||||
|
self.sessions: list[AnalysisSession] = []
|
||||||
|
self.max_history = max_history
|
||||||
|
|
||||||
|
def add_session(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
plan: dict,
|
||||||
|
steps: list[ExplorationStep],
|
||||||
|
insights: list[Insight],
|
||||||
|
report: str,
|
||||||
|
) -> AnalysisSession:
|
||||||
|
"""记录一次分析"""
|
||||||
|
session = AnalysisSession(
|
||||||
|
question=question,
|
||||||
|
plan=plan,
|
||||||
|
steps=steps,
|
||||||
|
insights=insights,
|
||||||
|
report=report,
|
||||||
|
)
|
||||||
|
self.sessions.append(session)
|
||||||
|
|
||||||
|
# 保持历史大小
|
||||||
|
if len(self.sessions) > self.max_history:
|
||||||
|
self.sessions = self.sessions[-self.max_history:]
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_context_for(self, new_question: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
根据新问题,从历史中找到相关上下文。
|
||||||
|
简单实现:取最近的分析会话。
|
||||||
|
"""
|
||||||
|
if not self.sessions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 取最近 2 轮分析的摘要
|
||||||
|
recent = self.sessions[-2:]
|
||||||
|
parts = []
|
||||||
|
for session in recent:
|
||||||
|
parts.append(session.to_reference_text())
|
||||||
|
|
||||||
|
return "\n\n---\n\n".join(parts)
|
||||||
|
|
||||||
|
def get_history_summary(self) -> str:
|
||||||
|
"""获取所有历史的摘要"""
|
||||||
|
if not self.sessions:
|
||||||
|
return "(无历史分析)"
|
||||||
|
|
||||||
|
lines = [f"共 {len(self.sessions)} 次分析:"]
|
||||||
|
for i, session in enumerate(self.sessions, 1):
|
||||||
|
ts = time.strftime("%H:%M", time.localtime(session.timestamp))
|
||||||
|
lines.append(f" {i}. [{ts}] {session.question}")
|
||||||
|
if session.insights:
|
||||||
|
for insight in session.insights[:2]:
|
||||||
|
lines.append(f" {insight}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""清空历史"""
|
||||||
|
self.sessions.clear()
|
||||||
160
demo.py
Normal file
160
demo.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
演示脚本 —— 创建示例数据,运行四层架构分析
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import sqlite3
|
||||||
|
import random
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
from config import DB_PATH, LLM_CONFIG
|
||||||
|
from agent import DataAnalysisAgent
|
||||||
|
|
||||||
|
|
||||||
|
def create_demo_data(db_path: str):
|
||||||
|
"""创建示例数据库:电商订单 + 用户 + 商品"""
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE users (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
region TEXT NOT NULL,
|
||||||
|
tier TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE products (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL,
|
||||||
|
price REAL NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE orders (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
product_id INTEGER NOT NULL,
|
||||||
|
amount REAL NOT NULL,
|
||||||
|
quantity INTEGER NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
region TEXT NOT NULL,
|
||||||
|
order_date TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id),
|
||||||
|
FOREIGN KEY (product_id) REFERENCES products(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
regions = ["华东", "华南", "华北", "西南", "华中", "东北"]
|
||||||
|
tiers = ["普通", "银卡", "金卡", "钻石"]
|
||||||
|
categories = ["电子产品", "服饰", "食品", "家居", "图书", "美妆"]
|
||||||
|
statuses = ["已完成", "已发货", "待发货", "已取消", "退款中"]
|
||||||
|
|
||||||
|
users = []
|
||||||
|
for i in range(1, 2001):
|
||||||
|
reg = random.choice(regions)
|
||||||
|
tier = random.choices(tiers, weights=[50, 25, 15, 10])[0]
|
||||||
|
created = datetime(2024, 1, 1) + timedelta(days=random.randint(0, 400))
|
||||||
|
users.append((i, f"用户{i}", reg, tier, created.strftime("%Y-%m-%d")))
|
||||||
|
cur.executemany("INSERT INTO users VALUES (?, ?, ?, ?, ?)", users)
|
||||||
|
|
||||||
|
products = []
|
||||||
|
for i in range(1, 201):
|
||||||
|
cat = random.choice(categories)
|
||||||
|
price = round(random.uniform(9.9, 9999.99), 2)
|
||||||
|
products.append((i, f"商品{i}", cat, price))
|
||||||
|
cur.executemany("INSERT INTO products VALUES (?, ?, ?, ?)", products)
|
||||||
|
|
||||||
|
orders = []
|
||||||
|
for i in range(1, 50001):
|
||||||
|
uid = random.randint(1, 2000)
|
||||||
|
pid = random.randint(1, 200)
|
||||||
|
qty = random.choices([1, 2, 3, 4, 5], weights=[50, 25, 15, 7, 3])[0]
|
||||||
|
status = random.choices(statuses, weights=[60, 15, 10, 10, 5])[0]
|
||||||
|
region = random.choices(regions, weights=[35, 25, 15, 10, 10, 5])[0]
|
||||||
|
order_date = datetime(2025, 1, 1) + timedelta(days=random.randint(0, 440))
|
||||||
|
cur.execute("SELECT price FROM products WHERE id = ?", (pid,))
|
||||||
|
price = cur.fetchone()[0]
|
||||||
|
amount = round(price * qty, 2)
|
||||||
|
orders.append((i, uid, pid, amount, qty, status, region, order_date.strftime("%Y-%m-%d")))
|
||||||
|
cur.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?, ?, ?)", orders)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
print(f"✅ 示例数据库已创建: {db_path}")
|
||||||
|
print(f" - users: 2,000 条")
|
||||||
|
print(f" - products: 200 条")
|
||||||
|
print(f" - orders: 50,000 条")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print(" 🤖 数据分析 Agent —— 四层架构自适应分析")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if not LLM_CONFIG["api_key"]:
|
||||||
|
print("\n⚠️ 未配置 LLM_API_KEY,请设置环境变量:")
|
||||||
|
print()
|
||||||
|
print(" # OpenAI")
|
||||||
|
print(" export LLM_API_KEY=sk-xxx")
|
||||||
|
print(" export LLM_BASE_URL=https://api.openai.com/v1")
|
||||||
|
print(" export LLM_MODEL=gpt-4o")
|
||||||
|
print()
|
||||||
|
print(" # Ollama (本地)")
|
||||||
|
print(" export LLM_API_KEY=ollama")
|
||||||
|
print(" export LLM_BASE_URL=http://localhost:11434/v1")
|
||||||
|
print(" export LLM_MODEL=qwen2.5-coder:7b")
|
||||||
|
print()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}")
|
||||||
|
|
||||||
|
if not os.path.exists(DB_PATH):
|
||||||
|
create_demo_data(DB_PATH)
|
||||||
|
else:
|
||||||
|
print(f"\n📂 使用已有数据库: {DB_PATH}")
|
||||||
|
|
||||||
|
# 初始化 Agent(自动加载 Schema)
|
||||||
|
agent = DataAnalysisAgent(DB_PATH)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(" ⬆️ AI 只看到 Schema:表结构 + 数据画像")
|
||||||
|
print(" ⬇️ 四层架构分析:规划 → 探索 → 洞察 → 报告")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
questions = [
|
||||||
|
"各地区的销售表现如何?帮我全面分析一下",
|
||||||
|
"不同商品类别的销售情况和利润贡献",
|
||||||
|
"整体订单的完成率怎么样?有没有什么异常需要关注?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for q in questions:
|
||||||
|
try:
|
||||||
|
report = agent.analyze(q)
|
||||||
|
print("\n" + report)
|
||||||
|
print("\n" + "~" * 60)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 分析出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 最终审计
|
||||||
|
print("\n📋 会话审计:")
|
||||||
|
print(agent.get_audit())
|
||||||
|
print("\n📋 分析历史:")
|
||||||
|
print(agent.get_history())
|
||||||
|
print("\n✅ 演示完成!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
238
explorer.py
Normal file
238
explorer.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
Layer 2: 自适应探索器
|
||||||
|
基于分析计划 + 已有发现,动态决定下一步查什么。
|
||||||
|
多轮迭代,直到 AI 判断"够了"或达到上限。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from config import LLM_CONFIG
|
||||||
|
from sandbox_executor import SandboxExecutor
|
||||||
|
|
||||||
|
|
||||||
|
EXPLORER_SYSTEM = """你是一个数据分析执行者。你的上级给了你一个分析计划,你需要通过迭代执行 SQL 来完成分析。
|
||||||
|
|
||||||
|
## 你的工作方式
|
||||||
|
每一轮你看到:
|
||||||
|
1. 分析计划(上级给的目标)
|
||||||
|
2. 数据库 Schema(表结构、数据画像)
|
||||||
|
3. 之前的探索历史(查过什么、得到什么结果)
|
||||||
|
|
||||||
|
你决定下一步:
|
||||||
|
- 输出一条 SQL 继续探索
|
||||||
|
- 或者输出 done 表示分析足够
|
||||||
|
|
||||||
|
## 输出格式(严格 JSON)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action": "query",
|
||||||
|
"reasoning": "为什么要做这个查询",
|
||||||
|
"sql": "SELECT ...",
|
||||||
|
"purpose": "这个查询的目的"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
或:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action": "done",
|
||||||
|
"reasoning": "为什么分析已经足够"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## SQL 规则
|
||||||
|
- 只用 SELECT,必须有聚合函数或 GROUP BY
|
||||||
|
- 禁止 SELECT *
|
||||||
|
- 用 ROUND 控制精度
|
||||||
|
- 合理使用 LIMIT(分组结果 15 行以内,时间序列 60 行以内)
|
||||||
|
|
||||||
|
## 探索策略
|
||||||
|
1. 第一轮:验证核心假设(计划中最关键的查询)
|
||||||
|
2. 后续轮:基于已有结果追问
|
||||||
|
- 发现离群值 → 追问为什么
|
||||||
|
- 发现异常比例 → 追问细分维度
|
||||||
|
- 结果平淡 → 换个角度试试
|
||||||
|
3. 不要重复查已经确认的事
|
||||||
|
4. 每轮要有新发现,否则就该结束"""
|
||||||
|
|
||||||
|
|
||||||
|
EXPLORER_CONTINUE = """查询结果:
|
||||||
|
|
||||||
|
{result_text}
|
||||||
|
|
||||||
|
请基于这个结果决定下一步。如果发现异常或值得深挖的点,继续查询。如果分析足够,输出 done。"""
|
||||||
|
|
||||||
|
|
||||||
|
class ExplorationStep:
|
||||||
|
"""单步探索结果"""
|
||||||
|
def __init__(self, round_num: int, decision: dict, result: dict):
|
||||||
|
self.round_num = round_num
|
||||||
|
self.reasoning = decision.get("reasoning", "")
|
||||||
|
self.purpose = decision.get("purpose", "")
|
||||||
|
self.sql = decision.get("sql", "")
|
||||||
|
self.action = decision.get("action", "query")
|
||||||
|
self.success = result.get("success", False)
|
||||||
|
self.error = result.get("error")
|
||||||
|
self.columns = result.get("columns", [])
|
||||||
|
self.rows = result.get("rows", [])
|
||||||
|
self.row_count = result.get("row_count", 0)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = {
|
||||||
|
"round": self.round_num,
|
||||||
|
"action": self.action,
|
||||||
|
"reasoning": self.reasoning,
|
||||||
|
"purpose": self.purpose,
|
||||||
|
"sql": self.sql,
|
||||||
|
"success": self.success,
|
||||||
|
}
|
||||||
|
if self.success:
|
||||||
|
d["result"] = {
|
||||||
|
"columns": self.columns,
|
||||||
|
"rows": self.rows,
|
||||||
|
"row_count": self.row_count,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
d["result"] = {"error": self.error}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class Explorer:
|
||||||
|
"""自适应探索器:多轮迭代执行 SQL"""
|
||||||
|
|
||||||
|
def __init__(self, executor: SandboxExecutor):
|
||||||
|
self.executor = executor
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
api_key=LLM_CONFIG["api_key"],
|
||||||
|
base_url=LLM_CONFIG["base_url"],
|
||||||
|
)
|
||||||
|
self.model = LLM_CONFIG["model"]
|
||||||
|
|
||||||
|
def explore(
|
||||||
|
self,
|
||||||
|
plan: dict,
|
||||||
|
schema_text: str,
|
||||||
|
max_rounds: int = 6,
|
||||||
|
) -> list[ExplorationStep]:
|
||||||
|
"""
|
||||||
|
执行探索循环
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: Planner 生成的分析计划
|
||||||
|
schema_text: Schema 文本描述
|
||||||
|
max_rounds: 最大探索轮数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
探步列表
|
||||||
|
"""
|
||||||
|
steps: list[ExplorationStep] = []
|
||||||
|
|
||||||
|
# 构建初始消息
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": EXPLORER_SYSTEM},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"## 分析计划\n```json\n{json.dumps(plan, ensure_ascii=False, indent=2)}\n```\n\n"
|
||||||
|
f"## 数据库 Schema\n{schema_text}\n\n"
|
||||||
|
f"请开始第一轮探索。根据计划,先执行最关键的查询。"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for round_num in range(1, max_rounds + 1):
|
||||||
|
print(f"\n 🔄 探索第 {round_num}/{max_rounds} 轮")
|
||||||
|
|
||||||
|
# LLM 决策
|
||||||
|
decision = self._llm_decide(messages)
|
||||||
|
action = decision.get("action", "query")
|
||||||
|
reasoning = decision.get("reasoning", "")
|
||||||
|
|
||||||
|
print(f" 💭 {reasoning[:80]}{'...' if len(reasoning) > 80 else ''}")
|
||||||
|
|
||||||
|
if action == "done":
|
||||||
|
print(f" ✅ 探索完成")
|
||||||
|
steps.append(ExplorationStep(round_num, decision, {"success": True}))
|
||||||
|
break
|
||||||
|
|
||||||
|
# 执行 SQL
|
||||||
|
sql = decision.get("sql", "")
|
||||||
|
purpose = decision.get("purpose", "")
|
||||||
|
|
||||||
|
if not sql:
|
||||||
|
print(f" ⚠️ 未生成 SQL,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f" 📝 {purpose}")
|
||||||
|
result = self.executor.execute(sql)
|
||||||
|
|
||||||
|
if result["success"]:
|
||||||
|
print(f" ✅ {result['row_count']} 行")
|
||||||
|
else:
|
||||||
|
print(f" ❌ {result['error']}")
|
||||||
|
|
||||||
|
step = ExplorationStep(round_num, decision, result)
|
||||||
|
steps.append(step)
|
||||||
|
|
||||||
|
# 更新对话历史
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": json.dumps(decision, ensure_ascii=False),
|
||||||
|
})
|
||||||
|
|
||||||
|
result_text = self._format_result(result)
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": EXPLORER_CONTINUE.format(result_text=result_text),
|
||||||
|
})
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
def _llm_decide(self, messages: list[dict]) -> dict:
|
||||||
|
"""LLM 决策"""
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=1024,
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
return self._extract_json(content)
|
||||||
|
|
||||||
|
def _format_result(self, result: dict) -> str:
|
||||||
|
"""格式化查询结果"""
|
||||||
|
if not result.get("success"):
|
||||||
|
return f"❌ 执行失败: {result.get('error', '未知错误')}"
|
||||||
|
|
||||||
|
rows = result["rows"][:20]
|
||||||
|
return (
|
||||||
|
f"✅ 查询成功,返回 {result['row_count']} 行\n"
|
||||||
|
f"列: {result['columns']}\n"
|
||||||
|
f"数据:\n{json.dumps(rows, ensure_ascii=False, indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_json(self, text: str) -> dict:
|
||||||
|
"""从 LLM 输出提取 JSON"""
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return json.loads(match.group(1))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return json.loads(match.group())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"action": "done", "reasoning": f"无法解析输出: {text[:100]}"}
|
||||||
239
insights.py
Normal file
239
insights.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""
|
||||||
|
Layer 3: 洞察引擎
|
||||||
|
对探索结果进行异常检测 + 主动洞察,输出用户没问但值得知道的事。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import statistics
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from config import LLM_CONFIG
|
||||||
|
from explorer import ExplorationStep
|
||||||
|
|
||||||
|
|
||||||
|
INSIGHT_SYSTEM = """你是一个数据洞察专家。你会收到探索过程的所有结果,你需要:
|
||||||
|
|
||||||
|
1. 从结果中发现异常和有趣现象
|
||||||
|
2. 对比不同维度,找出差异
|
||||||
|
3. 输出用户可能没问但值得知道的洞察
|
||||||
|
|
||||||
|
## 输出格式(严格 JSON 数组)
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "outlier" | "trend" | "distribution" | "correlation" | "recommendation",
|
||||||
|
"severity": "high" | "medium" | "low",
|
||||||
|
"title": "简短标题",
|
||||||
|
"detail": "详细描述,包含具体数字",
|
||||||
|
"evidence": "支撑这个洞察的数据来源"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 洞察类型
|
||||||
|
- outlier: 离群值(某个分组异常高/低)
|
||||||
|
- trend: 趋势发现(增长/下降、季节性)
|
||||||
|
- distribution: 分布异常(不均衡、集中度过高)
|
||||||
|
- correlation: 关联发现(两个维度的意外关联)
|
||||||
|
- recommendation: 行动建议(基于数据的建议)
|
||||||
|
|
||||||
|
## 分析原则
|
||||||
|
- 每个洞察必须有具体数字支撑
|
||||||
|
- 用对比来说话(A 比 B 高 X%)
|
||||||
|
- 关注异常,不描述平淡的事实
|
||||||
|
- 如果没有异常,返回空数组"""
|
||||||
|
|
||||||
|
|
||||||
|
class Insight:
|
||||||
|
"""单条洞察"""
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
self.type = data.get("type", "unknown")
|
||||||
|
self.severity = data.get("severity", "low")
|
||||||
|
self.title = data.get("title", "")
|
||||||
|
self.detail = data.get("detail", "")
|
||||||
|
self.evidence = data.get("evidence", "")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emoji(self) -> str:
|
||||||
|
return {
|
||||||
|
"outlier": "⚠️",
|
||||||
|
"trend": "📈",
|
||||||
|
"distribution": "📊",
|
||||||
|
"correlation": "🔗",
|
||||||
|
"recommendation": "💡",
|
||||||
|
}.get(self.type, "📌")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def severity_emoji(self) -> str:
|
||||||
|
return {"high": "🔴", "medium": "🟡", "low": "🟢"}.get(self.severity, "")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.emoji} {self.severity_emoji} {self.title}: {self.detail}"
|
||||||
|
|
||||||
|
|
||||||
|
class InsightEngine:
|
||||||
|
"""洞察引擎:自动检测异常 + 主动输出"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
api_key=LLM_CONFIG["api_key"],
|
||||||
|
base_url=LLM_CONFIG["base_url"],
|
||||||
|
)
|
||||||
|
self.model = LLM_CONFIG["model"]
|
||||||
|
|
||||||
|
def analyze(self, steps: list[ExplorationStep], question: str) -> list[Insight]:
|
||||||
|
"""
|
||||||
|
对探索结果进行洞察分析
|
||||||
|
|
||||||
|
Args:
|
||||||
|
steps: 探索步骤列表
|
||||||
|
question: 原始用户问题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
洞察列表
|
||||||
|
"""
|
||||||
|
if not steps:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 构建探索历史文本
|
||||||
|
history = self._build_history(steps)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": INSIGHT_SYSTEM},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"## 用户原始问题\n{question}\n\n"
|
||||||
|
f"## 探索历史\n{history}\n\n"
|
||||||
|
f"请分析以上数据,输出异常和洞察。"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
insights_data = self._extract_json_array(content)
|
||||||
|
|
||||||
|
return [Insight(d) for d in insights_data]
|
||||||
|
|
||||||
|
def format_insights(self, insights: list[Insight]) -> str:
|
||||||
|
"""格式化洞察为可读文本"""
|
||||||
|
if not insights:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 按严重程度排序
|
||||||
|
severity_order = {"high": 0, "medium": 1, "low": 2}
|
||||||
|
sorted_insights = sorted(insights, key=lambda i: severity_order.get(i.severity, 9))
|
||||||
|
|
||||||
|
lines = ["## 💡 主动洞察", ""]
|
||||||
|
lines.append("_以下是你没问但数据告诉我们的事:_\n")
|
||||||
|
|
||||||
|
for insight in sorted_insights:
|
||||||
|
lines.append(f"**{insight.emoji} {insight.title}** {insight.severity_emoji}")
|
||||||
|
lines.append(f" {insight.detail}")
|
||||||
|
lines.append(f" _数据来源: {insight.evidence}_")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _build_history(self, steps: list[ExplorationStep]) -> str:
|
||||||
|
"""构建探索历史文本"""
|
||||||
|
parts = []
|
||||||
|
for step in steps:
|
||||||
|
if step.action == "done":
|
||||||
|
parts.append(f"### 结束\n{step.reasoning}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if step.success:
|
||||||
|
parts.append(
|
||||||
|
f"### 第 {step.round_num} 轮:{step.purpose}\n"
|
||||||
|
f"思考: {step.reasoning}\n"
|
||||||
|
f"SQL: `{step.sql}`\n"
|
||||||
|
f"结果 ({step.row_count} 行):\n"
|
||||||
|
f"列: {step.columns}\n"
|
||||||
|
f"数据: {json.dumps(step.rows, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(
|
||||||
|
f"### 第 {step.round_num} 轮:{step.purpose}\n"
|
||||||
|
f"SQL: `{step.sql}`\n"
|
||||||
|
f"结果: 执行失败 - {step.error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def _extract_json_array(self, text: str) -> list[dict]:
|
||||||
|
"""从 LLM 输出提取 JSON 数组"""
|
||||||
|
try:
|
||||||
|
result = json.loads(text)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
result = json.loads(match.group(1))
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 找最外层 []
|
||||||
|
match = re.search(r'\[.*\]', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return json.loads(match.group())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# ── 基于规则的快速异常检测(无需 LLM)────────────────
|
||||||
|
|
||||||
|
def quick_detect(steps: list[ExplorationStep]) -> list[str]:
|
||||||
|
"""
|
||||||
|
基于规则的快速异常检测,不调 LLM。
|
||||||
|
检测离群值、不均衡分布等。
|
||||||
|
"""
|
||||||
|
alerts = []
|
||||||
|
|
||||||
|
for step in steps:
|
||||||
|
if not step.success or not step.rows:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for row in step.rows:
|
||||||
|
for col in step.columns:
|
||||||
|
val = row.get(col)
|
||||||
|
if not isinstance(val, (int, float)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查 pct 列:某个分组占比异常
|
||||||
|
if col.lower() in ("pct", "percent", "percentage", "占比"):
|
||||||
|
if isinstance(val, (int, float)) and val > 50:
|
||||||
|
alerts.append(
|
||||||
|
f"⚠️ {step.purpose} 中某个分组占比 {val}%,超过 50%,集中度过高"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查 count 列:极值差异
|
||||||
|
if col.lower() in ("count", "cnt", "n", "total", "order_count"):
|
||||||
|
count_vals = [
|
||||||
|
r.get(col) for r in step.rows
|
||||||
|
if isinstance(r.get(col), (int, float))
|
||||||
|
]
|
||||||
|
if len(count_vals) >= 3 and max(count_vals) > 0:
|
||||||
|
ratio = max(count_vals) / (sum(count_vals) / len(count_vals))
|
||||||
|
if ratio > 3:
|
||||||
|
alerts.append(
|
||||||
|
f"⚠️ {step.purpose} 中最大值是均值的 {ratio:.1f} 倍,分布极不均衡"
|
||||||
|
)
|
||||||
|
|
||||||
|
return alerts
|
||||||
129
planner.py
Normal file
129
planner.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Layer 1: 意图规划器
|
||||||
|
将用户问题解析为结构化的分析计划,替代硬编码模板。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
PROMPT = """你是一个数据分析规划专家。
|
||||||
|
|
||||||
|
## 你的任务
|
||||||
|
根据用户的分析问题和数据库 Schema,生成一个结构化的分析计划。
|
||||||
|
|
||||||
|
## 输出格式(严格 JSON)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"intent": "一句话描述用户想了解什么",
|
||||||
|
"analysis_type": "ranking" | "distribution" | "trend" | "comparison" | "anomaly" | "overview",
|
||||||
|
"primary_table": "主要分析的表名",
|
||||||
|
"dimensions": ["分组维度列名"],
|
||||||
|
"metrics": ["需要聚合的数值列名"],
|
||||||
|
"aggregations": ["SUM", "AVG", "COUNT", ...],
|
||||||
|
"filters": [{"column": "列名", "condition": "过滤条件(可选)"}],
|
||||||
|
"join_needed": false,
|
||||||
|
"join_info": {"tables": [], "on": ""},
|
||||||
|
"expected_rounds": 3,
|
||||||
|
"rationale": "为什么这样规划,需要关注什么"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 分析类型说明
|
||||||
|
- ranking: 按某维度排名(哪个地区最高)
|
||||||
|
- distribution: 分布/占比(各地区占比多少)
|
||||||
|
- trend: 时间趋势(月度变化)
|
||||||
|
- comparison: 对比分析(A vs B)
|
||||||
|
- anomaly: 异常检测(有没有异常值)
|
||||||
|
- overview: 全局概览(整体情况如何)
|
||||||
|
|
||||||
|
## 规划原则
|
||||||
|
1. 只选择与问题相关的表和列
|
||||||
|
2. 如果需要 JOIN,说明关联条件
|
||||||
|
3. 预估需要几轮探索(1-6)
|
||||||
|
4. 标注可能的异常关注点
|
||||||
|
5. metrics 不要包含 id 列"""
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from config import LLM_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
class Planner:
|
||||||
|
"""意图规划器:将自然语言问题转为结构化分析计划"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
api_key=LLM_CONFIG["api_key"],
|
||||||
|
base_url=LLM_CONFIG["base_url"],
|
||||||
|
)
|
||||||
|
self.model = LLM_CONFIG["model"]
|
||||||
|
|
||||||
|
def plan(self, question: str, schema_text: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成分析计划
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"intent": str,
|
||||||
|
"analysis_type": str,
|
||||||
|
"primary_table": str,
|
||||||
|
"dimensions": [str],
|
||||||
|
"metrics": [str],
|
||||||
|
"aggregations": [str],
|
||||||
|
"filters": [dict],
|
||||||
|
"join_needed": bool,
|
||||||
|
"join_info": dict,
|
||||||
|
"expected_rounds": int,
|
||||||
|
"rationale": str,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"## Schema\n{schema_text}\n\n"
|
||||||
|
f"## 用户问题\n{question}"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
plan = self._extract_json(content)
|
||||||
|
|
||||||
|
# 补充默认值
|
||||||
|
plan.setdefault("analysis_type", "overview")
|
||||||
|
plan.setdefault("expected_rounds", 3)
|
||||||
|
plan.setdefault("filters", [])
|
||||||
|
plan.setdefault("join_needed", False)
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
def _extract_json(self, text: str) -> dict:
|
||||||
|
"""从 LLM 输出提取 JSON"""
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return json.loads(match.group(1))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 找最外层 {}
|
||||||
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return json.loads(match.group())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"intent": text[:100], "analysis_type": "overview"}
|
||||||
110
reporter.py
Normal file
110
reporter.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""
|
||||||
|
报告生成器
|
||||||
|
将分析计划 + 探索结果 + 洞察,综合为一份可读报告。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from config import LLM_CONFIG
|
||||||
|
from explorer import ExplorationStep
|
||||||
|
from insights import Insight
|
||||||
|
|
||||||
|
|
||||||
|
REPORT_PROMPT = """你是一个数据分析报告撰写专家。基于以下信息撰写报告。
|
||||||
|
|
||||||
|
## 用户问题
|
||||||
|
{question}
|
||||||
|
|
||||||
|
## 分析计划
|
||||||
|
{plan}
|
||||||
|
|
||||||
|
## 探索过程
|
||||||
|
{exploration}
|
||||||
|
|
||||||
|
## 主动洞察
|
||||||
|
{insights_text}
|
||||||
|
|
||||||
|
## 撰写要求
|
||||||
|
1. **开头**:一句话总结核心结论(先给答案)
|
||||||
|
2. **核心发现**:按重要性排列,每个发现带具体数字
|
||||||
|
3. **深入洞察**:异常、趋势、关联(从洞察中提取)
|
||||||
|
4. **建议**:基于数据的行动建议
|
||||||
|
5. **审计**:末尾附上所有执行的 SQL
|
||||||
|
|
||||||
|
使用 Markdown,中文撰写。
|
||||||
|
语气:专业但不枯燥,像一个聪明的分析师在做简报。"""
|
||||||
|
|
||||||
|
|
||||||
|
class ReportGenerator:
|
||||||
|
"""报告生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
api_key=LLM_CONFIG["api_key"],
|
||||||
|
base_url=LLM_CONFIG["base_url"],
|
||||||
|
)
|
||||||
|
self.model = LLM_CONFIG["model"]
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
plan: dict,
|
||||||
|
steps: list[ExplorationStep],
|
||||||
|
insights: list[Insight],
|
||||||
|
) -> str:
|
||||||
|
"""生成分析报告"""
|
||||||
|
|
||||||
|
# 构建探索过程文本
|
||||||
|
exploration = self._build_exploration(steps)
|
||||||
|
|
||||||
|
# 构建洞察文本
|
||||||
|
insights_text = self._build_insights(insights)
|
||||||
|
|
||||||
|
prompt = REPORT_PROMPT.format(
|
||||||
|
question=question,
|
||||||
|
plan=json.dumps(plan, ensure_ascii=False, indent=2),
|
||||||
|
exploration=exploration,
|
||||||
|
insights_text=insights_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "你是专业的数据分析师,撰写清晰、有洞察力的分析报告。"},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def _build_exploration(self, steps: list[ExplorationStep]) -> str:
|
||||||
|
parts = []
|
||||||
|
for step in steps:
|
||||||
|
if step.action == "done":
|
||||||
|
parts.append(f"### 探索结束\n{step.reasoning}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if step.success:
|
||||||
|
parts.append(
|
||||||
|
f"### 第 {step.round_num} 轮:{step.purpose}\n"
|
||||||
|
f"思考: {step.reasoning}\n"
|
||||||
|
f"SQL: `{step.sql}`\n"
|
||||||
|
f"结果 ({step.row_count} 行):\n"
|
||||||
|
f"列: {step.columns}\n"
|
||||||
|
f"数据: {json.dumps(step.rows, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(
|
||||||
|
f"### 第 {step.round_num} 轮:{step.purpose}\n"
|
||||||
|
f"SQL: `{step.sql}`\n"
|
||||||
|
f"执行失败: {step.error}"
|
||||||
|
)
|
||||||
|
return "\n\n".join(parts) if parts else "无探索步骤"
|
||||||
|
|
||||||
|
def _build_insights(self, insights: list[Insight]) -> str:
|
||||||
|
if not insights:
|
||||||
|
return "未检测到异常。"
|
||||||
|
return "\n".join(str(i) for i in insights)
|
||||||
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
openai>=1.0.0
|
||||||
139
sandbox_executor.py
Normal file
139
sandbox_executor.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
沙箱执行器 —— 执行 SQL,只返回聚合结果
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
from config import SANDBOX_RULES, DB_PATH
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxError(Exception):
|
||||||
|
"""沙箱安全违规"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxExecutor:
|
||||||
|
def __init__(self, db_path: str = DB_PATH):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.rules = SANDBOX_RULES
|
||||||
|
self.execution_log: list[dict] = []
|
||||||
|
|
||||||
|
def execute(self, sql: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 SQL,返回脱敏后的聚合结果。
|
||||||
|
如果违反安全规则,抛出 SandboxError。
|
||||||
|
"""
|
||||||
|
# 验证 SQL 安全性
|
||||||
|
self._validate(sql)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
cur.execute(sql)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
# 转为字典列表
|
||||||
|
columns = [desc[0] for desc in cur.description] if cur.description else []
|
||||||
|
results = [dict(row) for row in rows]
|
||||||
|
|
||||||
|
# 脱敏处理
|
||||||
|
sanitized = self._sanitize(results, columns)
|
||||||
|
|
||||||
|
# 记录执行日志
|
||||||
|
self.execution_log.append({
|
||||||
|
"sql": sql,
|
||||||
|
"rows_returned": len(results),
|
||||||
|
"columns": columns,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"columns": columns,
|
||||||
|
"rows": sanitized,
|
||||||
|
"row_count": len(sanitized),
|
||||||
|
"sql": sql,
|
||||||
|
}
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"sql": sql,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _validate(self, sql: str):
|
||||||
|
"""SQL 安全验证"""
|
||||||
|
sql_upper = sql.upper().strip()
|
||||||
|
|
||||||
|
# 1. 检查禁止的关键字
|
||||||
|
for banned in self.rules["banned_keywords"]:
|
||||||
|
if banned.upper() in sql_upper:
|
||||||
|
raise SandboxError(f"禁止的 SQL 关键字: {banned}")
|
||||||
|
|
||||||
|
# 2. 只允许 SELECT(不能有多语句)
|
||||||
|
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
||||||
|
if len(statements) > 1:
|
||||||
|
raise SandboxError("禁止多语句执行")
|
||||||
|
if not sql_upper.startswith("SELECT"):
|
||||||
|
raise SandboxError("只允许 SELECT 查询")
|
||||||
|
|
||||||
|
# 3. 检查是否使用了聚合函数或 GROUP BY(可选要求)
|
||||||
|
if self.rules["require_aggregation"]:
|
||||||
|
agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY",
|
||||||
|
"DISTINCT", "HAVING", "ROUND", "CAST"]
|
||||||
|
has_agg = any(kw in sql_upper for kw in agg_keywords)
|
||||||
|
has_limit = "LIMIT" in sql_upper
|
||||||
|
|
||||||
|
if not has_agg and not has_limit:
|
||||||
|
raise SandboxError(
|
||||||
|
"要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. LIMIT 检查
|
||||||
|
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
|
||||||
|
if limit_match:
|
||||||
|
limit_val = int(limit_match.group(1))
|
||||||
|
if limit_val > self.rules["max_result_rows"]:
|
||||||
|
raise SandboxError(
|
||||||
|
f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]:
|
||||||
|
"""对结果进行脱敏处理"""
|
||||||
|
if not rows:
|
||||||
|
return rows
|
||||||
|
|
||||||
|
# 1. 限制行数
|
||||||
|
rows = rows[:self.rules["max_result_rows"]]
|
||||||
|
|
||||||
|
# 2. 浮点数四舍五入
|
||||||
|
for row in rows:
|
||||||
|
for col in columns:
|
||||||
|
val = row.get(col)
|
||||||
|
if isinstance(val, float):
|
||||||
|
row[col] = round(val, self.rules["round_floats"])
|
||||||
|
|
||||||
|
# 3. 小样本抑制(k-anonymity)
|
||||||
|
# 如果某个分组的 count 小于阈值,标记为 "<n"
|
||||||
|
for row in rows:
|
||||||
|
for col in columns:
|
||||||
|
if col.lower() in ("count", "cnt", "n", "total"):
|
||||||
|
val = row.get(col)
|
||||||
|
if isinstance(val, (int, float)) and val < self.rules["suppress_small_n"]:
|
||||||
|
row[col] = f"<{self.rules['suppress_small_n']}"
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def get_execution_summary(self) -> str:
|
||||||
|
"""获取执行摘要"""
|
||||||
|
if not self.execution_log:
|
||||||
|
return "尚未执行任何查询"
|
||||||
|
|
||||||
|
lines = [f"共执行 {len(self.execution_log)} 条查询:"]
|
||||||
|
for i, log in enumerate(self.execution_log, 1):
|
||||||
|
lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果")
|
||||||
|
return "\n".join(lines)
|
||||||
126
schema_extractor.py
Normal file
126
schema_extractor.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""
|
||||||
|
Schema 提取器 —— 只提取表结构,不碰数据
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def extract_schema(db_path: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
从数据库提取 Schema,只返回结构信息:
|
||||||
|
- 表名、列名、类型
|
||||||
|
- 主键、外键
|
||||||
|
- 行数
|
||||||
|
- 枚举列的去重值(不含原始数据)
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
# 获取所有表
|
||||||
|
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
||||||
|
tables = [row["name"] for row in cur.fetchall()]
|
||||||
|
|
||||||
|
schema = {"tables": []}
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
# 列信息
|
||||||
|
cur.execute(f"PRAGMA table_info('{table}')")
|
||||||
|
columns = []
|
||||||
|
for col in cur.fetchall():
|
||||||
|
columns.append({
|
||||||
|
"name": col["name"],
|
||||||
|
"type": col["type"],
|
||||||
|
"nullable": col["notnull"] == 0,
|
||||||
|
"is_primary_key": col["pk"] == 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 外键
|
||||||
|
cur.execute(f"PRAGMA foreign_key_list('{table}')")
|
||||||
|
fks = []
|
||||||
|
for fk in cur.fetchall():
|
||||||
|
fks.append({
|
||||||
|
"column": fk["from"],
|
||||||
|
"references_table": fk["table"],
|
||||||
|
"references_column": fk["to"],
|
||||||
|
})
|
||||||
|
|
||||||
|
# 行数
|
||||||
|
cur.execute(f"SELECT COUNT(*) AS cnt FROM '{table}'")
|
||||||
|
row_count = cur.fetchone()["cnt"]
|
||||||
|
|
||||||
|
# 对 VARCHAR / TEXT 类型列,提取去重枚举值(最多 20 个)
|
||||||
|
data_profile = {}
|
||||||
|
for col in columns:
|
||||||
|
col_name = col["name"]
|
||||||
|
col_type = (col["type"] or "").upper()
|
||||||
|
|
||||||
|
if any(t in col_type for t in ("VARCHAR", "TEXT", "CHAR")):
|
||||||
|
cur.execute(f'SELECT DISTINCT "{col_name}" FROM "{table}" WHERE "{col_name}" IS NOT NULL LIMIT 20')
|
||||||
|
vals = [row[0] for row in cur.fetchall()]
|
||||||
|
if len(vals) <= 20:
|
||||||
|
data_profile[col_name] = {
|
||||||
|
"type": "enum",
|
||||||
|
"distinct_count": len(vals),
|
||||||
|
"values": vals,
|
||||||
|
}
|
||||||
|
elif any(t in col_type for t in ("INT", "REAL", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC")):
|
||||||
|
cur.execute(f'''
|
||||||
|
SELECT MIN("{col_name}") AS min_val, MAX("{col_name}") AS max_val,
|
||||||
|
AVG("{col_name}") AS avg_val, COUNT(DISTINCT "{col_name}") AS distinct_count
|
||||||
|
FROM "{table}" WHERE "{col_name}" IS NOT NULL
|
||||||
|
''')
|
||||||
|
row = cur.fetchone()
|
||||||
|
if row and row["min_val"] is not None:
|
||||||
|
data_profile[col_name] = {
|
||||||
|
"type": "numeric",
|
||||||
|
"min": round(row["min_val"], 2),
|
||||||
|
"max": round(row["max_val"], 2),
|
||||||
|
"avg": round(row["avg_val"], 2),
|
||||||
|
"distinct_count": row["distinct_count"],
|
||||||
|
}
|
||||||
|
|
||||||
|
schema["tables"].append({
|
||||||
|
"name": table,
|
||||||
|
"columns": columns,
|
||||||
|
"foreign_keys": fks,
|
||||||
|
"row_count": row_count,
|
||||||
|
"data_profile": data_profile,
|
||||||
|
})
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def schema_to_text(schema: dict) -> str:
|
||||||
|
"""将 Schema 转为可读文本,供 LLM 理解"""
|
||||||
|
lines = ["=== 数据库 Schema ===\n"]
|
||||||
|
|
||||||
|
for table in schema["tables"]:
|
||||||
|
lines.append(f"📋 表: {table['name']} (共 {table['row_count']} 行)")
|
||||||
|
lines.append(" 列:")
|
||||||
|
for col in table["columns"]:
|
||||||
|
pk = " [PK]" if col["is_primary_key"] else ""
|
||||||
|
null = " NULL" if col["nullable"] else " NOT NULL"
|
||||||
|
lines.append(f' - {col["name"]}: {col["type"]}{pk}{null}')
|
||||||
|
|
||||||
|
if table["foreign_keys"]:
|
||||||
|
lines.append(" 外键:")
|
||||||
|
for fk in table["foreign_keys"]:
|
||||||
|
lines.append(f' - {fk["column"]} → {fk["references_table"]}.{fk["references_column"]}')
|
||||||
|
|
||||||
|
if table["data_profile"]:
|
||||||
|
lines.append(" 数据画像:")
|
||||||
|
for col_name, profile in table["data_profile"].items():
|
||||||
|
if profile["type"] == "enum":
|
||||||
|
vals = ", ".join(str(v) for v in profile["values"][:10])
|
||||||
|
lines.append(f' - {col_name}: 枚举值({profile["distinct_count"]}个) = [{vals}]')
|
||||||
|
elif profile["type"] == "numeric":
|
||||||
|
lines.append(
|
||||||
|
f' - {col_name}: 范围[{profile["min"]} ~ {profile["max"]}], '
|
||||||
|
f'均值{profile["avg"]}, {profile["distinct_count"]}个不同值'
|
||||||
|
)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
Reference in New Issue
Block a user