225 lines
8.4 KiB
Python
225 lines
8.4 KiB
Python
|
|
"""
|
|||
|
|
Layer 2: 自适应探索器
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
from typing import Any
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
|
|||
|
|
from core.config import LLM_CONFIG
|
|||
|
|
from core.utils import get_llm_client, extract_json_object
|
|||
|
|
from core.sandbox import SandboxExecutor
|
|||
|
|
|
|||
|
|
|
|||
|
|
EXPLORER_SYSTEM = """你是一个数据分析执行者。你的上级给了你一个分析计划,你需要通过迭代执行 SQL 来完成分析。
|
|||
|
|
|
|||
|
|
## 你的工作方式
|
|||
|
|
每一轮你看到:
|
|||
|
|
1. 分析计划(上级给的目标)
|
|||
|
|
2. 数据库 Schema(表结构、数据画像)
|
|||
|
|
3. 之前的探索历史(查过什么、得到什么结果)
|
|||
|
|
|
|||
|
|
你决定下一步:
|
|||
|
|
- 输出一条 SQL 继续探索
|
|||
|
|
- 或者输出 done 表示分析足够
|
|||
|
|
|
|||
|
|
## 输出格式(严格 JSON)
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"action": "query",
|
|||
|
|
"reasoning": "为什么要做这个查询",
|
|||
|
|
"sql": "SELECT ...",
|
|||
|
|
"purpose": "这个查询的目的"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
或:
|
|||
|
|
```json
|
|||
|
|
{
|
|||
|
|
"action": "done",
|
|||
|
|
"reasoning": "为什么分析已经足够"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## SQL 规则(严格遵守,否则会被沙箱拒绝)
|
|||
|
|
- 只用 SELECT
|
|||
|
|
- 每条 SQL 必须包含聚合函数(COUNT/SUM/AVG/MIN/MAX)或 GROUP BY 或 LIMIT
|
|||
|
|
- 禁止 SELECT *
|
|||
|
|
- 用 ROUND 控制精度
|
|||
|
|
- 合理使用 LIMIT(分组结果 15 行以内,时间序列 60 行以内)
|
|||
|
|
- 如果需要查看明细数据,必须加 LIMIT
|
|||
|
|
|
|||
|
|
## 探索策略
|
|||
|
|
1. 第一轮:验证核心假设
|
|||
|
|
2. 后续轮:基于已有结果追问
|
|||
|
|
3. 不要重复查已经确认的事
|
|||
|
|
4. 每轮要有新发现,否则就该结束"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ExplorationStep:
|
|||
|
|
"""单步探索结果"""
|
|||
|
|
round_num: int = 0
|
|||
|
|
reasoning: str = ""
|
|||
|
|
purpose: str = ""
|
|||
|
|
sql: str = ""
|
|||
|
|
action: str = "query"
|
|||
|
|
success: bool = False
|
|||
|
|
error: str | None = None
|
|||
|
|
columns: list[str] = field(default_factory=list)
|
|||
|
|
rows: list[dict] = field(default_factory=list)
|
|||
|
|
row_count: int = 0
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_decision(cls, round_num: int, decision: dict, result: dict) -> "ExplorationStep":
|
|||
|
|
return cls(
|
|||
|
|
round_num=round_num,
|
|||
|
|
reasoning=decision.get("reasoning", ""),
|
|||
|
|
purpose=decision.get("purpose", ""),
|
|||
|
|
sql=decision.get("sql", ""),
|
|||
|
|
action=decision.get("action", "query"),
|
|||
|
|
success=result.get("success", False),
|
|||
|
|
error=result.get("error"),
|
|||
|
|
columns=result.get("columns", []),
|
|||
|
|
rows=result.get("rows", []),
|
|||
|
|
row_count=result.get("row_count", 0),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict:
|
|||
|
|
d = {
|
|||
|
|
"round": self.round_num, "action": self.action,
|
|||
|
|
"reasoning": self.reasoning, "purpose": self.purpose,
|
|||
|
|
"sql": self.sql, "success": self.success,
|
|||
|
|
}
|
|||
|
|
if self.success:
|
|||
|
|
d["result"] = {"columns": self.columns, "rows": self.rows, "row_count": self.row_count}
|
|||
|
|
else:
|
|||
|
|
d["result"] = {"error": self.error}
|
|||
|
|
return d
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Explorer:
|
|||
|
|
"""自适应探索器"""
|
|||
|
|
|
|||
|
|
def __init__(self, executor: SandboxExecutor):
|
|||
|
|
self.executor = executor
|
|||
|
|
self.client, self.model = get_llm_client(LLM_CONFIG)
|
|||
|
|
|
|||
|
|
def explore(
|
|||
|
|
self, plan: dict, schema_text: str,
|
|||
|
|
max_rounds: int = 6, playbook_result: dict | None = None,
|
|||
|
|
) -> list[ExplorationStep]:
|
|||
|
|
steps: list[ExplorationStep] = []
|
|||
|
|
|
|||
|
|
# 阶段 A: 预设查询
|
|||
|
|
preset_context = ""
|
|||
|
|
if playbook_result and playbook_result.get("preset_queries"):
|
|||
|
|
preset_steps = self._run_preset_queries(playbook_result["preset_queries"])
|
|||
|
|
steps.extend(preset_steps)
|
|||
|
|
preset_context = self._build_preset_context(preset_steps, playbook_result)
|
|||
|
|
|
|||
|
|
# 阶段 B: 自适应探索
|
|||
|
|
preset_used = len([s for s in steps if s.success])
|
|||
|
|
remaining = max(1, max_rounds - preset_used)
|
|||
|
|
|
|||
|
|
initial = (
|
|||
|
|
f"## 分析计划\n```json\n{json.dumps(plan, ensure_ascii=False, indent=2)}\n```\n\n"
|
|||
|
|
f"## 数据库 Schema\n{schema_text}\n\n"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 注入历史上下文
|
|||
|
|
prev_context = plan.pop("_prev_context", None)
|
|||
|
|
if prev_context:
|
|||
|
|
initial += f"## 历史分析参考\n{prev_context}\n\n"
|
|||
|
|
|
|||
|
|
if preset_context:
|
|||
|
|
initial += (
|
|||
|
|
f"## 预设分析结果(已执行)\n{preset_context}\n\n"
|
|||
|
|
f"请基于这些已有数据,决定是否需要进一步探索。\n"
|
|||
|
|
f"重点关注:预设结果中的异常、值得深挖的点。\n"
|
|||
|
|
f"如果预设结果已经足够,直接输出 done。"
|
|||
|
|
)
|
|||
|
|
if playbook_result.get("exploration_hints"):
|
|||
|
|
initial += f"\n\n## 探索提示\n{playbook_result['exploration_hints']}"
|
|||
|
|
else:
|
|||
|
|
initial += "请开始第一轮探索。根据计划,先执行最关键的查询。"
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": EXPLORER_SYSTEM},
|
|||
|
|
{"role": "user", "content": initial},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
offset = len(steps)
|
|||
|
|
for round_num in range(offset + 1, offset + remaining + 1):
|
|||
|
|
print(f"\n 🔄 探索第 {round_num}/{max_rounds} 轮")
|
|||
|
|
|
|||
|
|
decision = self._llm_decide(messages)
|
|||
|
|
reasoning = decision.get("reasoning", "")
|
|||
|
|
print(f" 💭 {reasoning[:80]}{'...' if len(reasoning) > 80 else ''}")
|
|||
|
|
|
|||
|
|
if decision.get("action") == "done":
|
|||
|
|
print(f" ✅ 探索完成")
|
|||
|
|
steps.append(ExplorationStep.from_decision(round_num, decision, {"success": True}))
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
sql = decision.get("sql", "")
|
|||
|
|
if not sql:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f" 📝 {decision.get('purpose', '')}")
|
|||
|
|
try:
|
|||
|
|
result = self.executor.execute(sql)
|
|||
|
|
except Exception as e:
|
|||
|
|
result = {"success": False, "error": str(e), "sql": sql}
|
|||
|
|
print(f" {'✅' if result['success'] else '❌'} {result.get('row_count', result.get('error', ''))}")
|
|||
|
|
|
|||
|
|
steps.append(ExplorationStep.from_decision(round_num, decision, result))
|
|||
|
|
|
|||
|
|
messages.append({"role": "assistant", "content": json.dumps(decision, ensure_ascii=False)})
|
|||
|
|
messages.append({"role": "user", "content": self._format_result(result)})
|
|||
|
|
|
|||
|
|
return steps
|
|||
|
|
|
|||
|
|
def _run_preset_queries(self, preset_queries: list[dict]) -> list[ExplorationStep]:
|
|||
|
|
steps = []
|
|||
|
|
for i, pq in enumerate(preset_queries, 1):
|
|||
|
|
sql, purpose = pq["sql"], pq.get("purpose", f"预设查询 {i}")
|
|||
|
|
print(f"\n 📌 预设查询 {i}/{len(preset_queries)}: {purpose}")
|
|||
|
|
try:
|
|||
|
|
result = self.executor.execute(sql)
|
|||
|
|
except Exception as e:
|
|||
|
|
result = {"success": False, "error": str(e), "sql": sql}
|
|||
|
|
decision = {"action": "query", "reasoning": f"[预设] {purpose}", "sql": sql, "purpose": purpose}
|
|||
|
|
steps.append(ExplorationStep.from_decision(i, decision, result))
|
|||
|
|
print(f" {'✅' if result['success'] else '❌'} {result.get('row_count', result.get('error', ''))}")
|
|||
|
|
return steps
|
|||
|
|
|
|||
|
|
def _build_preset_context(self, steps: list[ExplorationStep], playbook_result: dict) -> str:
|
|||
|
|
parts = [f"Playbook: {playbook_result.get('playbook_name', '未知')}"]
|
|||
|
|
for step in steps:
|
|||
|
|
if step.success:
|
|||
|
|
parts.append(
|
|||
|
|
f"### {step.purpose}\nSQL: `{step.sql}`\n"
|
|||
|
|
f"结果 ({step.row_count} 行): {json.dumps(step.rows[:15], ensure_ascii=False)}"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
parts.append(f"### {step.purpose}\nSQL: `{step.sql}`\n执行失败: {step.error}")
|
|||
|
|
return "\n\n".join(parts)
|
|||
|
|
|
|||
|
|
def _llm_decide(self, messages: list[dict]) -> dict:
|
|||
|
|
response = self.client.chat.completions.create(
|
|||
|
|
model=self.model, messages=messages, temperature=0.2, max_tokens=1024,
|
|||
|
|
)
|
|||
|
|
content = response.choices[0].message.content.strip()
|
|||
|
|
result = extract_json_object(content)
|
|||
|
|
return result if result else {"action": "done", "reasoning": f"无法解析: {content[:100]}"}
|
|||
|
|
|
|||
|
|
def _format_result(self, result: dict) -> str:
|
|||
|
|
if not result.get("success"):
|
|||
|
|
return f"❌ 执行失败: {result.get('error', '未知错误')}"
|
|||
|
|
rows = result["rows"][:20]
|
|||
|
|
return (
|
|||
|
|
f"查询结果:\n\n✅ 返回 {result['row_count']} 行\n"
|
|||
|
|
f"列: {result['columns']}\n数据:\n{json.dumps(rows, ensure_ascii=False, indent=2)}\n\n"
|
|||
|
|
f"请基于这个结果决定下一步。如果发现异常或值得深挖的点,继续查询。如果分析足够,输出 done。"
|
|||
|
|
)
|