- Layer 1 Planner: 意图规划,将问题转为结构化分析计划 - Layer 2 Explorer: 自适应探索循环,多轮迭代动态生成 SQL - Layer 3 InsightEngine: 异常检测 + 主动洞察 - Layer 4 ContextManager: 多轮对话上下文记忆 安全设计:AI 只看 Schema + 聚合结果,不接触原始数据。 支持任意 OpenAI 兼容 API(OpenAI / Ollama / DeepSeek / vLLM)
130 lines
3.9 KiB
Python
130 lines
3.9 KiB
Python
"""
|
||
Layer 1: 意图规划器
|
||
将用户问题解析为结构化的分析计划,替代硬编码模板。
|
||
"""
|
||
import json
|
||
import re
|
||
from typing import Any
|
||
|
||
PROMPT = """你是一个数据分析规划专家。
|
||
|
||
## 你的任务
|
||
根据用户的分析问题和数据库 Schema,生成一个结构化的分析计划。
|
||
|
||
## 输出格式(严格 JSON)
|
||
```json
|
||
{
|
||
"intent": "一句话描述用户想了解什么",
|
||
"analysis_type": "ranking" | "distribution" | "trend" | "comparison" | "anomaly" | "overview",
|
||
"primary_table": "主要分析的表名",
|
||
"dimensions": ["分组维度列名"],
|
||
"metrics": ["需要聚合的数值列名"],
|
||
"aggregations": ["SUM", "AVG", "COUNT", ...],
|
||
"filters": [{"column": "列名", "condition": "过滤条件(可选)"}],
|
||
"join_needed": false,
|
||
"join_info": {"tables": [], "on": ""},
|
||
"expected_rounds": 3,
|
||
"rationale": "为什么这样规划,需要关注什么"
|
||
}
|
||
```
|
||
|
||
## 分析类型说明
|
||
- ranking: 按某维度排名(哪个地区最高)
|
||
- distribution: 分布/占比(各地区占比多少)
|
||
- trend: 时间趋势(月度变化)
|
||
- comparison: 对比分析(A vs B)
|
||
- anomaly: 异常检测(有没有异常值)
|
||
- overview: 全局概览(整体情况如何)
|
||
|
||
## 规划原则
|
||
1. 只选择与问题相关的表和列
|
||
2. 如果需要 JOIN,说明关联条件
|
||
3. 预估需要几轮探索(1-6)
|
||
4. 标注可能的异常关注点
|
||
5. metrics 不要包含 id 列"""
|
||
|
||
import openai
|
||
from config import LLM_CONFIG
|
||
|
||
|
||
class Planner:
|
||
"""意图规划器:将自然语言问题转为结构化分析计划"""
|
||
|
||
def __init__(self):
|
||
self.client = openai.OpenAI(
|
||
api_key=LLM_CONFIG["api_key"],
|
||
base_url=LLM_CONFIG["base_url"],
|
||
)
|
||
self.model = LLM_CONFIG["model"]
|
||
|
||
def plan(self, question: str, schema_text: str) -> dict[str, Any]:
|
||
"""
|
||
生成分析计划
|
||
|
||
Returns:
|
||
{
|
||
"intent": str,
|
||
"analysis_type": str,
|
||
"primary_table": str,
|
||
"dimensions": [str],
|
||
"metrics": [str],
|
||
"aggregations": [str],
|
||
"filters": [dict],
|
||
"join_needed": bool,
|
||
"join_info": dict,
|
||
"expected_rounds": int,
|
||
"rationale": str,
|
||
}
|
||
"""
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{"role": "system", "content": PROMPT},
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
f"## Schema\n{schema_text}\n\n"
|
||
f"## 用户问题\n{question}"
|
||
),
|
||
},
|
||
],
|
||
temperature=0.1,
|
||
max_tokens=1024,
|
||
)
|
||
|
||
content = response.choices[0].message.content.strip()
|
||
plan = self._extract_json(content)
|
||
|
||
# 补充默认值
|
||
plan.setdefault("analysis_type", "overview")
|
||
plan.setdefault("expected_rounds", 3)
|
||
plan.setdefault("filters", [])
|
||
plan.setdefault("join_needed", False)
|
||
|
||
return plan
|
||
|
||
def _extract_json(self, text: str) -> dict:
|
||
"""从 LLM 输出提取 JSON"""
|
||
try:
|
||
return json.loads(text)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
|
||
match = re.search(pattern, text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
return json.loads(match.group(1))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 找最外层 {}
|
||
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
return json.loads(match.group())
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
return {"intent": text[:100], "analysis_type": "overview"}
|