Files
iov_ana/planner.py

130 lines
3.9 KiB
Python
Raw Normal View History

"""
Layer 1: 意图规划器
将用户问题解析为结构化的分析计划替代硬编码模板
"""
import json
import re
from typing import Any
PROMPT = """你是一个数据分析规划专家。
## 你的任务
根据用户的分析问题和数据库 Schema生成一个结构化的分析计划
## 输出格式(严格 JSON
```json
{
"intent": "一句话描述用户想了解什么",
"analysis_type": "ranking" | "distribution" | "trend" | "comparison" | "anomaly" | "overview",
"primary_table": "主要分析的表名",
"dimensions": ["分组维度列名"],
"metrics": ["需要聚合的数值列名"],
"aggregations": ["SUM", "AVG", "COUNT", ...],
"filters": [{"column": "列名", "condition": "过滤条件(可选)"}],
"join_needed": false,
"join_info": {"tables": [], "on": ""},
"expected_rounds": 3,
"rationale": "为什么这样规划,需要关注什么"
}
```
## 分析类型说明
- ranking: 按某维度排名哪个地区最高
- distribution: 分布/占比各地区占比多少
- trend: 时间趋势月度变化
- comparison: 对比分析A vs B
- anomaly: 异常检测有没有异常值
- overview: 全局概览整体情况如何
## 规划原则
1. 只选择与问题相关的表和列
2. 如果需要 JOIN说明关联条件
3. 预估需要几轮探索1-6
4. 标注可能的异常关注点
5. metrics 不要包含 id """
import openai
from config import LLM_CONFIG
class Planner:
"""意图规划器:将自然语言问题转为结构化分析计划"""
def __init__(self):
self.client = openai.OpenAI(
api_key=LLM_CONFIG["api_key"],
base_url=LLM_CONFIG["base_url"],
)
self.model = LLM_CONFIG["model"]
def plan(self, question: str, schema_text: str) -> dict[str, Any]:
"""
生成分析计划
Returns:
{
"intent": str,
"analysis_type": str,
"primary_table": str,
"dimensions": [str],
"metrics": [str],
"aggregations": [str],
"filters": [dict],
"join_needed": bool,
"join_info": dict,
"expected_rounds": int,
"rationale": str,
}
"""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": PROMPT},
{
"role": "user",
"content": (
f"## Schema\n{schema_text}\n\n"
f"## 用户问题\n{question}"
),
},
],
temperature=0.1,
max_tokens=1024,
)
content = response.choices[0].message.content.strip()
plan = self._extract_json(content)
# 补充默认值
plan.setdefault("analysis_type", "overview")
plan.setdefault("expected_rounds", 3)
plan.setdefault("filters", [])
plan.setdefault("join_needed", False)
return plan
def _extract_json(self, text: str) -> dict:
"""从 LLM 输出提取 JSON"""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']:
match = re.search(pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
continue
# 找最外层 {}
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
pass
return {"intent": text[:100], "analysis_type": "overview"}