Explorer 的 system prompt 明确告知 sandbox 规则 — "每条 SQL 必须包含聚合函数或 LIMIT",减少 LLM 生成违规 SQL 浪费轮次 LLM 客户端单例 — 所有组件共享一个 openai.OpenAI 实例,不再各建各的 sanitize 顺序修复 — 小样本抑制放在 float round 之前,避免被 round 干扰 quick_detect 从 O(n²) 改为 O(n) — 按列聚合一次,加去重,不再对每行重复算整列统计 历史上下文实际生效 — get_context_for 的结果现在会注入到 Explorer 的初始 prompt 里,多轮分析时 LLM 能看到之前的发现
162 lines
5.3 KiB
Python
162 lines
5.3 KiB
Python
"""
|
||
演示脚本 —— 创建示例数据,运行四层架构分析
|
||
"""
|
||
import os
|
||
import sys
|
||
import sqlite3
|
||
import random
|
||
from datetime import datetime, timedelta
|
||
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
|
||
from core.config import DB_PATH, LLM_CONFIG
|
||
from agent import DataAnalysisAgent
|
||
|
||
|
||
def create_demo_data(db_path: str):
|
||
"""创建示例数据库:电商订单 + 用户 + 商品"""
|
||
if os.path.exists(db_path):
|
||
os.remove(db_path)
|
||
|
||
conn = sqlite3.connect(db_path)
|
||
cur = conn.cursor()
|
||
|
||
cur.execute("""
|
||
CREATE TABLE users (
|
||
id INTEGER PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
region TEXT NOT NULL,
|
||
tier TEXT NOT NULL,
|
||
created_at TEXT NOT NULL
|
||
)
|
||
""")
|
||
|
||
cur.execute("""
|
||
CREATE TABLE products (
|
||
id INTEGER PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
category TEXT NOT NULL,
|
||
price REAL NOT NULL
|
||
)
|
||
""")
|
||
|
||
cur.execute("""
|
||
CREATE TABLE orders (
|
||
id INTEGER PRIMARY KEY,
|
||
user_id INTEGER NOT NULL,
|
||
product_id INTEGER NOT NULL,
|
||
amount REAL NOT NULL,
|
||
quantity INTEGER NOT NULL,
|
||
status TEXT NOT NULL,
|
||
region TEXT NOT NULL,
|
||
order_date TEXT NOT NULL,
|
||
FOREIGN KEY (user_id) REFERENCES users(id),
|
||
FOREIGN KEY (product_id) REFERENCES products(id)
|
||
)
|
||
""")
|
||
|
||
regions = ["华东", "华南", "华北", "西南", "华中", "东北"]
|
||
tiers = ["普通", "银卡", "金卡", "钻石"]
|
||
categories = ["电子产品", "服饰", "食品", "家居", "图书", "美妆"]
|
||
statuses = ["已完成", "已发货", "待发货", "已取消", "退款中"]
|
||
|
||
users = []
|
||
for i in range(1, 2001):
|
||
reg = random.choice(regions)
|
||
tier = random.choices(tiers, weights=[50, 25, 15, 10])[0]
|
||
created = datetime(2024, 1, 1) + timedelta(days=random.randint(0, 400))
|
||
users.append((i, f"用户{i}", reg, tier, created.strftime("%Y-%m-%d")))
|
||
cur.executemany("INSERT INTO users VALUES (?, ?, ?, ?, ?)", users)
|
||
|
||
products = []
|
||
for i in range(1, 201):
|
||
cat = random.choice(categories)
|
||
price = round(random.uniform(9.9, 9999.99), 2)
|
||
products.append((i, f"商品{i}", cat, price))
|
||
cur.executemany("INSERT INTO products VALUES (?, ?, ?, ?)", products)
|
||
|
||
orders = []
|
||
for i in range(1, 50001):
|
||
uid = random.randint(1, 2000)
|
||
pid = random.randint(1, 200)
|
||
qty = random.choices([1, 2, 3, 4, 5], weights=[50, 25, 15, 7, 3])[0]
|
||
status = random.choices(statuses, weights=[60, 15, 10, 10, 5])[0]
|
||
region = random.choices(regions, weights=[35, 25, 15, 10, 10, 5])[0]
|
||
order_date = datetime(2025, 1, 1) + timedelta(days=random.randint(0, 440))
|
||
cur.execute("SELECT price FROM products WHERE id = ?", (pid,))
|
||
price = cur.fetchone()[0]
|
||
amount = round(price * qty, 2)
|
||
orders.append((i, uid, pid, amount, qty, status, region, order_date.strftime("%Y-%m-%d")))
|
||
cur.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?, ?, ?)", orders)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
print(f"✅ 示例数据库已创建: {db_path}")
|
||
print(f" - users: 2,000 条")
|
||
print(f" - products: 200 条")
|
||
print(f" - orders: 50,000 条")
|
||
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print(" 🤖 数据分析 Agent —— 四层架构自适应分析")
|
||
print("=" * 60)
|
||
|
||
if not LLM_CONFIG["api_key"]:
|
||
print("\n⚠️ 未配置 LLM_API_KEY,请设置环境变量:")
|
||
print()
|
||
print(" # OpenAI")
|
||
print(" export LLM_API_KEY=sk-xxx")
|
||
print(" export LLM_BASE_URL=https://api.openai.com/v1")
|
||
print(" export LLM_MODEL=gpt-4o")
|
||
print()
|
||
print(" # Ollama (本地)")
|
||
print(" export LLM_API_KEY=ollama")
|
||
print(" export LLM_BASE_URL=http://localhost:11434/v1")
|
||
print(" export LLM_MODEL=qwen2.5-coder:7b")
|
||
print()
|
||
sys.exit(1)
|
||
|
||
print(f"\n🔗 LLM: {LLM_CONFIG['model']} @ {LLM_CONFIG['base_url']}")
|
||
|
||
if not os.path.exists(DB_PATH):
|
||
create_demo_data(DB_PATH)
|
||
else:
|
||
print(f"\n📂 使用已有数据库: {DB_PATH}")
|
||
|
||
# 初始化 Agent(自动加载 Schema)
|
||
agent = DataAnalysisAgent(DB_PATH)
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" ⬆️ AI 只看到 Schema:表结构 + 数据画像")
|
||
print(" ⬇️ 四层架构分析:规划 → 预设匹配 → 探索 → 洞察 → 报告")
|
||
print(f" 📋 已加载 {len(agent.playbook_mgr.playbooks)} 个预设剧本")
|
||
print("=" * 60)
|
||
|
||
questions = [
|
||
"各地区的销售表现如何?帮我全面分析一下",
|
||
"不同商品类别的销售情况和利润贡献",
|
||
"整体订单的完成率怎么样?有没有什么异常需要关注?",
|
||
]
|
||
|
||
for q in questions:
|
||
try:
|
||
report = agent.analyze(q)
|
||
print("\n" + report)
|
||
print("\n" + "~" * 60)
|
||
except Exception as e:
|
||
print(f"\n❌ 分析出错: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 最终审计
|
||
print("\n📋 会话审计:")
|
||
print(agent.get_audit())
|
||
print("\n📋 分析历史:")
|
||
print(agent.get_history())
|
||
print("\n✅ 演示完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|