commit 96927a789dc762772dceb327bed7944bea03c8cb Author: OpenClaw Agent Date: Thu Mar 19 12:21:04 2026 +0800 feat: 四层架构数据分析 Agent - Layer 1 Planner: 意图规划,将问题转为结构化分析计划 - Layer 2 Explorer: 自适应探索循环,多轮迭代动态生成 SQL - Layer 3 InsightEngine: 异常检测 + 主动洞察 - Layer 4 ContextManager: 多轮对话上下文记忆 安全设计:AI 只看 Schema + 聚合结果,不接触原始数据。 支持任意 OpenAI 兼容 API(OpenAI / Ollama / DeepSeek / vLLM) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1c580aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +demo.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..2b3d117 --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# 数据分析 Agent —— Schema-Only + 四层架构自适应分析 + +**AI 只看表结构,不碰原始数据。通过四层架构自适应探索,生成深度分析报告。** + +## 架构 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Agent (编排层) │ +│ 接收问题 → 调度四层 → 输出报告 │ +└──────┬──────────┬──────────┬──────────┬─────────────────┘ + │ │ │ │ + ┌────▼────┐ ┌──▼─────┐ ┌─▼──────┐ ┌▼─────────┐ + │ Planner │ │Explorer│ │Insight │ │ Context │ + │ 意图规划 │ │探索循环 │ │异常检测 │ │ 上下文记忆 │ + └─────────┘ └────────┘ └────────┘ └──────────┘ +``` + +### 四层分工 + +| 层 | 组件 | 职责 | +|---|---|---| +| L1 | **Planner** | 理解用户意图,生成结构化分析计划(类型、维度、指标) | +| L2 | **Explorer** | 基于计划多轮迭代探索,每轮根据上一轮结果决定下一步 | +| L3 | **InsightEngine** | 从探索结果中检测异常、趋势、关联,输出主动洞察 | +| L4 | **ContextManager** | 管理多轮对话历史,后续问题可引用之前的分析 | + +### 安全隔离 + +``` +用户提问 → Agent 看 Schema 生成 SQL → 沙箱执行 → 聚合结果 → Agent 生成报告 + ↑ + 原始数据永远留在这里 +``` + +- **Schema 提取器**:只提取表结构、列类型、行数、枚举值,不碰数据 +- **沙箱执行器**:禁止 SELECT * / DDL / DML,必须聚合函数,小样本抑制(n<5) +- **AI 的视角**:只有 Schema + 聚合统计结果,从未接触任何一行原始数据 + +## 快速开始 + +```bash +# 1. 安装依赖 +pip install openai + +# 2. 配置 LLM(兼容 OpenAI API 格式) + +# OpenAI +export LLM_API_KEY=sk-xxx +export LLM_BASE_URL=https://api.openai.com/v1 +export LLM_MODEL=gpt-4o + +# Ollama(本地部署,隐私优先) +export LLM_API_KEY=ollama +export LLM_BASE_URL=http://localhost:11434/v1 +export LLM_MODEL=qwen2.5-coder:7b + +# DeepSeek +export LLM_API_KEY=sk-xxx +export LLM_BASE_URL=https://api.deepseek.com +export LLM_MODEL=deepseek-chat + +# 3. 运行演示(自动创建 5 万条示例数据 + 3 个分析任务) +python3 demo.py + +# 4. 交互式分析 +python3 cli.py + +# 5. 分析你自己的数据库 +python3 cli.py /path/to/your.db +``` + +## 文件结构 + +``` +├── config.py # 配置(LLM、安全规则、探索轮数) +├── schema_extractor.py # Schema 提取器(只提取结构) +├── sandbox_executor.py # 沙箱执行器(SQL 验证 + 结果脱敏) +├── planner.py # [L1] 意图规划器 +├── explorer.py # [L2] 自适应探索器 +├── insights.py # [L3] 洞察引擎(异常检测) +├── context.py # [L4] 上下文管理器 +├── reporter.py # 报告生成器 +├── agent.py # Agent 编排层 +├── demo.py # 一键演示 +├── cli.py # 交互式 CLI +├── requirements.txt # 依赖 +└── README.md +``` + +## 对比预制脚本 + +| | 预制脚本 / 模板 | 本方案(四层架构) | +|---|---|---| +| SQL 生成 | 模板拼接 | LLM 动态生成 | +| 查询数量 | 固定 | 1-6 轮,AI 自适应 | +| 后续追问 | 无 | AI 看到结果后判断是否深挖 | +| 异常发现 | 无 | 主动检测 + 主动输出 | +| 多轮对话 | 无 | 上下文记忆,可引用历史分析 | +| 适用场景 | 已知分析模式 | 探索性分析、开放性问题 | + +## CLI 命令 + +``` +📊 > 帮我分析各地区的销售表现 # 分析问题 +📊 > rounds=3 最近的趋势怎么样 # 限制探索轮数 +📊 > schema # 查看数据库 Schema +📊 > history # 查看分析历史 +📊 > audit # 查看 SQL 审计日志 +📊 > clear # 清空历史 +📊 > help # 帮助 +📊 > quit # 退出 +``` + +## License + +MIT diff --git a/agent.py b/agent.py new file mode 100644 index 0000000..4a288a1 --- /dev/null +++ b/agent.py @@ -0,0 +1,139 @@ +""" +Agent 编排层 —— 调度四层架构完成分析 + +Layer 1: Planner 意图规划 +Layer 2: Explorer 自适应探索 +Layer 3: Insight 异常洞察 +Layer 4: Context 上下文记忆 +""" +from typing import Optional + +from config import DB_PATH, MAX_EXPLORATION_ROUNDS +from schema_extractor import extract_schema, schema_to_text +from sandbox_executor import SandboxExecutor +from planner import Planner +from explorer import Explorer +from insights import InsightEngine, quick_detect +from reporter import ReportGenerator +from context import ContextManager + + +class DataAnalysisAgent: + """ + 数据分析 Agent + + 四层架构: + 1. Planner - 理解用户意图,生成分析计划 + 2. Explorer - 基于计划多轮迭代探索 + 3. Insights - 从结果中检测异常、输出主动洞察 + 4. Context - 管理多轮对话上下文 + + Agent 负责编排这四层,从问题到报告。 + """ + + def __init__(self, db_path: str): + # 数据层 + self.db_path = db_path + self.schema = extract_schema(db_path) + self.schema_text = schema_to_text(self.schema) + self.executor = SandboxExecutor(db_path) + + # 四层组件 + self.planner = Planner() + self.explorer = Explorer(self.executor) + self.insight_engine = InsightEngine() + self.reporter = ReportGenerator() + self.context = ContextManager() + + def analyze(self, question: str, max_rounds: Optional[int] = None) -> str: + """ + 完整分析流程 + + Args: + question: 用户分析问题 + max_rounds: 最大探索轮数(默认用配置值) + + Returns: + 格式化的分析报告 + """ + max_rounds = max_rounds or MAX_EXPLORATION_ROUNDS + + print(f"\n{'='*60}") + print(f"📊 {question}") + print(f"{'='*60}") + + # ── Layer 0: 检查上下文 ────────────────────────── + prev_context = self.context.get_context_for(question) + if prev_context: + print("📎 发现历史分析上下文,将结合之前的发现") + + # ── Layer 1: 意图规划 ──────────────────────────── + print("\n🎯 [Layer 1] 意图规划...") + plan = self.planner.plan(question, self.schema_text) + + analysis_type = plan.get("analysis_type", "unknown") + dimensions = plan.get("dimensions", []) + rationale = plan.get("rationale", "") + print(f" 类型: {analysis_type}") + print(f" 维度: {', '.join(dimensions) if dimensions else '自动发现'}") + print(f" 理由: {rationale[:80]}{'...' if len(rationale) > 80 else ''}") + + # ── Layer 2: 自适应探索 ────────────────────────── + print(f"\n🔍 [Layer 2] 自适应探索 (最多 {max_rounds} 轮)...") + steps = self.explorer.explore(plan, self.schema_text, max_rounds=max_rounds) + + successful = sum(1 for s in steps if s.success) + print(f"\n 完成: {len(steps)} 轮, {successful} 条成功查询") + + # ── Layer 3: 异常洞察 ──────────────────────────── + print("\n🔎 [Layer 3] 异常洞察...") + + # 先做规则检测 + rule_alerts = quick_detect(steps) + for alert in rule_alerts: + print(f" {alert}") + + # 再做 LLM 深度分析 + insights = self.insight_engine.analyze(steps, question) + if insights: + print(f" 发现 {len(insights)} 条洞察") + for insight in insights: + print(f" {insight}") + else: + print(" 未发现异常") + + # ── 生成报告 ──────────────────────────────────── + print("\n📝 正在生成报告...") + report = self.reporter.generate(question, plan, steps, insights) + + # 追加主动洞察 + if insights: + insight_text = self.insight_engine.format_insights(insights) + report += f"\n\n---\n\n{insight_text}" + + # ── Layer 4: 记录上下文 ────────────────────────── + self.context.add_session( + question=question, + plan=plan, + steps=steps, + insights=insights, + report=report, + ) + + return report + + def get_schema(self) -> str: + """获取 Schema 文本""" + return self.schema_text + + def get_history(self) -> str: + """获取分析历史摘要""" + return self.context.get_history_summary() + + def get_audit(self) -> str: + """获取执行审计日志""" + return self.executor.get_execution_summary() + + def clear_history(self): + """清空分析历史""" + self.context.clear() diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..d0636f4 --- /dev/null +++ b/cli.py @@ -0,0 +1,112 @@ +""" +交互式 CLI —— 四层架构自适应分析 +用法: python3 cli.py [数据库路径] +""" +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +from config import DB_PATH, LLM_CONFIG, MAX_EXPLORATION_ROUNDS +from agent import DataAnalysisAgent + + +def print_help(): + print(""" +可用命令: + <问题> 分析一个问题 + rounds= <问题> 设置探索轮数 + schema 查看数据库 Schema + history 查看分析历史 + audit 查看 SQL 审计日志 + clear 清空分析历史 + help 显示帮助 + quit / q 退出 +""") + + +def main(): + db_path = sys.argv[1] if len(sys.argv) > 1 else DB_PATH + + if not os.path.exists(db_path): + print(f"❌ 数据库不存在: {db_path}") + print(f" 请先运行 python3 demo.py 创建示例数据库,或指定已有数据库路径") + sys.exit(1) + + if not LLM_CONFIG["api_key"]: + print("⚠️ 未配置 LLM_API_KEY,请先设置环境变量:") + print(" export LLM_API_KEY=sk-xxx") + print(" export LLM_BASE_URL=https://api.openai.com/v1 # 或 Ollama 等") + print(" export LLM_MODEL=gpt-4o") + sys.exit(1) + + # 初始化 Agent + agent = DataAnalysisAgent(db_path) + + print("=" * 60) + print(" 🤖 数据分析 Agent —— 四层架构") + print("=" * 60) + print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}") + print(f"🔄 最大探索轮数: {MAX_EXPLORATION_ROUNDS}") + print(f"💾 数据库: {db_path}") + print(f"\n💬 输入分析问题(help 查看命令)\n") + + while True: + try: + user_input = input("📊 > ").strip() + except (EOFError, KeyboardInterrupt): + print("\n👋 再见!") + break + + if not user_input: + continue + + cmd = user_input.lower() + + if cmd in ("quit", "exit", "q"): + print("👋 再见!") + break + elif cmd == "help": + print_help() + continue + elif cmd == "schema": + print(agent.get_schema()) + continue + elif cmd == "history": + print(agent.get_history()) + continue + elif cmd == "audit": + print(agent.get_audit()) + continue + elif cmd == "clear": + agent.clear_history() + print("✅ 历史已清空") + continue + + # 解析可选参数:rounds=3 + max_rounds = MAX_EXPLORATION_ROUNDS + question = user_input + if "rounds=" in question.lower(): + parts = question.split("rounds=") + question = parts[0].strip() + try: + max_rounds = int(parts[1].strip().split()[0]) + except (ValueError, IndexError): + pass + + try: + report = agent.analyze(question, max_rounds=max_rounds) + print("\n" + report) + print("\n" + "~" * 60) + except Exception as e: + print(f"\n❌ 分析出错: {e}") + import traceback + traceback.print_exc() + + # 退出时显示审计 + print("\n📋 本次会话审计:") + print(agent.get_audit()) + + +if __name__ == "__main__": + main() diff --git a/config.py b/config.py new file mode 100644 index 0000000..f8c60b3 --- /dev/null +++ b/config.py @@ -0,0 +1,36 @@ +""" +配置文件 +""" +import os + +# LLM 配置(兼容 OpenAI API 格式,包括 Ollama / vLLM / DeepSeek 等) +LLM_CONFIG = { + "api_key": os.getenv("LLM_API_KEY", ""), + "base_url": os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + "model": os.getenv("LLM_MODEL", "gpt-4o"), +} + +# 沙箱安全规则 +SANDBOX_RULES = { + "max_result_rows": 1000, # 聚合结果最大行数 + "round_floats": 2, # 浮点数保留位数 + "suppress_small_n": 5, # 分组样本 < n 时模糊处理 + "banned_keywords": [ # 禁止的 SQL 关键字 + "SELECT *", + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "ATTACH", + "PRAGMA", + ], + "require_aggregation": True, # 是否要求使用聚合函数 +} + +# 数据库路径 +DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "demo.db")) + +# 分析控制 +MAX_EXPLORATION_ROUNDS = int(os.getenv("MAX_ROUNDS", "6")) # 最大探索轮数 diff --git a/context.py b/context.py new file mode 100644 index 0000000..faa22ef --- /dev/null +++ b/context.py @@ -0,0 +1,134 @@ +""" +Layer 4: 上下文管理器 +管理多轮对话的分析上下文,让后续问题可以引用之前的发现。 +""" +import json +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +from explorer import ExplorationStep +from insights import Insight + + +@dataclass +class AnalysisSession: + """一次分析的完整记录""" + question: str + plan: dict + steps: list[ExplorationStep] + insights: list[Insight] + report: str + timestamp: float = field(default_factory=time.time) + + def summary(self) -> str: + """生成本次分析的摘要(供后续对话引用)""" + parts = [f"**问题**: {self.question}"] + + if self.plan: + parts.append(f"**分析类型**: {self.plan.get('analysis_type', 'unknown')}") + parts.append(f"**关注维度**: {', '.join(self.plan.get('dimensions', []))}") + + # 核心发现(从成功步骤中提取) + key_findings = [] + for step in self.steps: + if step.success and step.rows: + # 提取最突出的值 + top_row = step.rows[0] if step.rows else {} + finding = f"{step.purpose}: " + finding += ", ".join( + f"{k}={v}" for k, v in top_row.items() if k.lower() not in ("id",) + ) + key_findings.append(finding) + + if key_findings: + parts.append("**核心发现**:") + for f in key_findings[:5]: + parts.append(f" - {f}") + + # 洞察 + if self.insights: + parts.append("**主动洞察**:") + for insight in self.insights[:3]: + parts.append(f" - {insight}") + + return "\n".join(parts) + + def to_reference_text(self) -> str: + """生成供 LLM 使用的上下文文本""" + return ( + f"## 之前的分析\n\n" + f"### 问题\n{self.question}\n\n" + f"### 摘要\n{self.summary()}\n\n" + f"### 详细发现\n" + + "\n".join( + f"- {step.purpose}: {step.row_count} 行结果" + for step in self.steps if step.success + ) + ) + + +class ContextManager: + """上下文管理器:管理多轮对话的分析历史""" + + def __init__(self, max_history: int = 10): + self.sessions: list[AnalysisSession] = [] + self.max_history = max_history + + def add_session( + self, + question: str, + plan: dict, + steps: list[ExplorationStep], + insights: list[Insight], + report: str, + ) -> AnalysisSession: + """记录一次分析""" + session = AnalysisSession( + question=question, + plan=plan, + steps=steps, + insights=insights, + report=report, + ) + self.sessions.append(session) + + # 保持历史大小 + if len(self.sessions) > self.max_history: + self.sessions = self.sessions[-self.max_history:] + + return session + + def get_context_for(self, new_question: str) -> Optional[str]: + """ + 根据新问题,从历史中找到相关上下文。 + 简单实现:取最近的分析会话。 + """ + if not self.sessions: + return None + + # 取最近 2 轮分析的摘要 + recent = self.sessions[-2:] + parts = [] + for session in recent: + parts.append(session.to_reference_text()) + + return "\n\n---\n\n".join(parts) + + def get_history_summary(self) -> str: + """获取所有历史的摘要""" + if not self.sessions: + return "(无历史分析)" + + lines = [f"共 {len(self.sessions)} 次分析:"] + for i, session in enumerate(self.sessions, 1): + ts = time.strftime("%H:%M", time.localtime(session.timestamp)) + lines.append(f" {i}. [{ts}] {session.question}") + if session.insights: + for insight in session.insights[:2]: + lines.append(f" {insight}") + return "\n".join(lines) + + def clear(self): + """清空历史""" + self.sessions.clear() diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..beb2c53 --- /dev/null +++ b/demo.py @@ -0,0 +1,160 @@ +""" +演示脚本 —— 创建示例数据,运行四层架构分析 +""" +import os +import sys +import sqlite3 +import random +from datetime import datetime, timedelta + +sys.path.insert(0, os.path.dirname(__file__)) + +from config import DB_PATH, LLM_CONFIG +from agent import DataAnalysisAgent + + +def create_demo_data(db_path: str): + """创建示例数据库:电商订单 + 用户 + 商品""" + if os.path.exists(db_path): + os.remove(db_path) + + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + cur.execute(""" + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + region TEXT NOT NULL, + tier TEXT NOT NULL, + created_at TEXT NOT NULL + ) + """) + + cur.execute(""" + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category TEXT NOT NULL, + price REAL NOT NULL + ) + """) + + cur.execute(""" + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + product_id INTEGER NOT NULL, + amount REAL NOT NULL, + quantity INTEGER NOT NULL, + status TEXT NOT NULL, + region TEXT NOT NULL, + order_date TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id), + FOREIGN KEY (product_id) REFERENCES products(id) + ) + """) + + regions = ["华东", "华南", "华北", "西南", "华中", "东北"] + tiers = ["普通", "银卡", "金卡", "钻石"] + categories = ["电子产品", "服饰", "食品", "家居", "图书", "美妆"] + statuses = ["已完成", "已发货", "待发货", "已取消", "退款中"] + + users = [] + for i in range(1, 2001): + reg = random.choice(regions) + tier = random.choices(tiers, weights=[50, 25, 15, 10])[0] + created = datetime(2024, 1, 1) + timedelta(days=random.randint(0, 400)) + users.append((i, f"用户{i}", reg, tier, created.strftime("%Y-%m-%d"))) + cur.executemany("INSERT INTO users VALUES (?, ?, ?, ?, ?)", users) + + products = [] + for i in range(1, 201): + cat = random.choice(categories) + price = round(random.uniform(9.9, 9999.99), 2) + products.append((i, f"商品{i}", cat, price)) + cur.executemany("INSERT INTO products VALUES (?, ?, ?, ?)", products) + + orders = [] + for i in range(1, 50001): + uid = random.randint(1, 2000) + pid = random.randint(1, 200) + qty = random.choices([1, 2, 3, 4, 5], weights=[50, 25, 15, 7, 3])[0] + status = random.choices(statuses, weights=[60, 15, 10, 10, 5])[0] + region = random.choices(regions, weights=[35, 25, 15, 10, 10, 5])[0] + order_date = datetime(2025, 1, 1) + timedelta(days=random.randint(0, 440)) + cur.execute("SELECT price FROM products WHERE id = ?", (pid,)) + price = cur.fetchone()[0] + amount = round(price * qty, 2) + orders.append((i, uid, pid, amount, qty, status, region, order_date.strftime("%Y-%m-%d"))) + cur.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?, ?, ?)", orders) + + conn.commit() + conn.close() + print(f"✅ 示例数据库已创建: {db_path}") + print(f" - users: 2,000 条") + print(f" - products: 200 条") + print(f" - orders: 50,000 条") + + +def main(): + print("=" * 60) + print(" 🤖 数据分析 Agent —— 四层架构自适应分析") + print("=" * 60) + + if not LLM_CONFIG["api_key"]: + print("\n⚠️ 未配置 LLM_API_KEY,请设置环境变量:") + print() + print(" # OpenAI") + print(" export LLM_API_KEY=sk-xxx") + print(" export LLM_BASE_URL=https://api.openai.com/v1") + print(" export LLM_MODEL=gpt-4o") + print() + print(" # Ollama (本地)") + print(" export LLM_API_KEY=ollama") + print(" export LLM_BASE_URL=http://localhost:11434/v1") + print(" export LLM_MODEL=qwen2.5-coder:7b") + print() + sys.exit(1) + + print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}") + + if not os.path.exists(DB_PATH): + create_demo_data(DB_PATH) + else: + print(f"\n📂 使用已有数据库: {DB_PATH}") + + # 初始化 Agent(自动加载 Schema) + agent = DataAnalysisAgent(DB_PATH) + + print("\n" + "=" * 60) + print(" ⬆️ AI 只看到 Schema:表结构 + 数据画像") + print(" ⬇️ 四层架构分析:规划 → 探索 → 洞察 → 报告") + print("=" * 60) + + questions = [ + "各地区的销售表现如何?帮我全面分析一下", + "不同商品类别的销售情况和利润贡献", + "整体订单的完成率怎么样?有没有什么异常需要关注?", + ] + + for q in questions: + try: + report = agent.analyze(q) + print("\n" + report) + print("\n" + "~" * 60) + except Exception as e: + print(f"\n❌ 分析出错: {e}") + import traceback + traceback.print_exc() + + # 最终审计 + print("\n📋 会话审计:") + print(agent.get_audit()) + print("\n📋 分析历史:") + print(agent.get_history()) + print("\n✅ 演示完成!") + + +if __name__ == "__main__": + main() diff --git a/explorer.py b/explorer.py new file mode 100644 index 0000000..1538655 --- /dev/null +++ b/explorer.py @@ -0,0 +1,238 @@ +""" +Layer 2: 自适应探索器 +基于分析计划 + 已有发现,动态决定下一步查什么。 +多轮迭代,直到 AI 判断"够了"或达到上限。 +""" +import json +import re +from typing import Any + +import openai +from config import LLM_CONFIG +from sandbox_executor import SandboxExecutor + + +EXPLORER_SYSTEM = """你是一个数据分析执行者。你的上级给了你一个分析计划,你需要通过迭代执行 SQL 来完成分析。 + +## 你的工作方式 +每一轮你看到: +1. 分析计划(上级给的目标) +2. 数据库 Schema(表结构、数据画像) +3. 之前的探索历史(查过什么、得到什么结果) + +你决定下一步: +- 输出一条 SQL 继续探索 +- 或者输出 done 表示分析足够 + +## 输出格式(严格 JSON) +```json +{ + "action": "query", + "reasoning": "为什么要做这个查询", + "sql": "SELECT ...", + "purpose": "这个查询的目的" +} +``` + +或: +```json +{ + "action": "done", + "reasoning": "为什么分析已经足够" +} +``` + +## SQL 规则 +- 只用 SELECT,必须有聚合函数或 GROUP BY +- 禁止 SELECT * +- 用 ROUND 控制精度 +- 合理使用 LIMIT(分组结果 15 行以内,时间序列 60 行以内) + +## 探索策略 +1. 第一轮:验证核心假设(计划中最关键的查询) +2. 后续轮:基于已有结果追问 + - 发现离群值 → 追问为什么 + - 发现异常比例 → 追问细分维度 + - 结果平淡 → 换个角度试试 +3. 不要重复查已经确认的事 +4. 每轮要有新发现,否则就该结束""" + + +EXPLORER_CONTINUE = """查询结果: + +{result_text} + +请基于这个结果决定下一步。如果发现异常或值得深挖的点,继续查询。如果分析足够,输出 done。""" + + +class ExplorationStep: + """单步探索结果""" + def __init__(self, round_num: int, decision: dict, result: dict): + self.round_num = round_num + self.reasoning = decision.get("reasoning", "") + self.purpose = decision.get("purpose", "") + self.sql = decision.get("sql", "") + self.action = decision.get("action", "query") + self.success = result.get("success", False) + self.error = result.get("error") + self.columns = result.get("columns", []) + self.rows = result.get("rows", []) + self.row_count = result.get("row_count", 0) + + def to_dict(self) -> dict: + d = { + "round": self.round_num, + "action": self.action, + "reasoning": self.reasoning, + "purpose": self.purpose, + "sql": self.sql, + "success": self.success, + } + if self.success: + d["result"] = { + "columns": self.columns, + "rows": self.rows, + "row_count": self.row_count, + } + else: + d["result"] = {"error": self.error} + return d + + +class Explorer: + """自适应探索器:多轮迭代执行 SQL""" + + def __init__(self, executor: SandboxExecutor): + self.executor = executor + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def explore( + self, + plan: dict, + schema_text: str, + max_rounds: int = 6, + ) -> list[ExplorationStep]: + """ + 执行探索循环 + + Args: + plan: Planner 生成的分析计划 + schema_text: Schema 文本描述 + max_rounds: 最大探索轮数 + + Returns: + 探步列表 + """ + steps: list[ExplorationStep] = [] + + # 构建初始消息 + messages = [ + {"role": "system", "content": EXPLORER_SYSTEM}, + { + "role": "user", + "content": ( + f"## 分析计划\n```json\n{json.dumps(plan, ensure_ascii=False, indent=2)}\n```\n\n" + f"## 数据库 Schema\n{schema_text}\n\n" + f"请开始第一轮探索。根据计划,先执行最关键的查询。" + ), + }, + ] + + for round_num in range(1, max_rounds + 1): + print(f"\n 🔄 探索第 {round_num}/{max_rounds} 轮") + + # LLM 决策 + decision = self._llm_decide(messages) + action = decision.get("action", "query") + reasoning = decision.get("reasoning", "") + + print(f" 💭 {reasoning[:80]}{'...' if len(reasoning) > 80 else ''}") + + if action == "done": + print(f" ✅ 探索完成") + steps.append(ExplorationStep(round_num, decision, {"success": True})) + break + + # 执行 SQL + sql = decision.get("sql", "") + purpose = decision.get("purpose", "") + + if not sql: + print(f" ⚠️ 未生成 SQL,跳过") + continue + + print(f" 📝 {purpose}") + result = self.executor.execute(sql) + + if result["success"]: + print(f" ✅ {result['row_count']} 行") + else: + print(f" ❌ {result['error']}") + + step = ExplorationStep(round_num, decision, result) + steps.append(step) + + # 更新对话历史 + messages.append({ + "role": "assistant", + "content": json.dumps(decision, ensure_ascii=False), + }) + + result_text = self._format_result(result) + messages.append({ + "role": "user", + "content": EXPLORER_CONTINUE.format(result_text=result_text), + }) + + return steps + + def _llm_decide(self, messages: list[dict]) -> dict: + """LLM 决策""" + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.2, + max_tokens=1024, + ) + content = response.choices[0].message.content.strip() + return self._extract_json(content) + + def _format_result(self, result: dict) -> str: + """格式化查询结果""" + if not result.get("success"): + return f"❌ 执行失败: {result.get('error', '未知错误')}" + + rows = result["rows"][:20] + return ( + f"✅ 查询成功,返回 {result['row_count']} 行\n" + f"列: {result['columns']}\n" + f"数据:\n{json.dumps(rows, ensure_ascii=False, indent=2)}" + ) + + def _extract_json(self, text: str) -> dict: + """从 LLM 输出提取 JSON""" + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: + match = re.search(pattern, text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + continue + + match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return {"action": "done", "reasoning": f"无法解析输出: {text[:100]}"} diff --git a/insights.py b/insights.py new file mode 100644 index 0000000..39639aa --- /dev/null +++ b/insights.py @@ -0,0 +1,239 @@ +""" +Layer 3: 洞察引擎 +对探索结果进行异常检测 + 主动洞察,输出用户没问但值得知道的事。 +""" +import json +import re +import statistics +from typing import Any + +import openai +from config import LLM_CONFIG +from explorer import ExplorationStep + + +INSIGHT_SYSTEM = """你是一个数据洞察专家。你会收到探索过程的所有结果,你需要: + +1. 从结果中发现异常和有趣现象 +2. 对比不同维度,找出差异 +3. 输出用户可能没问但值得知道的洞察 + +## 输出格式(严格 JSON 数组) +```json +[ + { + "type": "outlier" | "trend" | "distribution" | "correlation" | "recommendation", + "severity": "high" | "medium" | "low", + "title": "简短标题", + "detail": "详细描述,包含具体数字", + "evidence": "支撑这个洞察的数据来源" + } +] +``` + +## 洞察类型 +- outlier: 离群值(某个分组异常高/低) +- trend: 趋势发现(增长/下降、季节性) +- distribution: 分布异常(不均衡、集中度过高) +- correlation: 关联发现(两个维度的意外关联) +- recommendation: 行动建议(基于数据的建议) + +## 分析原则 +- 每个洞察必须有具体数字支撑 +- 用对比来说话(A 比 B 高 X%) +- 关注异常,不描述平淡的事实 +- 如果没有异常,返回空数组""" + + +class Insight: + """单条洞察""" + def __init__(self, data: dict): + self.type = data.get("type", "unknown") + self.severity = data.get("severity", "low") + self.title = data.get("title", "") + self.detail = data.get("detail", "") + self.evidence = data.get("evidence", "") + + @property + def emoji(self) -> str: + return { + "outlier": "⚠️", + "trend": "📈", + "distribution": "📊", + "correlation": "🔗", + "recommendation": "💡", + }.get(self.type, "📌") + + @property + def severity_emoji(self) -> str: + return {"high": "🔴", "medium": "🟡", "low": "🟢"}.get(self.severity, "") + + def __str__(self): + return f"{self.emoji} {self.severity_emoji} {self.title}: {self.detail}" + + +class InsightEngine: + """洞察引擎:自动检测异常 + 主动输出""" + + def __init__(self): + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def analyze(self, steps: list[ExplorationStep], question: str) -> list[Insight]: + """ + 对探索结果进行洞察分析 + + Args: + steps: 探索步骤列表 + question: 原始用户问题 + + Returns: + 洞察列表 + """ + if not steps: + return [] + + # 构建探索历史文本 + history = self._build_history(steps) + + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": INSIGHT_SYSTEM}, + { + "role": "user", + "content": ( + f"## 用户原始问题\n{question}\n\n" + f"## 探索历史\n{history}\n\n" + f"请分析以上数据,输出异常和洞察。" + ), + }, + ], + temperature=0.3, + max_tokens=2048, + ) + + content = response.choices[0].message.content.strip() + insights_data = self._extract_json_array(content) + + return [Insight(d) for d in insights_data] + + def format_insights(self, insights: list[Insight]) -> str: + """格式化洞察为可读文本""" + if not insights: + return "" + + # 按严重程度排序 + severity_order = {"high": 0, "medium": 1, "low": 2} + sorted_insights = sorted(insights, key=lambda i: severity_order.get(i.severity, 9)) + + lines = ["## 💡 主动洞察", ""] + lines.append("_以下是你没问但数据告诉我们的事:_\n") + + for insight in sorted_insights: + lines.append(f"**{insight.emoji} {insight.title}** {insight.severity_emoji}") + lines.append(f" {insight.detail}") + lines.append(f" _数据来源: {insight.evidence}_") + lines.append("") + + return "\n".join(lines) + + def _build_history(self, steps: list[ExplorationStep]) -> str: + """构建探索历史文本""" + parts = [] + for step in steps: + if step.action == "done": + parts.append(f"### 结束\n{step.reasoning}") + continue + + if step.success: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"思考: {step.reasoning}\n" + f"SQL: `{step.sql}`\n" + f"结果 ({step.row_count} 行):\n" + f"列: {step.columns}\n" + f"数据: {json.dumps(step.rows, ensure_ascii=False)}" + ) + else: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"SQL: `{step.sql}`\n" + f"结果: 执行失败 - {step.error}" + ) + + return "\n\n".join(parts) + + def _extract_json_array(self, text: str) -> list[dict]: + """从 LLM 输出提取 JSON 数组""" + try: + result = json.loads(text) + if isinstance(result, list): + return result + except json.JSONDecodeError: + pass + + for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: + match = re.search(pattern, text, re.DOTALL) + if match: + try: + result = json.loads(match.group(1)) + if isinstance(result, list): + return result + except json.JSONDecodeError: + continue + + # 找最外层 [] + match = re.search(r'\[.*\]', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return [] + + +# ── 基于规则的快速异常检测(无需 LLM)──────────────── + +def quick_detect(steps: list[ExplorationStep]) -> list[str]: + """ + 基于规则的快速异常检测,不调 LLM。 + 检测离群值、不均衡分布等。 + """ + alerts = [] + + for step in steps: + if not step.success or not step.rows: + continue + + for row in step.rows: + for col in step.columns: + val = row.get(col) + if not isinstance(val, (int, float)): + continue + + # 检查 pct 列:某个分组占比异常 + if col.lower() in ("pct", "percent", "percentage", "占比"): + if isinstance(val, (int, float)) and val > 50: + alerts.append( + f"⚠️ {step.purpose} 中某个分组占比 {val}%,超过 50%,集中度过高" + ) + + # 检查 count 列:极值差异 + if col.lower() in ("count", "cnt", "n", "total", "order_count"): + count_vals = [ + r.get(col) for r in step.rows + if isinstance(r.get(col), (int, float)) + ] + if len(count_vals) >= 3 and max(count_vals) > 0: + ratio = max(count_vals) / (sum(count_vals) / len(count_vals)) + if ratio > 3: + alerts.append( + f"⚠️ {step.purpose} 中最大值是均值的 {ratio:.1f} 倍,分布极不均衡" + ) + + return alerts diff --git a/planner.py b/planner.py new file mode 100644 index 0000000..6a671e8 --- /dev/null +++ b/planner.py @@ -0,0 +1,129 @@ +""" +Layer 1: 意图规划器 +将用户问题解析为结构化的分析计划,替代硬编码模板。 +""" +import json +import re +from typing import Any + +PROMPT = """你是一个数据分析规划专家。 + +## 你的任务 +根据用户的分析问题和数据库 Schema,生成一个结构化的分析计划。 + +## 输出格式(严格 JSON) +```json +{ + "intent": "一句话描述用户想了解什么", + "analysis_type": "ranking" | "distribution" | "trend" | "comparison" | "anomaly" | "overview", + "primary_table": "主要分析的表名", + "dimensions": ["分组维度列名"], + "metrics": ["需要聚合的数值列名"], + "aggregations": ["SUM", "AVG", "COUNT", ...], + "filters": [{"column": "列名", "condition": "过滤条件(可选)"}], + "join_needed": false, + "join_info": {"tables": [], "on": ""}, + "expected_rounds": 3, + "rationale": "为什么这样规划,需要关注什么" +} +``` + +## 分析类型说明 +- ranking: 按某维度排名(哪个地区最高) +- distribution: 分布/占比(各地区占比多少) +- trend: 时间趋势(月度变化) +- comparison: 对比分析(A vs B) +- anomaly: 异常检测(有没有异常值) +- overview: 全局概览(整体情况如何) + +## 规划原则 +1. 只选择与问题相关的表和列 +2. 如果需要 JOIN,说明关联条件 +3. 预估需要几轮探索(1-6) +4. 标注可能的异常关注点 +5. metrics 不要包含 id 列""" + +import openai +from config import LLM_CONFIG + + +class Planner: + """意图规划器:将自然语言问题转为结构化分析计划""" + + def __init__(self): + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def plan(self, question: str, schema_text: str) -> dict[str, Any]: + """ + 生成分析计划 + + Returns: + { + "intent": str, + "analysis_type": str, + "primary_table": str, + "dimensions": [str], + "metrics": [str], + "aggregations": [str], + "filters": [dict], + "join_needed": bool, + "join_info": dict, + "expected_rounds": int, + "rationale": str, + } + """ + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": PROMPT}, + { + "role": "user", + "content": ( + f"## Schema\n{schema_text}\n\n" + f"## 用户问题\n{question}" + ), + }, + ], + temperature=0.1, + max_tokens=1024, + ) + + content = response.choices[0].message.content.strip() + plan = self._extract_json(content) + + # 补充默认值 + plan.setdefault("analysis_type", "overview") + plan.setdefault("expected_rounds", 3) + plan.setdefault("filters", []) + plan.setdefault("join_needed", False) + + return plan + + def _extract_json(self, text: str) -> dict: + """从 LLM 输出提取 JSON""" + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: + match = re.search(pattern, text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + continue + + # 找最外层 {} + match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return {"intent": text[:100], "analysis_type": "overview"} diff --git a/reporter.py b/reporter.py new file mode 100644 index 0000000..5a29c53 --- /dev/null +++ b/reporter.py @@ -0,0 +1,110 @@ +""" +报告生成器 +将分析计划 + 探索结果 + 洞察,综合为一份可读报告。 +""" +import json +from typing import Any + +import openai +from config import LLM_CONFIG +from explorer import ExplorationStep +from insights import Insight + + +REPORT_PROMPT = """你是一个数据分析报告撰写专家。基于以下信息撰写报告。 + +## 用户问题 +{question} + +## 分析计划 +{plan} + +## 探索过程 +{exploration} + +## 主动洞察 +{insights_text} + +## 撰写要求 +1. **开头**:一句话总结核心结论(先给答案) +2. **核心发现**:按重要性排列,每个发现带具体数字 +3. **深入洞察**:异常、趋势、关联(从洞察中提取) +4. **建议**:基于数据的行动建议 +5. **审计**:末尾附上所有执行的 SQL + +使用 Markdown,中文撰写。 +语气:专业但不枯燥,像一个聪明的分析师在做简报。""" + + +class ReportGenerator: + """报告生成器""" + + def __init__(self): + self.client = openai.OpenAI( + api_key=LLM_CONFIG["api_key"], + base_url=LLM_CONFIG["base_url"], + ) + self.model = LLM_CONFIG["model"] + + def generate( + self, + question: str, + plan: dict, + steps: list[ExplorationStep], + insights: list[Insight], + ) -> str: + """生成分析报告""" + + # 构建探索过程文本 + exploration = self._build_exploration(steps) + + # 构建洞察文本 + insights_text = self._build_insights(insights) + + prompt = REPORT_PROMPT.format( + question=question, + plan=json.dumps(plan, ensure_ascii=False, indent=2), + exploration=exploration, + insights_text=insights_text, + ) + + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "你是专业的数据分析师,撰写清晰、有洞察力的分析报告。"}, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + max_tokens=4096, + ) + + return response.choices[0].message.content + + def _build_exploration(self, steps: list[ExplorationStep]) -> str: + parts = [] + for step in steps: + if step.action == "done": + parts.append(f"### 探索结束\n{step.reasoning}") + continue + + if step.success: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"思考: {step.reasoning}\n" + f"SQL: `{step.sql}`\n" + f"结果 ({step.row_count} 行):\n" + f"列: {step.columns}\n" + f"数据: {json.dumps(step.rows, ensure_ascii=False)}" + ) + else: + parts.append( + f"### 第 {step.round_num} 轮:{step.purpose}\n" + f"SQL: `{step.sql}`\n" + f"执行失败: {step.error}" + ) + return "\n\n".join(parts) if parts else "无探索步骤" + + def _build_insights(self, insights: list[Insight]) -> str: + if not insights: + return "未检测到异常。" + return "\n".join(str(i) for i in insights) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..aa2b704 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +openai>=1.0.0 diff --git a/sandbox_executor.py b/sandbox_executor.py new file mode 100644 index 0000000..c1094d6 --- /dev/null +++ b/sandbox_executor.py @@ -0,0 +1,139 @@ +""" +沙箱执行器 —— 执行 SQL,只返回聚合结果 +""" +import sqlite3 +import re +from typing import Any +from config import SANDBOX_RULES, DB_PATH + + +class SandboxError(Exception): + """沙箱安全违规""" + pass + + +class SandboxExecutor: + def __init__(self, db_path: str = DB_PATH): + self.db_path = db_path + self.rules = SANDBOX_RULES + self.execution_log: list[dict] = [] + + def execute(self, sql: str) -> dict[str, Any]: + """ + 执行 SQL,返回脱敏后的聚合结果。 + 如果违反安全规则,抛出 SandboxError。 + """ + # 验证 SQL 安全性 + self._validate(sql) + + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + cur.execute(sql) + rows = cur.fetchall() + + # 转为字典列表 + columns = [desc[0] for desc in cur.description] if cur.description else [] + results = [dict(row) for row in rows] + + # 脱敏处理 + sanitized = self._sanitize(results, columns) + + # 记录执行日志 + self.execution_log.append({ + "sql": sql, + "rows_returned": len(results), + "columns": columns, + }) + + return { + "success": True, + "columns": columns, + "rows": sanitized, + "row_count": len(sanitized), + "sql": sql, + } + + except sqlite3.Error as e: + return { + "success": False, + "error": str(e), + "sql": sql, + } + finally: + conn.close() + + def _validate(self, sql: str): + """SQL 安全验证""" + sql_upper = sql.upper().strip() + + # 1. 检查禁止的关键字 + for banned in self.rules["banned_keywords"]: + if banned.upper() in sql_upper: + raise SandboxError(f"禁止的 SQL 关键字: {banned}") + + # 2. 只允许 SELECT(不能有多语句) + statements = [s.strip() for s in sql.split(";") if s.strip()] + if len(statements) > 1: + raise SandboxError("禁止多语句执行") + if not sql_upper.startswith("SELECT"): + raise SandboxError("只允许 SELECT 查询") + + # 3. 检查是否使用了聚合函数或 GROUP BY(可选要求) + if self.rules["require_aggregation"]: + agg_keywords = ["COUNT", "SUM", "AVG", "MIN", "MAX", "GROUP BY", + "DISTINCT", "HAVING", "ROUND", "CAST"] + has_agg = any(kw in sql_upper for kw in agg_keywords) + has_limit = "LIMIT" in sql_upper + + if not has_agg and not has_limit: + raise SandboxError( + "要求使用聚合函数 (COUNT/SUM/AVG/MIN/MAX/GROUP BY) 或 LIMIT" + ) + + # 4. LIMIT 检查 + limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) + if limit_match: + limit_val = int(limit_match.group(1)) + if limit_val > self.rules["max_result_rows"]: + raise SandboxError( + f"LIMIT {limit_val} 超过最大允许值 {self.rules['max_result_rows']}" + ) + + def _sanitize(self, rows: list[dict], columns: list[str]) -> list[dict]: + """对结果进行脱敏处理""" + if not rows: + return rows + + # 1. 限制行数 + rows = rows[:self.rules["max_result_rows"]] + + # 2. 浮点数四舍五入 + for row in rows: + for col in columns: + val = row.get(col) + if isinstance(val, float): + row[col] = round(val, self.rules["round_floats"]) + + # 3. 小样本抑制(k-anonymity) + # 如果某个分组的 count 小于阈值,标记为 " str: + """获取执行摘要""" + if not self.execution_log: + return "尚未执行任何查询" + + lines = [f"共执行 {len(self.execution_log)} 条查询:"] + for i, log in enumerate(self.execution_log, 1): + lines.append(f" {i}. {log['sql'][:80]}... → {log['rows_returned']} 行结果") + return "\n".join(lines) diff --git a/schema_extractor.py b/schema_extractor.py new file mode 100644 index 0000000..90b6492 --- /dev/null +++ b/schema_extractor.py @@ -0,0 +1,126 @@ +""" +Schema 提取器 —— 只提取表结构,不碰数据 +""" +import sqlite3 +from typing import Any + + +def extract_schema(db_path: str) -> dict[str, Any]: + """ + 从数据库提取 Schema,只返回结构信息: + - 表名、列名、类型 + - 主键、外键 + - 行数 + - 枚举列的去重值(不含原始数据) + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + # 获取所有表 + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") + tables = [row["name"] for row in cur.fetchall()] + + schema = {"tables": []} + + for table in tables: + # 列信息 + cur.execute(f"PRAGMA table_info('{table}')") + columns = [] + for col in cur.fetchall(): + columns.append({ + "name": col["name"], + "type": col["type"], + "nullable": col["notnull"] == 0, + "is_primary_key": col["pk"] == 1, + }) + + # 外键 + cur.execute(f"PRAGMA foreign_key_list('{table}')") + fks = [] + for fk in cur.fetchall(): + fks.append({ + "column": fk["from"], + "references_table": fk["table"], + "references_column": fk["to"], + }) + + # 行数 + cur.execute(f"SELECT COUNT(*) AS cnt FROM '{table}'") + row_count = cur.fetchone()["cnt"] + + # 对 VARCHAR / TEXT 类型列,提取去重枚举值(最多 20 个) + data_profile = {} + for col in columns: + col_name = col["name"] + col_type = (col["type"] or "").upper() + + if any(t in col_type for t in ("VARCHAR", "TEXT", "CHAR")): + cur.execute(f'SELECT DISTINCT "{col_name}" FROM "{table}" WHERE "{col_name}" IS NOT NULL LIMIT 20') + vals = [row[0] for row in cur.fetchall()] + if len(vals) <= 20: + data_profile[col_name] = { + "type": "enum", + "distinct_count": len(vals), + "values": vals, + } + elif any(t in col_type for t in ("INT", "REAL", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC")): + cur.execute(f''' + SELECT MIN("{col_name}") AS min_val, MAX("{col_name}") AS max_val, + AVG("{col_name}") AS avg_val, COUNT(DISTINCT "{col_name}") AS distinct_count + FROM "{table}" WHERE "{col_name}" IS NOT NULL + ''') + row = cur.fetchone() + if row and row["min_val"] is not None: + data_profile[col_name] = { + "type": "numeric", + "min": round(row["min_val"], 2), + "max": round(row["max_val"], 2), + "avg": round(row["avg_val"], 2), + "distinct_count": row["distinct_count"], + } + + schema["tables"].append({ + "name": table, + "columns": columns, + "foreign_keys": fks, + "row_count": row_count, + "data_profile": data_profile, + }) + + conn.close() + return schema + + +def schema_to_text(schema: dict) -> str: + """将 Schema 转为可读文本,供 LLM 理解""" + lines = ["=== 数据库 Schema ===\n"] + + for table in schema["tables"]: + lines.append(f"📋 表: {table['name']} (共 {table['row_count']} 行)") + lines.append(" 列:") + for col in table["columns"]: + pk = " [PK]" if col["is_primary_key"] else "" + null = " NULL" if col["nullable"] else " NOT NULL" + lines.append(f' - {col["name"]}: {col["type"]}{pk}{null}') + + if table["foreign_keys"]: + lines.append(" 外键:") + for fk in table["foreign_keys"]: + lines.append(f' - {fk["column"]} → {fk["references_table"]}.{fk["references_column"]}') + + if table["data_profile"]: + lines.append(" 数据画像:") + for col_name, profile in table["data_profile"].items(): + if profile["type"] == "enum": + vals = ", ".join(str(v) for v in profile["values"][:10]) + lines.append(f' - {col_name}: 枚举值({profile["distinct_count"]}个) = [{vals}]') + elif profile["type"] == "numeric": + lines.append( + f' - {col_name}: 范围[{profile["min"]} ~ {profile["max"]}], ' + f'均值{profile["avg"]}, {profile["distinct_count"]}个不同值' + ) + + lines.append("") + + return "\n".join(lines)