""" Layer 1: 意图规划器 将用户问题解析为结构化的分析计划,替代硬编码模板。 """ import json import re from typing import Any PROMPT = """你是一个数据分析规划专家。 ## 你的任务 根据用户的分析问题和数据库 Schema,生成一个结构化的分析计划。 ## 输出格式(严格 JSON) ```json { "intent": "一句话描述用户想了解什么", "analysis_type": "ranking" | "distribution" | "trend" | "comparison" | "anomaly" | "overview", "primary_table": "主要分析的表名", "dimensions": ["分组维度列名"], "metrics": ["需要聚合的数值列名"], "aggregations": ["SUM", "AVG", "COUNT", ...], "filters": [{"column": "列名", "condition": "过滤条件(可选)"}], "join_needed": false, "join_info": {"tables": [], "on": ""}, "expected_rounds": 3, "rationale": "为什么这样规划,需要关注什么" } ``` ## 分析类型说明 - ranking: 按某维度排名(哪个地区最高) - distribution: 分布/占比(各地区占比多少) - trend: 时间趋势(月度变化) - comparison: 对比分析(A vs B) - anomaly: 异常检测(有没有异常值) - overview: 全局概览(整体情况如何) ## 规划原则 1. 只选择与问题相关的表和列 2. 如果需要 JOIN,说明关联条件 3. 预估需要几轮探索(1-6) 4. 标注可能的异常关注点 5. metrics 不要包含 id 列""" import openai from config import LLM_CONFIG class Planner: """意图规划器:将自然语言问题转为结构化分析计划""" def __init__(self): self.client = openai.OpenAI( api_key=LLM_CONFIG["api_key"], base_url=LLM_CONFIG["base_url"], ) self.model = LLM_CONFIG["model"] def plan(self, question: str, schema_text: str) -> dict[str, Any]: """ 生成分析计划 Returns: { "intent": str, "analysis_type": str, "primary_table": str, "dimensions": [str], "metrics": [str], "aggregations": [str], "filters": [dict], "join_needed": bool, "join_info": dict, "expected_rounds": int, "rationale": str, } """ response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": PROMPT}, { "role": "user", "content": ( f"## Schema\n{schema_text}\n\n" f"## 用户问题\n{question}" ), }, ], temperature=0.1, max_tokens=1024, ) content = response.choices[0].message.content.strip() plan = self._extract_json(content) # 补充默认值 plan.setdefault("analysis_type", "overview") plan.setdefault("expected_rounds", 3) plan.setdefault("filters", []) plan.setdefault("join_needed", False) return plan def _extract_json(self, text: str) -> dict: """从 LLM 输出提取 JSON""" try: return json.loads(text) except json.JSONDecodeError: pass for pattern in [r'```json\s*\n(.*?)\n```', r'```\s*\n(.*?)\n```']: match = re.search(pattern, text, re.DOTALL) if match: try: return json.loads(match.group(1)) except json.JSONDecodeError: continue # 找最外层 {} match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL) if match: try: return json.loads(match.group()) except json.JSONDecodeError: pass return {"intent": text[:100], "analysis_type": "overview"}