239 lines
7.2 KiB
Python
239 lines
7.2 KiB
Python
|
|
"""
|
|||
|
|
Layer 2: 自适应探索器
|
|||
|
|
基于分析计划 + 已有发现,动态决定下一步查什么。
|
|||
|
|
多轮迭代,直到 AI 判断"够了"或达到上限。
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
import re
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
import openai
|
|||
|
|
from config import LLM_CONFIG
|
|||
|
|
from sandbox_executor import SandboxExecutor
|
|||
|
|
|
|||
|
|
|
|||
|
|
EXPLORER_SYSTEM = """你是一个数据分析执行者。你的上级给了你一个分析计划,你需要通过迭代执行 SQL 来完成分析。
|
|||
|
|
|
|||
|
|
## 你的工作方式
|
|||
|
|
每一轮你看到:
|
|||
|
|
1. 分析计划(上级给的目标)
|
|||
|
|
2. 数据库 Schema(表结构、数据画像)
|
|||
|
|
3. 之前的探索历史(查过什么、得到什么结果)
|
|||
|
|
|
|||
|
|
你决定下一步:
|
|||
|
|
- 输出一条 SQL 继续探索
|
|||
|
|
- 或者输出 done 表示分析足够
|
|||
|
|
|
|||
|
|
## 输出格式(严格 JSON)
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"action": "query",
|
|||
|
|
"reasoning": "为什么要做这个查询",
|
|||
|
|
"sql": "SELECT ...",
|
|||
|
|
"purpose": "这个查询的目的"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
或:
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"action": "done",
|
|||
|
|
"reasoning": "为什么分析已经足够"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## SQL 规则
|
|||
|
|
- 只用 SELECT,必须有聚合函数或 GROUP BY
|
|||
|
|
- 禁止 SELECT *
|
|||
|
|
- 用 ROUND 控制精度
|
|||
|
|
- 合理使用 LIMIT(分组结果 15 行以内,时间序列 60 行以内)
|
|||
|
|
|
|||
|
|
## 探索策略
|
|||
|
|
1. 第一轮:验证核心假设(计划中最关键的查询)
|
|||
|
|
2. 后续轮:基于已有结果追问
|
|||
|
|
- 发现离群值 → 追问为什么
|
|||
|
|
- 发现异常比例 → 追问细分维度
|
|||
|
|
- 结果平淡 → 换个角度试试
|
|||
|
|
3. 不要重复查已经确认的事
|
|||
|
|
4. 每轮要有新发现,否则就该结束"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
EXPLORER_CONTINUE = """查询结果:
|
|||
|
|
|
|||
|
|
{result_text}
|
|||
|
|
|
|||
|
|
请基于这个结果决定下一步。如果发现异常或值得深挖的点,继续查询。如果分析足够,输出 done。"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExplorationStep:
|
|||
|
|
"""单步探索结果"""
|
|||
|
|
def __init__(self, round_num: int, decision: dict, result: dict):
|
|||
|
|
self.round_num = round_num
|
|||
|
|
self.reasoning = decision.get("reasoning", "")
|
|||
|
|
self.purpose = decision.get("purpose", "")
|
|||
|
|
self.sql = decision.get("sql", "")
|
|||
|
|
self.action = decision.get("action", "query")
|
|||
|
|
self.success = result.get("success", False)
|
|||
|
|
self.error = result.get("error")
|
|||
|
|
self.columns = result.get("columns", [])
|
|||
|
|
self.rows = result.get("rows", [])
|
|||
|
|
self.row_count = result.get("row_count", 0)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict:
|
|||
|
|
d = {
|
|||
|
|
"round": self.round_num,
|
|||
|
|
"action": self.action,
|
|||
|
|
"reasoning": self.reasoning,
|
|||
|
|
"purpose": self.purpose,
|
|||
|
|
"sql": self.sql,
|
|||
|
|
"success": self.success,
|
|||
|
|
}
|
|||
|
|
if self.success:
|
|||
|
|
d["result"] = {
|
|||
|
|
"columns": self.columns,
|
|||
|
|
"rows": self.rows,
|
|||
|
|
"row_count": self.row_count,
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
d["result"] = {"error": self.error}
|
|||
|
|
return d
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Explorer:
|
|||
|
|
"""自适应探索器:多轮迭代执行 SQL"""
|
|||
|
|
|
|||
|
|
def __init__(self, executor: SandboxExecutor):
|
|||
|
|
self.executor = executor
|
|||
|
|
self.client = openai.OpenAI(
|
|||
|
|
api_key=LLM_CONFIG["api_key"],
|
|||
|
|
base_url=LLM_CONFIG["base_url"],
|
|||
|
|
)
|
|||
|
|
self.model = LLM_CONFIG["model"]
|
|||
|
|
|
|||
|
|
def explore(
|
|||
|
|
self,
|
|||
|
|
plan: dict,
|
|||
|
|
schema_text: str,
|
|||
|
|
max_rounds: int = 6,
|
|||
|
|
) -> list[ExplorationStep]:
|
|||
|
|
"""
|
|||
|
|
执行探索循环
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
plan: Planner 生成的分析计划
|
|||
|
|
schema_text: Schema 文本描述
|
|||
|
|
max_rounds: 最大探索轮数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
探步列表
|
|||
|
|
"""
|
|||
|
|
steps: list[ExplorationStep] = []
|
|||
|
|
|
|||
|
|
# 构建初始消息
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": EXPLORER_SYSTEM},
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": (
|
|||
|
|
f"## 分析计划\n```json\n{json.dumps(plan, ensure_ascii=False, indent=2)}\n```\n\n"
|
|||
|
|
f"## 数据库 Schema\n{schema_text}\n\n"
|
|||
|
|
f"请开始第一轮探索。根据计划,先执行最关键的查询。"
|
|||
|
|
),
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for round_num in range(1, max_rounds + 1):
|
|||
|
|
print(f"\n 🔄 探索第 {round_num}/{max_rounds} 轮")
|
|||
|
|
|
|||
|
|
# LLM 决策
|
|||
|
|
decision = self._llm_decide(messages)
|
|||
|
|
action = decision.get("action", "query")
|
|||
|
|
reasoning = decision.get("reasoning", "")
|
|||
|
|
|
|||
|
|
print(f" 💭 {reasoning[:80]}{'...' if len(reasoning) > 80 else ''}")
|
|||
|
|
|
|||
|
|
if action == "done":
|
|||
|
|
print(f" ✅ 探索完成")
|
|||
|
|
steps.append(ExplorationStep(round_num, decision, {"success": True}))
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 执行 SQL
|
|||
|
|
sql = decision.get("sql", "")
|
|||
|
|
purpose = decision.get("purpose", "")
|
|||
|
|
|
|||
|
|
if not sql:
|
|||
|
|
print(f" ⚠️ 未生成 SQL,跳过")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f" 📝 {purpose}")
|
|||
|
|
result = self.executor.execute(sql)
|
|||
|
|
|
|||
|
|
if result["success"]:
|
|||
|
|
print(f" ✅ {result['row_count']} 行")
|
|||
|
|
else:
|
|||
|
|
print(f" ❌ {result['error']}")
|
|||
|
|
|
|||
|
|
step = ExplorationStep(round_num, decision, result)
|
|||
|
|
steps.append(step)
|
|||
|
|
|
|||
|
|
# 更新对话历史
|
|||
|
|
messages.append({
|
|||
|
|
"role": "assistant",
|
|||
|
|
"content": json.dumps(decision, ensure_ascii=False),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
result_text = self._format_result(result)
|
|||
|
|
messages.append({
|
|||
|
|
"role": "user",
|
|||
|
|
"content": EXPLORER_CONTINUE.format(result_text=result_text),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return steps
|
|||
|
|
|
|||
|
|
def _llm_decide(self, messages: list[dict]) -> dict:
|
|||
|
|
"""LLM 决策"""
|
|||
|
|
response = self.client.chat.completions.create(
|
|||
|
|
model=self.model,
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=0.2,
|
|||
|
|
max_tokens=1024,
|
|||
|
|
)
|
|||
|
|
content = response.choices[0].message.content.strip()
|
|||
|
|
return self._extract_json(content)
|
|||
|
|
|
|||
|
|
def _format_result(self, result: dict) -> str:
|
|||
|
|
"""格式化查询结果"""
|
|||
|
|
if not result.get("success"):
|
|||
|
|
return f"❌ 执行失败: {result.get('error', '未知错误')}"
|
|||
|
|
|
|||
|
|
rows = result["rows"][:20]
|
|||
|
|
return (
|
|||
|
|
f"✅ 查询成功,返回 {result['row_count']} 行\n"
|
|||
|
|
f"列: {result['columns']}\n"
|
|||
|
|
f"数据:\n{json.dumps(rows, ensure_ascii=False, indent=2)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _extract_json(self, text: str) -> dict:
|
|||
|
|
"""从 LLM 输出提取 JSON"""
|
|||
|
|
try:
|
|||
|
|
return json.loads(text)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
|||
|
|
match = re.search(pattern, text, re.DOTALL)
|
|||
|
|
if match:
|
|||
|
|
try:
|
|||
|
|
return json.loads(match.group(1))
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
|||
|
|
if match:
|
|||
|
|
try:
|
|||
|
|
return json.loads(match.group())
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return {"action": "done", "reasoning": f"无法解析输出: {text[:100]}"}
|