From 96927a789dc762772dceb327bed7944bea03c8cb Mon Sep 17 00:00:00 2001 From: OpenClaw Agent Date: Thu, 19 Mar 2026 12:21:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=9B=9B=E5=B1=82=E6=9E=B6=E6=9E=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=88=86=E6=9E=90=20Agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Layer 1 Planner: 意图规划,将问题转为结构化分析计划 - Layer 2 Explorer: 自适应探索循环,多轮迭代动态生成 SQL - Layer 3 InsightEngine: 异常检测 + 主动洞察 - Layer 4 ContextManager: 多轮对话上下文记忆 安全设计:AI 只看 Schema + 聚合结果,不接触原始数据。 支持任意 OpenAI 兼容 API(OpenAI / Ollama / DeepSeek / vLLM) --- .gitignore | 3 + README.md | 117 ++++++++++++++++++++++ agent.py | 139 ++++++++++++++++++++++++++ cli.py | 112 +++++++++++++++++++++ config.py | 36 +++++++ context.py | 134 +++++++++++++++++++++++++ demo.py | 160 +++++++++++++++++++++++++++++ explorer.py | 238 +++++++++++++++++++++++++++++++++++++++++++ insights.py | 239 ++++++++++++++++++++++++++++++++++++++++++++ planner.py | 129 ++++++++++++++++++++++++ reporter.py | 110 ++++++++++++++++++++ requirements.txt | 1 + sandbox_executor.py | 139 ++++++++++++++++++++++++++ schema_extractor.py | 126 +++++++++++++++++++++++ 14 files changed, 1683 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 agent.py create mode 100644 cli.py create mode 100644 config.py create mode 100644 context.py create mode 100644 demo.py create mode 100644 explorer.py create mode 100644 insights.py create mode 100644 planner.py create mode 100644 reporter.py create mode 100644 requirements.txt create mode 100644 sandbox_executor.py create mode 100644 schema_extractor.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1c580aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +demo.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..2b3d117 --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# 数据分析 Agent —— Schema-Only + 四层架构自适应分析 + +**AI 只看表结构,不碰原始数据。通过四层架构自适应探索,生成深度分析报告。** + +## 架构 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Agent (编排层) │ +│ 接收问题 → 调度四层 → 输出报告 │ +└──────┬──────────┬──────────┬──────────┬─────────────────┘ + │ │ │ │ + ┌────▼────┐ ┌──▼─────┐ ┌─▼──────┐ ┌▼─────────┐ + │ Planner │ │Explorer│ │Insight │ │ Context │ + │ 意图规划 │ │探索循环 │ │异常检测 │ │ 上下文记忆 │ + └─────────┘ └────────┘ └────────┘ └──────────┘ +``` + +### 四层分工 + +| 层 | 组件 | 职责 | +|---|---|---| +| L1 | **Planner** | 理解用户意图,生成结构化分析计划(类型、维度、指标) | +| L2 | **Explorer** | 基于计划多轮迭代探索,每轮根据上一轮结果决定下一步 | +| L3 | **InsightEngine** | 从探索结果中检测异常、趋势、关联,输出主动洞察 | +| L4 | **ContextManager** | 管理多轮对话历史,后续问题可引用之前的分析 | + +### 安全隔离 + +``` +用户提问 → Agent 看 Schema 生成 SQL → 沙箱执行 → 聚合结果 → Agent 生成报告 + ↑ + 原始数据永远留在这里 +``` + +- **Schema 提取器**:只提取表结构、列类型、行数、枚举值,不碰数据 +- **沙箱执行器**:禁止 SELECT * / DDL / DML,必须聚合函数,小样本抑制(n<5) +- **AI 的视角**:只有 Schema + 聚合统计结果,从未接触任何一行原始数据 + +## 快速开始 + +```bash +# 1. 安装依赖 +pip install openai + +# 2. 配置 LLM(兼容 OpenAI API 格式) + +# OpenAI +export LLM_API_KEY=sk-xxx +export LLM_BASE_URL=https://api.openai.com/v1 +export LLM_MODEL=gpt-4o + +# Ollama(本地部署,隐私优先) +export LLM_API_KEY=ollama +export LLM_BASE_URL=http://localhost:11434/v1 +export LLM_MODEL=qwen2.5-coder:7b + +# DeepSeek +export LLM_API_KEY=sk-xxx +export LLM_BASE_URL=https://api.deepseek.com +export LLM_MODEL=deepseek-chat + +# 3. 运行演示(自动创建 5 万条示例数据 + 3 个分析任务) +python3 demo.py + +# 4. 交互式分析 +python3 cli.py + +# 5. 分析你自己的数据库 +python3 cli.py /path/to/your.db +``` + +## 文件结构 + +``` +├── config.py # 配置(LLM、安全规则、探索轮数) +├── schema_extractor.py # Schema 提取器(只提取结构) +├── sandbox_executor.py # 沙箱执行器(SQL 验证 + 结果脱敏) +├── planner.py # [L1] 意图规划器 +├── explorer.py # [L2] 自适应探索器 +├── insights.py # [L3] 洞察引擎(异常检测) +├── context.py # [L4] 上下文管理器 +├── reporter.py # 报告生成器 +├── agent.py # Agent 编排层 +├── demo.py # 一键演示 +├── cli.py # 交互式 CLI +├── requirements.txt # 依赖 +└── README.md +``` + +## 对比预制脚本 + +| | 预制脚本 / 模板 | 本方案(四层架构) | +|---|---|---| +| SQL 生成 | 模板拼接 | LLM 动态生成 | +| 查询数量 | 固定 | 1-6 轮,AI 自适应 | +| 后续追问 | 无 | AI 看到结果后判断是否深挖 | +| 异常发现 | 无 | 主动检测 + 主动输出 | +| 多轮对话 | 无 | 上下文记忆,可引用历史分析 | +| 适用场景 | 已知分析模式 | 探索性分析、开放性问题 | + +## CLI 命令 + +``` +📊 > 帮我分析各地区的销售表现 # 分析问题 +📊 > rounds=3 最近的趋势怎么样 # 限制探索轮数 +📊 > schema # 查看数据库 Schema +📊 > history # 查看分析历史 +📊 > audit # 查看 SQL 审计日志 +📊 > clear # 清空历史 +📊 > help # 帮助 +📊 > quit # 退出 +``` + +## License + +MIT diff --git a/agent.py b/agent.py new file mode 100644 index 0000000..4a288a1 --- /dev/null +++ b/agent.py @@ -0,0 +1,139 @@ +""" +Agent 编排层 —— 调度四层架构完成分析 + +Layer 1: Planner 意图规划 +Layer 2: Explorer 自适应探索 +Layer 3: Insight 异常洞察 +Layer 4: Context 上下文记忆 +""" +from typing import Optional + +from config import DB_PATH, MAX_EXPLORATION_ROUNDS +from schema_extractor import extract_schema, schema_to_text +from sandbox_executor import SandboxExecutor +from planner import Planner +from explorer import Explorer +from insights import InsightEngine, quick_detect +from reporter import ReportGenerator +from context import ContextManager + + +class DataAnalysisAgent: + """ + 数据分析 Agent + + 四层架构: + 1. Planner - 理解用户意图,生成分析计划 + 2. Explorer - 基于计划多轮迭代探索 + 3. Insights - 从结果中检测异常、输出主动洞察 + 4. Context - 管理多轮对话上下文 + + Agent 负责编排这四层,从问题到报告。 + """ + + def __init__(self, db_path: str): + # 数据层 + self.db_path = db_path + self.schema = extract_schema(db_path) + self.schema_text = schema_to_text(self.schema) + self.executor = SandboxExecutor(db_path) + + # 四层组件 + self.planner = Planner() + self.explorer = Explorer(self.executor) + self.insight_engine = InsightEngine() + self.reporter = ReportGenerator() + self.context = ContextManager() + + def analyze(self, question: str, max_rounds: Optional[int] = None) -> str: + """ + 完整分析流程 + + Args: + question: 用户分析问题 + max_rounds: 最大探索轮数(默认用配置值) + + Returns: + 格式化的分析报告 + """ + max_rounds = max_rounds or MAX_EXPLORATION_ROUNDS + + print(f"\n{'='*60}") + print(f"📊 {question}") + print(f"{'='*60}") + + # ── Layer 0: 检查上下文 ────────────────────────── + prev_context = self.context.get_context_for(question) + if prev_context: + print("📎 发现历史分析上下文,将结合之前的发现") + + # ── Layer 1: 意图规划 ──────────────────────────── + print("\n🎯 [Layer 1] 意图规划...") + plan = self.planner.plan(question, self.schema_text) + + analysis_type = plan.get("analysis_type", "unknown") + dimensions = plan.get("dimensions", []) + rationale = plan.get("rationale", "") + print(f" 类型: {analysis_type}") + print(f" 维度: {', '.join(dimensions) if dimensions else '自动发现'}") + print(f" 理由: {rationale[:80]}{'...' if len(rationale) > 80 else ''}") + + # ── Layer 2: 自适应探索 ────────────────────────── + print(f"\n🔍 [Layer 2] 自适应探索 (最多 {max_rounds} 轮)...") + steps = self.explorer.explore(plan, self.schema_text, max_rounds=max_rounds) + + successful = sum(1 for s in steps if s.success) + print(f"\n 完成: {len(steps)} 轮, {successful} 条成功查询") + + # ── Layer 3: 异常洞察 ──────────────────────────── + print("\n🔎 [Layer 3] 异常洞察...") + + # 先做规则检测 + rule_alerts = quick_detect(steps) + for alert in rule_alerts: + print(f" {alert}") + + # 再做 LLM 深度分析 + insights = self.insight_engine.analyze(steps, question) + if insights: + print(f" 发现 {len(insights)} 条洞察") + for insight in insights: + print(f" {insight}") + else: + print(" 未发现异常") + + # ── 生成报告 ──────────────────────────────────── + print("\n📝 正在生成报告...") + report = self.reporter.generate(question, plan, steps, insights) + + # 追加主动洞察 + if insights: + insight_text = self.insight_engine.format_insights(insights) + report += f"\n\n---\n\n{insight_text}" + + # ── Layer 4: 记录上下文 ────────────────────────── + self.context.add_session( + question=question, + plan=plan, + steps=steps, + insights=insights, + report=report, + ) + + return report + + def get_schema(self) -> str: + """获取 Schema 文本""" + return self.schema_text + + def get_history(self) -> str: + """获取分析历史摘要""" + return self.context.get_history_summary() + + def get_audit(self) -> str: + """获取执行审计日志""" + return self.executor.get_execution_summary() + + def clear_history(self): + """清空分析历史""" + self.context.clear() diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..d0636f4 --- /dev/null +++ b/cli.py @@ -0,0 +1,112 @@ +""" +交互式 CLI —— 四层架构自适应分析 +用法: python3 cli.py [数据库路径] +""" +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +from config import DB_PATH, LLM_CONFIG, MAX_EXPLORATION_ROUNDS +from agent import DataAnalysisAgent + + +def print_help(): + print(""" +可用命令: + <问题> 分析一个问题 + rounds= <问题> 设置探索轮数 + schema 查看数据库 Schema + history 查看分析历史 + audit 查看 SQL 审计日志 + clear 清空分析历史 + help 显示帮助 + quit / q 退出 +""") + + +def main(): + db_path = sys.argv[1] if len(sys.argv) > 1 else DB_PATH + + if not os.path.exists(db_path): + print(f"❌ 数据库不存在: {db_path}") + print(f" 请先运行 python3 demo.py 创建示例数据库,或指定已有数据库路径") + sys.exit(1) + + if not LLM_CONFIG["api_key"]: + print("⚠️ 未配置 LLM_API_KEY,请先设置环境变量:") + print(" export LLM_API_KEY=sk-xxx") + print(" export LLM_BASE_URL=https://api.openai.com/v1 # 或 Ollama 等") + print(" export LLM_MODEL=gpt-4o") + sys.exit(1) + + # 初始化 Agent + agent = DataAnalysisAgent(db_path) + + print("=" * 60) + print(" 🤖 数据分析 Agent —— 四层架构") + print("=" * 60) + print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}") + print(f"🔄 最大探索轮数: {MAX_EXPLORATION_ROUNDS}") + print(f"💾 数据库: {db_path}") + print(f"\n💬 输入分析问题(help 查看命令)\n") + + while True: + try: + user_input = input("📊 > ").strip() + except (EOFError, KeyboardInterrupt): + print("\n👋 再见!") + break + + if not user_input: + continue + + cmd = user_input.lower() + + if cmd in ("quit", "exit", "q"): + print("👋 再见!") + break + elif cmd == "help": + print_help() + continue + elif cmd == "schema": + print(agent.get_schema()) + continue + elif cmd == "history": + print(agent.get_history()) + continue + elif cmd == "audit": + print(agent.get_audit()) + continue + elif cmd == "clear": + agent.clear_history() + print("✅ 历史已清空") + continue + + # 解析可选参数:rounds=3 + max_rounds = MAX_EXPLORATION_ROUNDS + question = user_input + if "rounds=" in question.lower(): + parts = question.split("rounds=") + question = parts[0].strip() + try: + max_rounds = int(parts[1].strip().split()[0]) + except (ValueError, IndexError): + pass + + try: + report = agent.analyze(question, max_rounds=max_rounds) + print("\n" + report) + print("\n" + "~" * 60) + except Exception as e: + print(f"\n❌ 分析出错: {e}") + import traceback + traceback.print_exc() + + # 退出时显示审计 + print("\n📋 本次会话审计:") + print(agent.get_audit()) + + +if __name__ == "__main__": + main() diff --git a/config.py b/config.py new file mode 100644 index 0000000..f8c60b3 --- /dev/null +++ b/config.py @@ -0,0 +1,36 @@ +""" +配置文件 +""" +import os + +# LLM 配置(兼容 OpenAI API 格式,包括 Ollama / vLLM / DeepSeek 等) +LLM_CONFIG = { + "api_key": os.getenv("LLM_API_KEY", ""), + "base_url": os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + "model": os.getenv("LLM_MODEL", "gpt-4o"), +} + +# 沙箱安全规则 +SANDBOX_RULES = { + "max_result_rows": 1000, # 聚合结果最大行数 + "round_floats": 2, # 浮点数保留位数 + "suppress_small_n": 5, # 分组样本 < n 时模糊处理 + "banned_keywords": [ # 禁止的 SQL 关键字 + "SELECT *", + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "ATTACH", + "PRAGMA", + ], + "require_aggregation": True, # 是否要求使用聚合函数 +} + +# 数据库路径 +DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "demo.db")) + +# 分析控制 +MAX_EXPLORATION_ROUNDS = int(os.getenv("MAX_ROUNDS", "6")) # 最大探索轮数 diff --git a/context.py b/context.py new file mode 100644 index 0000000..faa22ef --- /dev/null +++ b/context.py @@ -0,0 +1,134 @@ +""" +Layer 4: 上下文管理器 +管理多轮对话的分析上下文,让后续问题可以引用之前的发现。 +""" +import json +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +from explorer import ExplorationStep +from insights import Insight + + +@dataclass +class AnalysisSession: + """一次分析的完整记录""" + question: str + plan: dict + steps: list[ExplorationStep] + insights: list[Insight] + report: str + timestamp: float = field(default_factory=time.time) + + def summary(self) -> str: + """生成本次分析的摘要(供后续对话引用)""" + parts = [f"**问题**: {self.question}"] + + if self.plan: + parts.append(f"**分析类型**: {self.plan.get('analysis_type', 'unknown')}") + parts.append(f"**关注维度**: {', '.join(self.plan.get('dimensions', []))}") + + # 核心发现(从成功步骤中提取) + key_findings = [] + for step in self.steps: + if step.success and step.rows: + # 提取最突出的值 + top_row = step.rows[0] if step.rows else {} + finding = f"{step.purpose}: " + finding += ", ".join( + f"{k}={v}" for k, v in top_row.items() if k.lower() not in ("id",) + ) + key_findings.append(finding) + + if key_findings: + parts.append("**核心发现**:") + for f in key_findings[:5]: + parts.append(f" - {f}") + + # 洞察 + if self.insights: + parts.append("**主动洞察**:") + for insight in self.insights[:3]: + parts.append(f" - {insight}") + + return "\n".join(parts) + + def to_reference_text(self) -> str: + """生成供 LLM 使用的上下文文本""" + return ( + f"## 之前的分析\n\n" + f"### 问题\n{self.question}\n\n" + f"### 摘要\n{self.summary()}\n\n" + f"### 详细发现\n" + + "\n".join( + f"- {step.purpose}: {step.row_count} 行结果" + for step in self.steps if step.success + ) + ) + + +class ContextManager: + """上下文管理器:管理多轮对话的分析历史""" + + def __init__(self, max_history: int = 10): + self.sessions: list[AnalysisSession] = [] + self.max_history = max_history + + def add_session( + self, + question: str, + plan: dict, + steps: list[ExplorationStep], + insights: list[Insight], + report: str, + ) -> AnalysisSession: + """记录一次分析""" + session = AnalysisSession( + question=question, + plan=plan, + steps=steps, + insights=insights, + report=report, + ) + self.sessions.append(session) + + # 保持历史大小 + if len(self.sessions) > self.max_history: + self.sessions = self.sessions[-self.max_history:] + + return session + + def get_context_for(self, new_question: str) -> Optional[str]: + """ + 根据新问题,从历史中找到相关上下文。 + 简单实现:取最近的分析会话。 + """ + if not self.sessions: + return None + + # 取最近 2 轮分析的摘要 + recent = self.sessions[-2:] + parts = [] + for session in recent: + parts.append(session.to_reference_text()) + + return "\n\n---\n\n".join(parts) + + def get_history_summary(self) -> str: + """获取所有历史的摘要""" + if not self.sessions: + return "(无历史分析)" + + lines = [f"共 {len(self.sessions)} 次分析:"] + for i, session in enumerate(self.sessions, 1): + ts = time.strftime("%H:%M", time.localtime(session.timestamp)) + lines.append(f" {i}. [{ts}] {session.question}") + if session.insights: + for insight in session.insights[:2]: + lines.append(f" {insight}") + return "\n".join(lines) + + def clear(self): + """清空历史""" + self.sessions.clear() diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..beb2c53 --- /dev/null +++ b/demo.py @@ -0,0 +1,160 @@ +""" +演示脚本 —— 创建示例数据,运行四层架构分析 +""" +import os +import sys +import sqlite3 +import random +from datetime import datetime, timedelta + +sys.path.insert(0, os.path.dirname(__file__)) + +from config import DB_PATH, LLM_CONFIG +from agent import DataAnalysisAgent + + +def create_demo_data(db_path: str): + """创建示例数据库:电商订单 + 用户 + 商品""" + if os.path.exists(db_path): + os.remove(db_path) + + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + cur.execute(""" + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + region TEXT NOT NULL, + tier TEXT NOT NULL, + created_at TEXT NOT NULL + ) + """) + + cur.execute(""" + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category TEXT NOT NULL, + price REAL NOT NULL + ) + """) + + cur.execute(""" + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + product_id INTEGER NOT NULL, + amount REAL NOT NULL, + quantity INTEGER NOT NULL, + status TEXT NOT NULL, + region TEXT NOT NULL, + order_date TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (product_id) REFERENCES products(id) + ) + """) + + regions = ["华东", "华南", "华北", "西南", "华中", "东北"] + tiers = ["普通", "银卡", "金卡", "钻石"] + categories = ["电子产品", "服饰", "食品", "家居", "图书", "美妆"] + statuses = ["已完成", "已发货", "待发货", "已取消", "退款中"] + + users = [] + for i in range(1, 2001): + reg = random.choice(regions) + tier = random.choices(tiers, weights=[50, 25, 15, 10])[0] + created = datetime(2024, 1, 1) + timedelta(days=random.randint(0, 400)) + users.append((i, f"用户{i}", reg, tier, created.strftime("%Y-%m-%d"))) + cur.executemany("INSERT INTO users VALUES (?, ?, ?, ?, ?)", users) + + products = [] + for i in range(1, 201): + cat = random.choice(categories) + price = round(random.uniform(9.9, 9999.99), 2) + products.append((i, f"商品{i}", cat, price)) + cur.executemany("INSERT INTO products VALUES (?, ?, ?, ?)", products) + + orders = [] + for i in range(1, 50001): + uid = random.randint(1, 2000) + pid = random.randint(1, 200) + qty = random.choices([1, 2, 3, 4, 5], weights=[50, 25, 15, 7, 3])[0] + status = random.choices(statuses, weights=[60, 15, 10, 10, 5])[0] + region = random.choices(regions, weights=[35, 25, 15, 10, 10, 5])[0] + order_date = datetime(2025, 1, 1) + timedelta(days=random.randint(0, 440)) + cur.execute("SELECT price FROM products WHERE id = ?", (pid,)) + price = cur.fetchone()[0] + amount = round(price * qty, 2) + orders.append((i, uid, pid, amount, qty, status, region, order_date.strftime("%Y-%m-%d"))) + cur.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?, ?, ?)", orders) + + conn.commit() + conn.close() + print(f"✅ 示例数据库已创建: {db_path}") + print(f" - users: 2,000 条") + print(f" - products: 200 条") + print(f" - orders: 50,000 条") + + +def main(): + print("=" * 60) + print(" 🤖 数据分析 Agent —— 四层架构自适应分析") + print("=" * 60) + + if not LLM_CONFIG["api_key"]: + print("\n⚠️ 未配置 LLM_API_KEY,请设置环境变量:") + print() + print(" # OpenAI") + print(" export LLM_API_KEY=sk-xxx") + print(" export LLM_BASE_URL=https://api.openai.com/v1") + print(" export LLM_MODEL=gpt-4o") + print() + print(" # Ollama (本地)") + print(" export LLM_API_KEY=ollama") + print(" export LLM_BASE_URL=http://localhost:11434/v1") + print(" export LLM_MODEL=qwen2.5-coder:7b") + print() + sys.exit(1) + + print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}") + + if not os.path.exists(DB_PATH): + create_demo_data(DB_PATH) + else: + print(f"\n📂 使用已有数据库: {DB_PATH}") + + # 初始化 Agent(自动加载 Schema) + agent = DataAnalysisAgent(DB_PATH) + + print("\n" + "=" * 60) + print(" ⬆️ AI 只看到 Schema:表结构 + 数据画像") + print(" ⬇️ 四层架构分析:规划 → 探索 → 洞察 → 报告") + print("=" * 60) + + questions = [ + "各地区的销售表现如何?帮我全面分析一下", + "不同商品类别的销售情况和利润贡献", + "整体订单的完成率怎么样?有没有什么异常需要关注?", + ] + + for q in questions: + try: + report = agent.analyze(q) + print("\n" + report) + print("\n" + "~" * 60) + except Exception as e: + print(f"\n❌ 分析出错: {e}") + import traceback + traceback.print_exc() + + # 最终审计 + print("\n📋 会话审计:") + print(agent.get_audit()) + print("\n📋 分析历史:") + print(agent.get_history()) + print("\n✅ 演示完成!") + + +if __name__ == "__main__": + main() diff --git a/explorer.py b/explorer.py new file mode 100644 index 0000000..1538655 --- /dev/null +++ b/explorer.py @@ -0,0 +1,238 @@ +""" +Layer 2: 自适应探索器 +基于分析计划 + 已有发现,动态决定下一步查什么。 +多轮迭代,直到 AI 判断"够了"或达到上限。 +""" +import json +import re +from typing import Any + +import openai +from config import LLM_CONFIG +from sandbox_executor import SandboxExecutor + + +EXPLORER_SYSTEM = """你是一个数据分析执行者。你的上级给了你一个分析计划,你需要通过迭代执行 SQL 来完成分析。 + +## 你的工作方式 +每一轮你看到: +1. 分析计划(上级给的目标) +2. 数据库 Schema(表结构、数据画像) +3. 之前的探索历史(查过什么、得到什么结果) + +你决定下一步: +- 输出一条 SQL 继续探索 +- 或者输出 done 表示分析足够 + +## 输出格式(严格 JSON) +```json +{ + "action": "query", + "reasoning": "为什么要做这个查询", + "sql": "SELECT ...", + "purpose": "这个查询的目的" +} +``` + +或: +```json +{ + "action": "done", + "reasoning": "为什么分析已经足够" +} +``` + +## SQL 规则 +- 只用 SELECT,必须有聚合函数或 GROUP BY +- 禁止 SELECT * +- 用 ROUND 控制精度 +- 合理使用 LIMIT(分组结果 15 行以内,时间序列 60 行以内) + +## 探索策略 +1. 第一轮:验证核心假设(计划中最关键的查询) +2. 后续轮:基于已有结果追问 + - 发现离群值 → 追问为什么 + - 发现异常比例 → 追问细分维度 + - 结果平淡 → 换个角度试试 +3. 不要重复查已经确认的事 +4. 每轮要有新发现,否则就该结束""" + + +EXPLORER_CONTINUE = """查询结果: + +{result_text} + +请基于这个结果决定下一步。如果发现异常或值得深挖的点,继续查询。如果分析足够,输出 done。""" + + +class ExplorationStep: + """单步探索结果""" + def __init__(self, round_num: int, decision: dict, result: dict): + self.round_num = round_num + self.reasoning = decision.get("reasoning", "") + self.purpose = decision.get("purpose", "") + self.sql = decision.get("sql", "") + self.action = decision.get("action", "query") + self.success = result.get("success", False) + self.error = result.get("error") + self.columns = result.get("columns", []) + self.rows = result.get("rows", []) + self.row_count = result.get("row_count", 0) + + def to_dict(self) -> dict: + d = { + "round": self.round_num, + "action": self.action, + "reasoning": self.reasoning, + "purpose": self.purpose, + "sql": self.sql, + "success": self.success, + } + if self.success: + d["result"] = { + "columns": self.columns, + "rows": self.rows, + "row_count": self.row_count, + } + else: + d["result"] = {"error": self.error} + return d + + +class Explorer: + """自适应探索器:多轮迭代执行 SQL""" + + def __init__(self, executor: SandboxExecutor): + self.executor = executor + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def explore( + self, + plan: dict, + schema_text: str, + max_rounds: int = 6, + ) -> list[ExplorationStep]: + """ + 执行探索循环 + + Args: + plan: Planner 生成的分析计划 + schema_text: Schema 文本描述 + max_rounds: 最大探索轮数 + + Returns: + 探步列表 + """ + steps: list[ExplorationStep] = [] + + # 构建初始消息 + messages = [ + {"role": "system", "content": EXPLORER_SYSTEM}, + { + "role": "user", + "content": ( + f"## 分析计划\n```json\n{json.dumps(plan, ensure_ascii=False, indent=2)}\n```\n\n" + f"## 数据库 Schema\n{schema_text}\n\n" + f"请开始第一轮探索。根据计划,先执行最关键的查询。" + ), + }, + ] + + for round_num in range(1, max_rounds + 1): + print(f"\n 🔄 探索第 {round_num}/{max_rounds} 轮") + + # LLM 决策 + decision = self._llm_decide(messages) + action = decision.get("action", "query") + reasoning = decision.get("reasoning", "") + + print(f" 💭 {reasoning[:80]}{'...' if len(reasoning) > 80 else ''}") + + if action == "done": + print(f" ✅ 探索完成") + steps.append(ExplorationStep(round_num, decision, {"success": True})) + break + + # 执行 SQL + sql = decision.get("sql", "") + purpose = decision.get("purpose", "") + + if not sql: + print(f" ⚠️ 未生成 SQL,跳过") + continue + + print(f" 📝 {purpose}") + result = self.executor.execute(sql) + + if result["success"]: + print(f" ✅ {result['row_count']} 行") + else: + print(f" ❌ {result['error']}") + + step = ExplorationStep(round_num, decision, result) + steps.append(step) + + # 更新对话历史 + messages.append({ + "role": "assistant", + "content": json.dumps(decision, ensure_ascii=False), + }) + + result_text = self._format_result(result) + messages.append({ + "role": "user", + "content": EXPLORER_CONTINUE.format(result_text=result_text), + }) + + return steps + + def _llm_decide(self, messages: list[dict]) -> dict: + """LLM 决策""" + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.2, + max_tokens=1024, + ) + content = response.choices[0].message.content.strip() + return self._extract_json(content) + + def _format_result(self, result: dict) -> str: + """格式化查询结果""" + if not result.get("success"): + return f"❌ 执行失败: {result.get('error', '未知错误')}" + + rows = result["rows"][:20] + return ( + f"✅ 查询成功,返回 {result['row_count']} 行\n" + f"列: {result['columns']}\n" + f"数据:\n{json.dumps(rows, ensure_ascii=False, indent=2)}" + ) + + def _extract_json(self, text: str) -> dict: + """从 LLM 输出提取 JSON""" + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: + match = re.search(pattern, text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + continue + + match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return {"action": "done", "reasoning": f"无法解析输出: {text[:100]}"} diff --git a/insights.py b/insights.py new file mode 100644 index 0000000..39639aa --- /dev/null +++ b/insights.py @@ -0,0 +1,239 @@ +""" +Layer 3: 洞察引擎 +对探索结果进行异常检测 + 主动洞察,输出用户没问但值得知道的事。 +""" +import json +import re +import statistics +from typing import Any + +import openai +from config import LLM_CONFIG +from explorer import ExplorationStep + + +INSIGHT_SYSTEM = """你是一个数据洞察专家。你会收到探索过程的所有结果,你需要: + +1. 从结果中发现异常和有趣现象 +2. 对比不同维度,找出差异 +3. 输出用户可能没问但值得知道的洞察 + +## 输出格式(严格 JSON 数组) +```json +[ + { + "type": "outlier" | "trend" | "distribution" | "correlation" | "recommendation", + "severity": "high" | "medium" | "low", + "title": "简短标题", + "detail": "详细描述,包含具体数字", + "evidence": "支撑这个洞察的数据来源" + } +] +``` + +## 洞察类型 +- outlier: 离群值(某个分组异常高/低) +- trend: 趋势发现(增长/下降、季节性) +- distribution: 分布异常(不均衡、集中度过高) +- correlation: 关联发现(两个维度的意外关联) +- recommendation: 行动建议(基于数据的建议) + +## 分析原则 +- 每个洞察必须有具体数字支撑 +- 用对比来说话(A 比 B 高 X%) +- 关注异常,不描述平淡的事实 +- 如果没有异常,返回空数组""" + + +class Insight: + """单条洞察""" + def __init__(self, data: dict): + self.type = data.get("type", "unknown") + self.severity = data.get("severity", "low") + self.title = data.get("title", "") + self.detail = data.get("detail", "") + self.evidence = data.get("evidence", "") + + @property + def emoji(self) -> str: + return { + "outlier": "⚠️", + "trend": "📈", + "distribution": "📊", + "correlation": "🔗", + "recommendation": "💡", + }.get(self.type, "📌") + + @property + def severity_emoji(self) -> str: + return {"high": "🔴", "medium": "🟡", "low": "🟢"}.get(self.severity, "") + + def __str__(self): + return f"{self.emoji} {self.severity_emoji} {self.title}: {self.detail}" + + +class InsightEngine: + """洞察引擎:自动检测异常 + 主动输出""" + + def __init__(self): + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def analyze(self, steps: list[ExplorationStep], question: str) -> list[Insight]: + """ + 对探索结果进行洞察分析 + + Args: + steps: 探索步骤列表 + question: 原始用户问题 + + Returns: + 洞察列表 + """ + if not steps: + return [] + + # 构建探索历史文本 + history = self._build_history(steps) + + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": INSIGHT_SYSTEM}, + { + "role": "user", + "content": ( + f"## 用户原始问题\n{question}\n\n" + f"## 探索历史\n{history}\n\n" + f"请分析以上数据,输出异常和洞察。" + ), + }, + ], + temperature=0.3, + max_tokens=2048, + ) + + content = response.choices[0].message.content.strip() + insights_data = self._extract_json_array(content) + + return [Insight(d) for d in insights_data] + + def format_insights(self, insights: list[Insight]) -> str: + """格式化洞察为可读文本""" + if not insights: + return "" + + # 按严重程度排序 + severity_order = {"high": 0, "medium": 1, "low": 2} + sorted_insights = sorted(insights, key=lambda i: severity_order.get(i.severity, 9)) + + lines = ["## 💡 主动洞察", ""] + lines.append("_以下是你没问但数据告诉我们的事:_\n") + + for insight in sorted_insights: + lines.append(f"**{insight.emoji} {insight.title}** {insight.severity_emoji}") + lines.append(f" {insight.detail}") + lines.append(f" _数据来源: {insight.evidence}_") + lines.append("") + + return "\n".join(lines) + + def _build_history(self, steps: list[ExplorationStep]) -> str: + """构建探索历史文本""" + parts = [] + for step in steps: + if step.action == "done": + parts.append(f"### 结束\n{step.reasoning}") + continue + + if step.success: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"思考: {step.reasoning}\n" + f"SQL: `{step.sql}`\n" + f"结果 ({step.row_count} 行):\n" + f"列: {step.columns}\n" + f"数据: {json.dumps(step.rows, ensure_ascii=False)}" + ) + else: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"SQL: `{step.sql}`\n" + f"结果: 执行失败 - {step.error}" + ) + + return "\n\n".join(parts) + + def _extract_json_array(self, text: str) -> list[dict]: + """从 LLM 输出提取 JSON 数组""" + try: + result = json.loads(text) + if isinstance(result, list): + return result + except json.JSONDecodeError: + pass + + for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: + match = re.search(pattern, text, re.DOTALL) + if match: + try: + result = json.loads(match.group(1)) + if isinstance(result, list): + return result + except json.JSONDecodeError: + continue + + # 找最外层 [] + match = re.search(r'\[.*\]', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return [] + + +# ── 基于规则的快速异常检测(无需 LLM)──────────────── + +def quick_detect(steps: list[ExplorationStep]) -> list[str]: + """ + 基于规则的快速异常检测,不调 LLM。 + 检测离群值、不均衡分布等。 + """ + alerts = [] + + for step in steps: + if not step.success or not step.rows: + continue + + for row in step.rows: + for col in step.columns: + val = row.get(col) + if not isinstance(val, (int, float)): + continue + + # 检查 pct 列:某个分组占比异常 + if col.lower() in ("pct", "percent", "percentage", "占比"): + if isinstance(val, (int, float)) and val > 50: + alerts.append( + f"⚠️ {step.purpose} 中某个分组占比 {val}%,超过 50%,集中度过高" + ) + + # 检查 count 列:极值差异 + if col.lower() in ("count", "cnt", "n", "total", "order_count"): + count_vals = [ + r.get(col) for r in step.rows + if isinstance(r.get(col), (int, float)) + ] + if len(count_vals) >= 3 and max(count_vals) > 0: + ratio = max(count_vals) / (sum(count_vals) / len(count_vals)) + if ratio > 3: + alerts.append( + f"⚠️ {step.purpose} 中最大值是均值的 {ratio:.1f} 倍,分布极不均衡" + ) + + return alerts diff --git a/planner.py b/planner.py new file mode 100644 index 0000000..6a671e8 --- /dev/null +++ b/planner.py @@ -0,0 +1,129 @@ +""" +Layer 1: 意图规划器 +将用户问题解析为结构化的分析计划,替代硬编码模板。 +""" +import json +import re +from typing import Any + +PROMPT = """你是一个数据分析规划专家。 + +## 你的任务 +根据用户的分析问题和数据库 Schema,生成一个结构化的分析计划。 + +## 输出格式(严格 JSON) +```json +{ + "intent": "一句话描述用户想了解什么", + "analysis_type": "ranking" | "distribution" | "trend" | "comparison" | "anomaly" | "overview", + "primary_table": "主要分析的表名", + "dimensions": ["分组维度列名"], + "metrics": ["需要聚合的数值列名"], + "aggregations": ["SUM", "AVG", "COUNT", ...], + "filters": [{"column": "列名", "condition": "过滤条件(可选)"}], + "join_needed": false, + "join_info": {"tables": [], "on": ""}, + "expected_rounds": 3, + "rationale": "为什么这样规划,需要关注什么" +} +``` + +## 分析类型说明 +- ranking: 按某维度排名(哪个地区最高) +- distribution: 分布/占比(各地区占比多少) +- trend: 时间趋势(月度变化) +- comparison: 对比分析(A vs B) +- anomaly: 异常检测(有没有异常值) +- overview: 全局概览(整体情况如何) + +## 规划原则 +1. 只选择与问题相关的表和列 +2. 如果需要 JOIN,说明关联条件 +3. 预估需要几轮探索(1-6) +4. 标注可能的异常关注点 +5. metrics 不要包含 id 列""" + +import openai +from config import LLM_CONFIG + + +class Planner: + """意图规划器:将自然语言问题转为结构化分析计划""" + + def __init__(self): + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def plan(self, question: str, schema_text: str) -> dict[str, Any]: + """ + 生成分析计划 + + Returns: + { + "intent": str, + "analysis_type": str, + "primary_table": str, + "dimensions": [str], + "metrics": [str], + "aggregations": [str], + "filters": [dict], + "join_needed": bool, + "join_info": dict, + "expected_rounds": int, + "rationale": str, + } + """ + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": PROMPT}, + { + "role": "user", + "content": ( + f"## Schema\n{schema_text}\n\n" + f"## 用户问题\n{question}" + ), + }, + ], + temperature=0.1, + max_tokens=1024, + ) + + content = response.choices[0].message.content.strip() + plan = self._extract_json(content) + + # 补充默认值 + plan.setdefault("analysis_type", "overview") + plan.setdefault("expected_rounds", 3) + plan.setdefault("filters", []) + plan.setdefault("join_needed", False) + + return plan + + def _extract_json(self, text: str) -> dict: + """从 LLM 输出提取 JSON""" + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: + match = re.search(pattern, text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + continue + + # 找最外层 {} + match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return {"intent": text[:100], "analysis_type": "overview"} diff --git a/reporter.py b/reporter.py new file mode 100644 index 0000000..5a29c53 --- /dev/null +++ b/reporter.py @@ -0,0 +1,110 @@ +""" +报告生成器 +将分析计划 + 探索结果 + 洞察,综合为一份可读报告。 +""" +import json +from typing import Any + +import openai +from config import LLM_CONFIG +from explorer import ExplorationStep +from insights import Insight + + +REPORT_PROMPT = """你是一个数据分析报告撰写专家。基于以下信息撰写报告。 + +## 用户问题 +{question} + +## 分析计划 +{plan} + +## 探索过程 +{exploration} + +## 主动洞察 +{insights_text} + +## 撰写要求 +1. **开头**:一句话总结核心结论(先给答案) +2. **核心发现**:按重要性排列,每个发现带具体数字 +3. **深入洞察**:异常、趋势、关联(从洞察中提取) +4. **建议**:基于数据的行动建议 +5. **审计**:末尾附上所有执行的 SQL + +使用 Markdown,中文撰写。 +语气:专业但不枯燥,像一个聪明的分析师在做简报。""" + + +class ReportGenerator: + """报告生成器""" + + def __init__(self): + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def generate( + self, + question: str, + plan: dict, + steps: list[ExplorationStep], + insights: list[Insight], + ) -> str: + """生成分析报告""" + + # 构建探索过程文本 + exploration = self._build_exploration(steps) + + # 构建洞察文本 + insights_text = self._build_insights(insights) + + prompt = REPORT_PROMPT.format( + question=question, + plan=json.dumps(plan, ensure_ascii=False, indent=2), + exploration=exploration, + insights_text=insights_text, + ) + + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "你是专业的数据分析师,撰写清晰、有洞察力的分析报告。"}, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + max_tokens=4096, + ) + + return response.choices[0].message.content + + def _build_exploration(self, steps: list[ExplorationStep]) -> str: + parts = [] + for step in steps: + if step.action == "done": + parts.append(f"### 探索结束\n{step.reasoning}") + continue + + if step.success: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"思考: {step.reasoning}\n" + f"SQL: `{step.sql}`\n" + f"结果 ({step.row_count} 行):\n" + f"列: {step.columns}\n" + f"数据: {json.dumps(step.rows, ensure_ascii=False)}" + ) + else: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"SQL: `{step.sql}`\n" + f"执行失败: {step.error}" + ) + return "\n\n".join(parts) if parts else "无探索步骤" + + def _build_insights(self, insights: list[Insight]) -> str: + if not insights: + return "未检测到异常。" + return "\n".join(str(i) for i in insights) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..aa2b704 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +openai>=1.0.0 diff --git a/sandbox_executor.py b/sandbox_executor.py new file mode 100644 index 0000000..c1094d6 --- /dev/null +++ b/sandbox_executor.py @@ -0,0 +1,139 @@ +""" +沙箱执行器 —— 执行 SQL,只返回聚合结果 +""" +import sqlite3 +import re +from typing import Any +from config import SANDBOX_RULES, DB_PATH + + +class SandboxError(Exception): + """沙箱安全违规""" + pass + + +class SandboxExecutor: + def __init__(self, db_path: str = DB_PATH): + self.db_path = db_path + self.rules = SANDBOX_RULES + self.execution_log: list[dict] = [] + + def execute(self, sql: str) -> dict[str, Any]: + """ + 执行 SQL,返回脱敏后的聚合结果。 + 如果违反安全规则,抛出 SandboxError。 + """ + # 验证 SQL 安全性 + self._validate(sql) + + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + cur.execute(sql) + rows = cur.fetchall() + + # 转为字典列表 + columns = [desc[0] for desc in cur.description] if cur.description else [] + results = [dict(row) for row in rows] + + # 脱敏处理 + sanitized = self._sanitize(results, columns) + + # 记录执行日志 + self.execution_log.append({ + "sql": sql, + "rows_returned": len(results), + "columns": columns, + }) + + return { + "success": True, + "columns": columns, + "rows": sanitized, + "row_count": len(sanitized), + "sql": sql, + } + + except sqlite3.Error as e: + return { + "success": False, + "error": str(e), + "sql": sql, + } + finally: + conn.close() + + def _validate(self, sql: str): + """SQL 安全验证""" + sql_upper = sql.upper().strip() + + # 1. 检查禁止的关键字 + for banned in self.rules["banned_keywords"]: + if banned.upper() in sql_upper: + raise SandboxError(f"禁止的 SQL 关键字: {banned}") + + # 2. 只允许 SELECT(不能有多语句) + statements = [s.strip() for s in sql.split(";") if s.strip()] + if len(statements) > 1: + raise SandboxError("禁止多语句执行") + if not sql_upper.startswith("SELECT"): + raise SandboxError("只允许 SELECT 查询") + + # 3. 检查是否使用了聚合函数或 GROUP BY(可选要求) + if self.rules["require_aggregation"]: + agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY", + "DISTINCT", "HAVING", "ROUND", "CAST"] + has_agg = any(kw in sql_upper for kw in agg_keywords) + has_limit = "LIMIT" in sql_upper + + if not has_agg and not has_limit: + raise SandboxError( + "要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT" + ) + + # 4. LIMIT 检查 + limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) + if limit_match: + limit_val = int(limit_match.group(1)) + if limit_val > self.rules["max_result_rows"]: + raise SandboxError( + f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}" + ) + + def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]: + """对结果进行脱敏处理""" + if not rows: + return rows + + # 1. 限制行数 + rows = rows[:self.rules["max_result_rows"]] + + # 2. 浮点数四舍五入 + for row in rows: + for col in columns: + val = row.get(col) + if isinstance(val, float): + row[col] = round(val, self.rules["round_floats"]) + + # 3. 小样本抑制(k-anonymity) + # 如果某个分组的 count 小于阈值,标记为 " str: + """获取执行摘要""" + if not self.execution_log: + return "尚未执行任何查询" + + lines = [f"共执行 {len(self.execution_log)} 条查询:"] + for i, log in enumerate(self.execution_log, 1): + lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果") + return "\n".join(lines) diff --git a/schema_extractor.py b/schema_extractor.py new file mode 100644 index 0000000..90b6492 --- /dev/null +++ b/schema_extractor.py @@ -0,0 +1,126 @@ +""" +Schema 提取器 —— 只提取表结构,不碰数据 +""" +import sqlite3 +from typing import Any + + +def extract_schema(db_path: str) -> dict[str, Any]: + """ + 从数据库提取 Schema,只返回结构信息: + - 表名、列名、类型 + - 主键、外键 + - 行数 + - 枚举列的去重值(不含原始数据) + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + # 获取所有表 + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") + tables = [row["name"] for row in cur.fetchall()] + + schema = {"tables": []} + + for table in tables: + # 列信息 + cur.execute(f"PRAGMA table_info('{table}')") + columns = [] + for col in cur.fetchall(): + columns.append({ + "name": col["name"], + "type": col["type"], + "nullable": col["notnull"] == 0, + "is_primary_key": col["pk"] == 1, + }) + + # 外键 + cur.execute(f"PRAGMA foreign_key_list('{table}')") + fks = [] + for fk in cur.fetchall(): + fks.append({ + "column": fk["from"], + "references_table": fk["table"], + "references_column": fk["to"], + }) + + # 行数 + cur.execute(f"SELECT COUNT(*) AS cnt FROM '{table}'") + row_count = cur.fetchone()["cnt"] + + # 对 VARCHAR / TEXT 类型列,提取去重枚举值(最多 20 个) + data_profile = {} + for col in columns: + col_name = col["name"] + col_type = (col["type"] or "").upper() + + if any(t in col_type for t in ("VARCHAR", "TEXT", "CHAR")): + cur.execute(f'SELECT DISTINCT "{col_name}" FROM "{table}" WHERE "{col_name}" IS NOT NULL LIMIT 20') + vals = [row[0] for row in cur.fetchall()] + if len(vals) <= 20: + data_profile[col_name] = { + "type": "enum", + "distinct_count": len(vals), + "values": vals, + } + elif any(t in col_type for t in ("INT", "REAL", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC")): + cur.execute(f''' + SELECT MIN("{col_name}") AS min_val, MAX("{col_name}") AS max_val, + AVG("{col_name}") AS avg_val, COUNT(DISTINCT "{col_name}") AS distinct_count + FROM "{table}" WHERE "{col_name}" IS NOT NULL + ''') + row = cur.fetchone() + if row and row["min_val"] is not None: + data_profile[col_name] = { + "type": "numeric", + "min": round(row["min_val"], 2), + "max": round(row["max_val"], 2), + "avg": round(row["avg_val"], 2), + "distinct_count": row["distinct_count"], + } + + schema["tables"].append({ + "name": table, + "columns": columns, + "foreign_keys": fks, + "row_count": row_count, + "data_profile": data_profile, + }) + + conn.close() + return schema + + +def schema_to_text(schema: dict) -> str: + """将 Schema 转为可读文本,供 LLM 理解""" + lines = ["=== 数据库 Schema ===\n"] + + for table in schema["tables"]: + lines.append(f"📋 表: {table['name']} (共 {table['row_count']} 行)") + lines.append(" 列:") + for col in table["columns"]: + pk = " [PK]" if col["is_primary_key"] else "" + null = " NULL" if col["nullable"] else " NOT NULL" + lines.append(f' - {col["name"]}: {col["type"]}{pk}{null}') + + if table["foreign_keys"]: + lines.append(" 外键:") + for fk in table["foreign_keys"]: + lines.append(f' - {fk["column"]} → {fk["references_table"]}.{fk["references_column"]}') + + if table["data_profile"]: + lines.append(" 数据画像:") + for col_name, profile in table["data_profile"].items(): + if profile["type"] == "enum": + vals = ", ".join(str(v) for v in profile["values"][:10]) + lines.append(f' - {col_name}: 枚举值({profile["distinct_count"]}个) = [{vals}]') + elif profile["type"] == "numeric": + lines.append( + f' - {col_name}: 范围[{profile["min"]} ~ {profile["max"]}], ' + f'均值{profile["avg"]}, {profile["distinct_count"]}个不同值' + ) + + lines.append("") + + return "\n".join(lines)