diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e291a40 --- /dev/null +++ b/.env.example @@ -0,0 +1,21 @@ +# LLM 配置(兼容 OpenAI API 格式) +LLM_API_KEY=sk-your-api-key-here +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL=gpt-4o-mini + +# 其他可用配置: +# Ollama(本地部署) +# LLM_API_KEY=ollama +# LLM_BASE_URL=http://localhost:11434/v1 +# LLM_MODEL=qwen2.5-coder:7b + +# DeepSeek +# LLM_API_KEY=sk-xxx +# LLM_BASE_URL=https://api.deepseek.com +# LLM_MODEL=deepseek-chat + +# 数据库路径(可选,默认 demo.db) +# DB_PATH=/path/to/your.db + +# 探索轮数(可选,默认 6) +# MAX_ROUNDS=6 diff --git a/agent.py b/agent.py index 8c35a47..59122ae 100644 --- a/agent.py +++ b/agent.py @@ -9,10 +9,11 @@ Layer 4: Context 上下文记忆 Output: Reporter + Chart + Consolidator """ import os +import time from typing import Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from core.config import DB_PATH, MAX_EXPLORATION_ROUNDS, PLAYBOOK_DIR, CHARTS_DIR +from core.config import DB_PATH, MAX_EXPLORATION_ROUNDS, PLAYBOOK_DIR, CHARTS_DIR, PROJECT_ROOT from core.schema import extract_schema, schema_to_text from core.sandbox import SandboxExecutor from layers.planner import Planner @@ -47,6 +48,10 @@ class DataAnalysisAgent: # 累积图表 self._all_charts: list[dict] = [] + # 报告输出目录 + self.reports_dir = os.path.join(PROJECT_ROOT, "reports") + os.makedirs(self.reports_dir, exist_ok=True) + # 自动生成 Playbook if not self.playbook_mgr.playbooks: print("\n🤖 [Playbook] 未发现预设剧本,AI 自动生成中...") @@ -145,6 +150,9 @@ class DataAnalysisAgent: # Layer 4: 记录上下文 self.context.add_session(question=question, plan=plan, steps=steps, insights=insights, report=report) + # 自动保存报告 + self.save_report(report, question, charts=charts) + return report def full_report(self, question: str = "") -> str: @@ -173,3 +181,54 @@ class DataAnalysisAgent: def close(self): """释放资源""" self.executor.close() + + def save_report(self, report: str, question: str, charts: list[dict] | None = None) -> str: + """将报告保存为 Markdown 文件,返回文件路径""" + ts = time.strftime("%Y%m%d_%H%M%S") + # 取问题前 20 字作为文件名 + import re + safe_q = re.sub(r'[^\w\u4e00-\u9fff]', '_', question)[:20].strip('_') + fname = f"{ts}_{safe_q}.md" + fpath = os.path.join(self.reports_dir, fname) + + with open(fpath, "w", encoding="utf-8") as f: + f.write(f"# 分析报告: {question}\n\n") + f.write(f"_生成时间: {time.strftime('%Y-%m-%d %H:%M:%S')}_\n\n") + f.write(report) + if charts: + f.write("\n\n---\n\n## 📊 图表索引\n\n") + for c in charts: + f.write(f"### {c['title']}\n![{c['title']}]({os.path.abspath(c['path'])})\n\n") + + print(f" 💾 报告已保存: {fpath}") + return fpath + + def export_data(self, steps: list, format: str = "csv") -> str | None: + """导出探索结果为 CSV""" + import csv + import io + + all_rows = [] + all_cols = set() + for step in steps: + if step.success and step.rows: + for row in step.rows: + row["_query"] = step.purpose + all_rows.append(row) + all_cols.update(row.keys()) + + if not all_rows: + return None + + ts = time.strftime("%Y%m%d_%H%M%S") + fname = f"export_{ts}.csv" + fpath = os.path.join(self.reports_dir, fname) + + cols = sorted(all_cols) + with open(fpath, "w", encoding="utf-8-sig", newline="") as f: + writer = csv.DictWriter(f, fieldnames=cols) + writer.writeheader() + writer.writerows(all_rows) + + print(f" 📁 数据已导出: {fpath} ({len(all_rows)} 行)") + return fpath diff --git a/cli.py b/cli.py index d536367..c035e47 100644 --- a/cli.py +++ b/cli.py @@ -1,5 +1,5 @@ """ -交互式 CLI —— 四层架构自适应分析 +交互式 CLI —— 四层架构自适应分析(增强版) 用法: python cli.py [数据库路径] """ import os @@ -7,7 +7,7 @@ import sys sys.path.insert(0, os.path.dirname(__file__)) -from core.config import DB_PATH, LLM_CONFIG, MAX_EXPLORATION_ROUNDS, PLAYBOOK_DIR +from core.config import DB_PATH, LLM_CONFIG, MAX_EXPLORATION_ROUNDS, PLAYBOOK_DIR, PROJECT_ROOT from agent import DataAnalysisAgent @@ -17,6 +17,8 @@ def print_help(): <问题> 分析一个问题 rounds= <问题> 设置探索轮数 report [主题] 整合所有分析,生成综合报告 + export 导出最近一次分析结果为 CSV + reports 列出已保存的报告文件 schema 查看数据库 Schema playbooks 查看已加载的预设剧本 regen 重新生成预设剧本 @@ -28,6 +30,39 @@ def print_help(): """) +def cmd_reports(agent): + """列出已保存的报告""" + reports_dir = agent.reports_dir + if not os.path.isdir(reports_dir): + print("(reports 目录不存在)") + return + files = sorted([f for f in os.listdir(reports_dir) if f.endswith(".md")]) + if not files: + print("(尚无保存的报告)") + return + print(f"\n📁 已保存 {len(files)} 份报告:") + for f in files: + fpath = os.path.join(reports_dir, f) + size = os.path.getsize(fpath) + print(f" 📄 {f} ({size/1024:.1f} KB)") + + +def setup_readline(): + """启用命令历史(Linux/macOS)""" + try: + import readline + histfile = os.path.join(PROJECT_ROOT, ".cli_history") + try: + readline.read_history_file(histfile) + except FileNotFoundError: + pass + import atexit + atexit.register(readline.write_history_file, histfile) + readline.set_history_length(100) + except ImportError: + pass + + def main(): db_path = sys.argv[1] if len(sys.argv) > 1 else DB_PATH @@ -36,9 +71,15 @@ def main(): sys.exit(1) if not LLM_CONFIG["api_key"]: - print("⚠️ 未配置 LLM_API_KEY") + print("❌ LLM_API_KEY 未配置!") + print(" 请设置环境变量或创建 .env 文件:") + print(" LLM_API_KEY=your-key") + print(" LLM_BASE_URL=https://api.openai.com/v1") + print(" LLM_MODEL=gpt-4o-mini") sys.exit(1) + setup_readline() + agent = DataAnalysisAgent(db_path) print("=" * 60) @@ -50,6 +91,8 @@ def main(): print(f"📋 预设剧本: {len(agent.playbook_mgr.playbooks)} 个") print(f"\n💬 输入分析问题(help 查看命令)\n") + last_steps = None # 记录最近一次分析的 steps,用于 export + while True: try: user_input = input("📊 > ").strip() @@ -75,7 +118,19 @@ def main(): print(agent.get_audit()) elif cmd == "clear": agent.clear_history() + last_steps = None print("✅ 历史已清空") + elif cmd == "reports": + cmd_reports(agent) + elif cmd == "export": + if last_steps: + fpath = agent.export_data(last_steps) + if fpath: + print(f"✅ 导出成功: {fpath}") + else: + print("⚠️ 无数据可导出") + else: + print("⚠️ 请先执行一次分析") elif cmd.startswith("report"): topic = user_input[6:].strip() try: @@ -120,6 +175,9 @@ def main(): try: report = agent.analyze(question, max_rounds=max_rounds) + # 保存 steps 用于 export + if agent.context.sessions: + last_steps = agent.context.sessions[-1].steps print("\n" + report) print("\n" + "~" * 60) except Exception as e: diff --git a/core/config.py b/core/config.py index aa979d0..f7a4b4b 100644 --- a/core/config.py +++ b/core/config.py @@ -1,13 +1,38 @@ """ -配置文件 +配置文件 —— 支持环境变量 + .env 文件 """ import os + +def _load_dotenv(path: str = ".env"): + """简易 .env 加载器,不依赖 python-dotenv""" + if not os.path.isfile(path): + return + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + key, _, val = line.partition("=") + key, val = key.strip(), val.strip().strip('"').strip("'") + if key and key not in os.environ: # 环境变量优先 + os.environ[key] = val + + +# 项目根目录(先定义,.env 加载需要用到) +PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) + +# 加载 .env(项目根目录优先,其次当前目录) +_load_dotenv(os.path.join(PROJECT_ROOT, ".env")) +_load_dotenv(".env") + # LLM 配置(兼容 OpenAI API 格式,包括 Ollama / vLLM / DeepSeek 等) LLM_CONFIG = { - "api_key": os.getenv("LLM_API_KEY", "sk-c44i1hy64xgzwox6x08o4zug93frq6rgn84oqugf2pje1tg4"), - "base_url": os.getenv("LLM_BASE_URL", "https://api.xiaomimimo.com/v1"), - "model": os.getenv("LLM_MODEL", "mimo-v2-flash"), + "api_key": os.getenv("LLM_API_KEY", ""), + "base_url": os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + "model": os.getenv("LLM_MODEL", "gpt-4o-mini"), } # 沙箱安全规则 diff --git a/core/utils.py b/core/utils.py index c4a51f7..009fc4c 100644 --- a/core/utils.py +++ b/core/utils.py @@ -1,8 +1,9 @@ """ -公共工具 —— JSON 提取、LLM 客户端单例 +公共工具 —— JSON 提取、LLM 客户端单例、重试机制 """ import json import re +import time from typing import Any import openai @@ -17,14 +18,77 @@ def get_llm_client(config: dict) -> tuple[openai.OpenAI, str]: """获取 LLM 客户端(单例),避免每个组件各建一个""" global _llm_client, _llm_model if _llm_client is None: + api_key = config.get("api_key", "") + if not api_key: + raise RuntimeError( + "LLM_API_KEY 未配置!请设置环境变量或在 .env 文件中添加:\n" + " LLM_API_KEY=your-key\n" + " LLM_BASE_URL=https://api.openai.com/v1\n" + " LLM_MODEL=gpt-4o-mini" + ) _llm_client = openai.OpenAI( - api_key=config["api_key"], + api_key=api_key, base_url=config["base_url"], ) _llm_model = config["model"] return _llm_client, _llm_model +# ── LLM 调用重试包装 ──────────────────────────────── + +class LLMCallError(Exception): + """LLM 调用最终失败""" + pass + + +def llm_chat(client: openai.OpenAI, model: str, messages: list[dict], + max_retries: int = 3, **kwargs) -> str: + """ + 带指数退避重试的 LLM 调用。 + 处理 429 限频、5xx 超时、网络错误。 + """ + last_err = None + for attempt in range(max_retries): + try: + response = client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + return response.choices[0].message.content.strip() + except openai.RateLimitError as e: + last_err = e + # 读取 Retry-After 或使用默认退避 + wait = _get_retry_delay(e, attempt) + print(f" ⏳ 限频,等待 {wait:.1f}s 后重试 ({attempt+1}/{max_retries})...") + time.sleep(wait) + except (openai.APITimeoutError, openai.APIConnectionError, openai.APIStatusError) as e: + last_err = e + wait = min(2 ** attempt * 2, 30) + print(f" ⚠️ API 错误: {type(e).__name__},等待 {wait:.1f}s ({attempt+1}/{max_retries})...") + time.sleep(wait) + except Exception as e: + last_err = e + if attempt < max_retries - 1: + wait = min(2 ** attempt * 2, 30) + time.sleep(wait) + continue + raise + + raise LLMCallError(f"LLM 调用失败({max_retries} 次重试): {last_err}") + + +def _get_retry_delay(error, attempt: int) -> float: + """从错误响应中提取重试等待时间""" + try: + if hasattr(error, 'response') and error.response is not None: + retry_after = error.response.headers.get('Retry-After') + if retry_after: + return float(retry_after) + except Exception: + pass + # 指数退避: 2s, 4s, 8s, 最大 30s + return min(2 ** (attempt + 1), 30) + + # ── JSON 提取 ──────────────────────────────────────── def extract_json_object(text: str) -> dict: diff --git a/import_csv.py b/import_csv.py index 950c2b9..13858a6 100644 --- a/import_csv.py +++ b/import_csv.py @@ -1,21 +1,198 @@ """ -将工单 CSV 数据导入 SQLite 数据库 +将工单 CSV 数据导入 SQLite 数据库 —— 增强版 +- 自动检测列名映射(兼容中英文) +- 空值/异常数据容错 +- 数据类型自动推断 +- 导入前完整性校验 """ import csv import sqlite3 import os import sys +import re +from typing import Any, Optional -def import_csv(csv_path: str, db_path: str): - """将工单 CSV 导入 SQLite""" + +# ── 列名别名映射(兼容不同版本 CSV)───────────────── + +COLUMN_ALIASES = { + "工单号": ["工单号", "ticket_id", "ticket_no", "id", "工单编号"], + "来源": ["来源", "source", "渠道"], + "创建日期": ["创建日期", "created_date", "create_date"], + "问题类型": ["问题类型", "issue_type", "type", "问题分类"], + "问题描述": ["问题描述", "description", "描述"], + "处理过程": ["处理过程", "process", "处理流程"], + "跟踪记录": ["跟踪记录", "tracking", "跟踪"], + "严重程度": ["严重程度", "severity", "priority", "优先级"], + "工单状态": ["工单状态", "status", "状态"], + "模块": ["模块", "module", "功能模块"], + "责任人": ["责任人", "assignee", "负责人"], + "关闭日期": ["关闭日期", "closed_date", "close_date"], + "车型": ["车型", "vehicle_model", "car_model"], + "VIN": ["VIN", "vin", "车架号"], + "SIM": ["SIM", "sim", "sim卡号"], + "Notes": ["Notes", "notes", "备注"], + "Attachment": ["Attachment", "attachment", "附件"], + "创建人": ["创建人", "creator", "创建者"], + "关闭时长_天": ["关闭时长(天)", "关闭时长_天", "close_duration", "duration_days"], + "创建日期_解析": ["创建日期_解析", "created_date_parsed"], + "关闭日期_解析": ["关闭日期_解析", "closed_date_parsed"], +} + + +def detect_column_mapping(headers: list[str]) -> dict[str, Optional[str]]: + """ + 自动检测 CSV 列名到标准列名的映射。 + 返回 {标准列名: CSV实际列名},找不到的值为 None。 + """ + # 标准化:去空格、小写 + header_map = {h.strip().lower(): h for h in headers} + mapping = {} + + for std_name, aliases in COLUMN_ALIASES.items(): + found = None + for alias in aliases: + key = alias.strip().lower() + if key in header_map: + found = header_map[key] + break + mapping[std_name] = found + + return mapping + + +def safe_float(val: Any) -> Optional[float]: + """安全转 float""" + if val is None or str(val).strip() == "": + return None + try: + return float(str(val).strip()) + except (ValueError, TypeError): + return None + + +def safe_str(val: Any) -> str: + """安全转 string,None → 空串""" + if val is None: + return "" + return str(val).strip() + + +def validate_row(row: dict, mapping: dict) -> tuple[bool, list[str]]: + """校验单行数据,返回 (是否通过, 问题列表)""" + issues = [] + ticket_id = safe_str(row.get(mapping.get("工单号", ""), "")) + if not ticket_id: + issues.append("缺少工单号") + return len(issues) == 0, issues + + +def import_csv(csv_path: str, db_path: str, dry_run: bool = False) -> dict: + """ + 导入 CSV 到 SQLite。 + 返回统计信息 dict。 + """ + stats = { + "total": 0, "imported": 0, "skipped": 0, + "warnings": [], "columns_detected": {}, "columns_missing": [], + } + + if not os.path.isfile(csv_path): + print(f"❌ CSV 文件不存在: {csv_path}") + return stats + + # ── 读取 CSV ────────────────────────────── + with open(csv_path, "r", encoding="utf-8-sig") as f: + reader = csv.DictReader(f) + headers = reader.fieldnames or [] + rows = list(reader) + + stats["total"] = len(rows) + print(f"📄 读取 CSV: {csv_path}") + print(f" 列: {headers}") + print(f" 行数: {len(rows)}") + + # ── 检测列映射 ──────────────────────────── + mapping = detect_column_mapping(headers) + detected = {k: v for k, v in mapping.items() if v is not None} + missing = [k for k, v in mapping.items() if v is None] + stats["columns_detected"] = detected + stats["columns_missing"] = missing + + print(f"\n🔍 列名映射:") + for std, actual in detected.items(): + print(f" ✅ {std} ← {actual}") + for m in missing: + print(f" ⚠️ {m} ← 未找到(将使用空值)") + + if "工单号" not in detected: + print(f"\n❌ 至少需要「工单号」列,无法继续!") + return stats + + # ── 数据预处理 + 校验 ────────────────────── + processed = [] + for i, row in enumerate(rows): + valid, issues = validate_row(row, mapping) + if not valid: + stats["skipped"] += 1 + if len(stats["warnings"]) < 10: + stats["warnings"].append(f"行 {i+2}: {', '.join(issues)}") + continue + + def get_col(std_name: str, default: str = "") -> str: + actual = mapping.get(std_name) + return safe_str(row.get(actual, default)) if actual else default + + def get_float(std_name: str) -> Optional[float]: + actual = mapping.get(std_name) + if not actual: + return None + return safe_float(row.get(actual)) + + processed.append(( + get_col("工单号"), + get_col("来源"), + get_col("创建日期"), + get_col("问题类型"), + get_col("问题描述"), + get_col("处理过程"), + get_col("跟踪记录"), + get_col("严重程度"), + get_col("工单状态"), + get_col("模块"), + get_col("责任人"), + get_col("关闭日期"), + get_col("车型"), + get_col("VIN"), + get_col("SIM"), + get_col("Notes"), + get_col("Attachment"), + get_col("创建人"), + get_float("关闭时长_天"), + get_col("创建日期_解析"), + get_col("关闭日期_解析"), + )) + + stats["imported"] = len(processed) + print(f"\n✅ 预处理完成: {len(processed)} 条有效, {stats['skipped']} 条跳过") + + if stats["warnings"]: + print(f" 警告:") + for w in stats["warnings"][:5]: + print(f" ⚠️ {w}") + + if dry_run: + print(" (dry_run 模式,未写入数据库)") + return stats + + # ── 写入数据库 ───────────────────────────── if os.path.exists(db_path): os.remove(db_path) - print(f"🗑️ 已删除旧数据库: {db_path}") + print(f"\n🗑️ 已删除旧数据库: {db_path}") conn = sqlite3.connect(db_path) cur = conn.cursor() - # 创建工单表 cur.execute(""" CREATE TABLE tickets ( 工单号 TEXT PRIMARY KEY, @@ -42,62 +219,39 @@ def import_csv(csv_path: str, db_path: str): ) """) - with open(csv_path, "r", encoding="utf-8-sig") as f: - reader = csv.DictReader(f) - rows = list(reader) - - for row in rows: - cur.execute(""" - INSERT INTO tickets VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) - """, ( - row.get("工单号", ""), - row.get("来源", ""), - row.get("创建日期", ""), - row.get("问题类型", ""), - row.get("问题描述", ""), - row.get("处理过程", ""), - row.get("跟踪记录", ""), - row.get("严重程度", ""), - row.get("工单状态", ""), - row.get("模块", ""), - row.get("责任人", ""), - row.get("关闭日期", ""), - row.get("车型", ""), - row.get("VIN", ""), - row.get("SIM", ""), - row.get("Notes", ""), - row.get("Attachment", ""), - row.get("创建人", ""), - float(row["关闭时长(天)"]) if row.get("关闭时长(天)") else None, - row.get("创建日期_解析", ""), - row.get("关闭日期_解析", ""), - )) + cur.executemany( + "INSERT INTO tickets VALUES (" + ",".join(["?"] * 21) + ")", + processed, + ) conn.commit() - print(f"✅ 导入 {len(rows)} 条工单到 {db_path}") - # 验证 + # ── 验证 ────────────────────────────────── cur.execute("SELECT COUNT(*) FROM tickets") - print(f" 数据库中共 {cur.fetchone()[0]} 条记录") + db_count = cur.fetchone()[0] + print(f"\n✅ 写入完成: 数据库中 {db_count} 条记录") - cur.execute("SELECT DISTINCT 问题类型 FROM tickets") - types = [r[0] for r in cur.fetchall()] - print(f" 问题类型: {', '.join(types)}") - - cur.execute("SELECT DISTINCT 工单状态 FROM tickets") - statuses = [r[0] for r in cur.fetchall()] - print(f" 工单状态: {', '.join(statuses)}") - - cur.execute("SELECT DISTINCT 车型 FROM tickets") - models = [r[0] for r in cur.fetchall()] - print(f" 车型: {', '.join(models)}") + # 打印维度信息 + for col in ("问题类型", "工单状态", "车型", "模块", "来源"): + actual = mapping.get(col) + if actual: + cur.execute(f'SELECT DISTINCT "{col}" FROM tickets WHERE "{col}" != ""') + vals = [r[0] for r in cur.fetchall()] + if vals: + print(f" {col}: {', '.join(vals[:10])}{'...' if len(vals) > 10 else ''}") conn.close() + return stats if __name__ == "__main__": csv_path = sys.argv[1] if len(sys.argv) > 1 else "cleaned_data.csv" db_path = os.path.join(os.path.dirname(__file__), "demo.db") - import_csv(csv_path, db_path) + dry_run = "--dry-run" in sys.argv + + stats = import_csv(csv_path, db_path, dry_run=dry_run) + + if stats["columns_missing"]: + print(f"\n💡 提示: 以下列在 CSV 中未找到,已用空值填充:") + for m in stats["columns_missing"]: + print(f" - {m}") diff --git a/layers/context.py b/layers/context.py index 707b83f..0347065 100644 --- a/layers/context.py +++ b/layers/context.py @@ -1,7 +1,10 @@ """ -Layer 4: 上下文管理器 +Layer 4: 上下文管理器 —— 增强版 +- 关键词语义匹配,替代简单取最近 N 条 +- 会话摘要去重 """ import time +import re from dataclasses import dataclass, field from typing import Optional @@ -19,6 +22,26 @@ class AnalysisSession: report: str timestamp: float = field(default_factory=time.time) + @property + def keywords(self) -> set[str]: + """提取会话关键词(中文分字 + 英文词切分)""" + text = f"{self.question} {self.plan.get('intent', '')} {' '.join(self.plan.get('dimensions', []))}" + # 中文字符 + cn_chars = set(re.findall(r'[\u4e00-\u9fff]+', text)) + # 英文单词(小写) + en_words = set(re.findall(r'[a-zA-Z]{2,}', text.lower())) + return cn_chars | en_words + + def similarity(self, question: str) -> float: + """与新问题的关键词相似度(Jaccard-like)""" + q_cn = set(re.findall(r'[\u4e00-\u9fff]+', question)) + q_en = set(re.findall(r'[a-zA-Z]{2,}', question.lower())) + q_kw = q_cn | q_en + if not q_kw: + return 0.0 + overlap = self.keywords & q_kw + return len(overlap) / len(q_kw) + def summary(self) -> str: parts = [f"**问题**: {self.question}"] if self.plan: @@ -48,9 +71,9 @@ class AnalysisSession: class ContextManager: - """上下文管理器""" + """上下文管理器 —— 语义匹配增强版""" - def __init__(self, max_history: int = 10): + def __init__(self, max_history: int = 20): self.sessions: list[AnalysisSession] = [] self.max_history = max_history @@ -63,9 +86,25 @@ class ContextManager: return session def get_context_for(self, new_question: str) -> Optional[str]: + """ + 智能匹配最相关的 1~3 个历史分析作为上下文。 + 相似度 > 0.3 才引用,最多 3 条,按相似度降序。 + """ if not self.sessions: return None - return "\n\n---\n\n".join(s.to_reference_text() for s in self.sessions[-2:]) + + scored = [] + for s in self.sessions: + sim = s.similarity(new_question) + if sim > 0.3: # 相关性阈值 + scored.append((sim, s)) + + if not scored: + # 无相关历史,返回最近 1 条作为兜底 + return self.sessions[-1].to_reference_text() + + scored.sort(key=lambda x: x[0], reverse=True) + return "\n\n---\n\n".join(s.to_reference_text() for _, s in scored[:3]) def get_history_summary(self) -> str: if not self.sessions: diff --git a/layers/explorer.py b/layers/explorer.py index 6309f95..9e40b8c 100644 --- a/layers/explorer.py +++ b/layers/explorer.py @@ -6,7 +6,7 @@ from typing import Any from dataclasses import dataclass, field from core.config import LLM_CONFIG -from core.utils import get_llm_client, extract_json_object +from core.utils import get_llm_client, llm_chat, extract_json_object from core.sandbox import SandboxExecutor @@ -206,10 +206,10 @@ class Explorer: return "\n\n".join(parts) def _llm_decide(self, messages: list[dict]) -> dict: - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.2, max_tokens=1024, + content = llm_chat( + self.client, self.model, + messages=messages, temperature=0.2, max_tokens=1024, ) - content = response.choices[0].message.content.strip() result = extract_json_object(content) return result if result else {"action": "done", "reasoning": f"无法解析: {content[:100]}"} diff --git a/layers/insights.py b/layers/insights.py index 5c858a0..9f16162 100644 --- a/layers/insights.py +++ b/layers/insights.py @@ -5,7 +5,7 @@ import json from typing import Any from core.config import LLM_CONFIG -from core.utils import get_llm_client, extract_json_array +from core.utils import get_llm_client, llm_chat, extract_json_array from layers.explorer import ExplorationStep @@ -68,15 +68,14 @@ class InsightEngine: return [] history = self._build_history(steps) - response = self.client.chat.completions.create( - model=self.model, + content = llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": INSIGHT_SYSTEM}, {"role": "user", "content": f"## 用户问题\n{question}\n\n## 探索历史\n{history}\n\n请分析以上数据,输出异常和洞察。"}, ], temperature=0.3, max_tokens=2048, ) - content = response.choices[0].message.content.strip() return [Insight(d) for d in extract_json_array(content)] def format_insights(self, insights: list[Insight]) -> str: @@ -109,9 +108,9 @@ class InsightEngine: def quick_detect(steps: list[ExplorationStep]) -> list[str]: - """基于规则的快速异常检测,不调 LLM""" + """基于规则的快速异常检测(零 LLM 成本)""" alerts = [] - seen = set() # 去重 + seen = set() for step in steps: if not step.success or not step.rows: @@ -119,22 +118,21 @@ def quick_detect(steps: list[ExplorationStep]) -> list[str]: for col in step.columns: vals = [r.get(col) for r in step.rows if isinstance(r.get(col), (int, float))] - if not vals: + if len(vals) < 2: continue col_lower = col.lower() - # 占比列:某个分组占比过高 + # ── 占比列:集中度过高 ── if col_lower in ("pct", "percent", "percentage", "占比"): - for v in vals: - if v > 50: - key = f"pct_{step.purpose}" - if key not in seen: - seen.add(key) - alerts.append(f"⚠️ {step.purpose} 中某个分组占比 {v}%,集中度过高") - break + max_pct = max(vals) + if max_pct > 50: + key = f"pct_{step.purpose}" + if key not in seen: + seen.add(key) + alerts.append(f"⚠️ {step.purpose}: 最高占比 {max_pct}%,集中度过高") - # 计数列:极值差异 + # ── 计数列:极值差异 ── if col_lower in ("count", "cnt", "n", "total", "order_count") and len(vals) >= 3: avg = sum(vals) / len(vals) if avg > 0: @@ -143,6 +141,47 @@ def quick_detect(steps: list[ExplorationStep]) -> list[str]: key = f"count_{step.purpose}" if key not in seen: seen.add(key) - alerts.append(f"⚠️ {step.purpose} 中最大值是均值的 {ratio:.1f} 倍") + alerts.append(f"⚠️ {step.purpose}: 最大值是均值的 {ratio:.1f} 倍") + + # ── Z-Score 异常检测 ── + if len(vals) >= 5 and col_lower not in ("id", "year", "month"): + mean = sum(vals) / len(vals) + variance = sum((v - mean) ** 2 for v in vals) / len(vals) + std = variance ** 0.5 + if std > 0: + outliers = [(i, v) for i, v in enumerate(vals) if abs(v - mean) / std > 2] + if outliers: + key = f"zscore_{step.purpose}_{col}" + if key not in seen: + seen.add(key) + outlier_desc = ", ".join(f"{v:.1f}" for _, v in outliers[:3]) + alerts.append( + f"⚠️ {step.purpose}「{col}」发现 {len(outliers)} 个异常值 " + f"(均值={mean:.1f}, σ={std:.1f}, 异常值={outlier_desc})" + ) + + # ── 离散度检测(变异系数 CV)── + if len(vals) >= 3 and col_lower not in ("id", "year", "month"): + mean = sum(vals) / len(vals) + if mean != 0: + variance = sum((v - mean) ** 2 for v in vals) / len(vals) + std = variance ** 0.5 + cv = std / abs(mean) + if cv > 1.0: + key = f"cv_{step.purpose}_{col}" + if key not in seen: + seen.add(key) + alerts.append(f"⚠️ {step.purpose}「{col}」离散度高 (CV={cv:.2f}),数据波动大") + + # ── 零值/缺失检测 ── + if col_lower in ("count", "cnt", "total", "amount", "sum", "关闭时长"): + zero_count = sum(1 for v in vals if v == 0) + if zero_count > 0 and zero_count < len(vals): + pct = zero_count / len(vals) * 100 + if pct > 10: + key = f"zero_{step.purpose}_{col}" + if key not in seen: + seen.add(key) + alerts.append(f"⚠️ {step.purpose}「{col}」有 {zero_count} 个零值 ({pct:.0f}%)") return alerts diff --git a/layers/planner.py b/layers/planner.py index 513de8e..8038854 100644 --- a/layers/planner.py +++ b/layers/planner.py @@ -5,7 +5,7 @@ import json from typing import Any from core.config import LLM_CONFIG -from core.utils import get_llm_client, extract_json_object +from core.utils import get_llm_client, llm_chat, extract_json_object PROMPT = """你是一个数据分析规划专家。 @@ -52,8 +52,8 @@ class Planner: self.client, self.model = get_llm_client(LLM_CONFIG) def plan(self, question: str, schema_text: str) -> dict[str, Any]: - response = self.client.chat.completions.create( - model=self.model, + content = llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": PROMPT}, {"role": "user", "content": f"## Schema\n{schema_text}\n\n## 用户问题\n{question}"}, @@ -61,7 +61,6 @@ class Planner: temperature=0.1, max_tokens=1024, ) - content = response.choices[0].message.content.strip() plan = extract_json_object(content) if not plan: diff --git a/layers/playbook.py b/layers/playbook.py index 3b8864b..adc3510 100644 --- a/layers/playbook.py +++ b/layers/playbook.py @@ -7,7 +7,7 @@ import re from typing import Optional from core.config import LLM_CONFIG -from core.utils import get_llm_client, extract_json_object, extract_json_array +from core.utils import get_llm_client, llm_chat, extract_json_object, extract_json_array class Playbook: @@ -87,15 +87,14 @@ class PlaybookManager: - 直接使用实际表名和列名""" try: - response = self.client.chat.completions.create( - model=self.model, + content = llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": "你是数据分析专家。只输出 JSON,不要其他内容。"}, {"role": "user", "content": prompt}, ], temperature=0.3, max_tokens=4096, ) - content = response.choices[0].message.content.strip() playbooks_data = extract_json_array(content) if not playbooks_data: return [] @@ -150,15 +149,15 @@ class PlaybookManager: 不匹配: {{"matched": false, "reasoning": "原因"}}""" try: - response = self.client.chat.completions.create( - model=self.model, + content = llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": "你是分析计划匹配器。"}, {"role": "user", "content": prompt}, ], temperature=0.1, max_tokens=512, ) - result = extract_json_object(response.choices[0].message.content.strip()) + result = extract_json_object(content) if not result.get("matched"): return None diff --git a/output/chart.py b/output/chart.py index 2d50918..ec9b0a1 100644 --- a/output/chart.py +++ b/output/chart.py @@ -12,11 +12,12 @@ import matplotlib.pyplot as plt import matplotlib.font_manager as fm from core.config import LLM_CONFIG -from core.utils import get_llm_client, extract_json_array +from core.utils import get_llm_client, llm_chat, extract_json_array from layers.explorer import ExplorationStep def _setup_chinese_font(): + """尝试加载中文字体,找不到时用英文显示(不崩溃)""" candidates = [ "SimHei", "Microsoft YaHei", "STHeiti", "WenQuanYi Micro Hei", "Noto Sans CJK SC", "PingFang SC", "Source Han Sans CN", @@ -27,10 +28,27 @@ def _setup_chinese_font(): plt.rcParams["font.sans-serif"] = [font] plt.rcParams["axes.unicode_minus"] = False return font + # 兜底:尝试找任何 CJK 字体 + for f in fm.fontManager.ttflist: + if any(kw in f.name.lower() for kw in ("cjk", "chinese", "hei", "song", "ming", "fang")): + plt.rcParams["font.sans-serif"] = [f.name] + plt.rcParams["axes.unicode_minus"] = False + return f.name plt.rcParams["axes.unicode_minus"] = False - return None + return None # 后续图表标题会用英文 fallback -_setup_chinese_font() + +_CN_FONT = _setup_chinese_font() + + +def _safe_title(title: str) -> str: + """无中文字体时将标题转为安全显示文本""" + if _CN_FONT: + return title + # 简单映射:中文→拼音首字母摘要,保留英文和数字 + import re + clean = re.sub(r'[^\w\s.,;:!?%/()\-+]', '', title) + return clean if clean.strip() else "Chart" CHART_PLAN_PROMPT = """你是一个数据可视化专家。根据以下分析结果,规划需要生成的图表。 @@ -97,15 +115,15 @@ class ChartGenerator: ) try: - response = self.client.chat.completions.create( - model=self.model, + content = llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": "你是数据可视化专家。只输出纯 JSON 数组,不要 markdown 代码块。"}, {"role": "user", "content": CHART_PLAN_PROMPT.format(exploration_summary="\n\n".join(summary_parts))}, ], temperature=0.1, max_tokens=1024, ) - plans = extract_json_array(response.choices[0].message.content.strip()) + plans = extract_json_array(content) return plans if plans else self._fallback_plan(valid_steps) except Exception as e: print(f" ⚠️ 图表规划失败: {e},使用 fallback") @@ -206,11 +224,11 @@ class ChartGenerator: ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9) ax.legend() - ax.set_title(title, fontsize=13, fontweight="bold", pad=12) + ax.set_title(_safe_title(title), fontsize=13, fontweight="bold", pad=12) if chart_type not in ("pie",): - ax.set_xlabel(x_col, fontsize=10) + ax.set_xlabel(_safe_title(x_col), fontsize=10) if chart_type != "horizontal_bar": - ax.set_ylabel(y_col, fontsize=10) + ax.set_ylabel(_safe_title(y_col), fontsize=10) ax.grid(axis="y", alpha=0.3) plt.tight_layout() diff --git a/output/consolidator.py b/output/consolidator.py index b9ad5ee..2c8381f 100644 --- a/output/consolidator.py +++ b/output/consolidator.py @@ -4,7 +4,7 @@ import json from core.config import LLM_CONFIG -from core.utils import get_llm_client +from core.utils import get_llm_client, llm_chat from layers.context import AnalysisSession @@ -47,15 +47,14 @@ class ReportConsolidator: charts_text = "\n".join(f"{i}. {c['title']}: {c['path']}" for i, c in enumerate(charts or [], 1)) or "无图表。" try: - response = self.client.chat.completions.create( - model=self.model, + return llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": "你是高级数据分析总监,整合多维度分析结果。"}, {"role": "user", "content": CONSOLIDATE_PROMPT.format(question=question, sections=sections, charts_text=charts_text)}, ], temperature=0.3, max_tokens=4096, ) - return response.choices[0].message.content except Exception as e: print(f" ⚠️ LLM 整合失败: {e},使用拼接模式") return self._fallback_concat(sessions, charts) diff --git a/output/reporter.py b/output/reporter.py index 13f8ab1..12af3dd 100644 --- a/output/reporter.py +++ b/output/reporter.py @@ -5,7 +5,7 @@ import json from typing import Any from core.config import LLM_CONFIG -from core.utils import get_llm_client +from core.utils import get_llm_client, llm_chat from layers.explorer import ExplorationStep from layers.insights import Insight @@ -58,15 +58,14 @@ class ReportGenerator: charts_text=charts_text, ) - response = self.client.chat.completions.create( - model=self.model, + return llm_chat( + self.client, self.model, messages=[ {"role": "system", "content": "你是专业的数据分析师,撰写清晰、有洞察力的分析报告。"}, {"role": "user", "content": prompt}, ], temperature=0.3, max_tokens=4096, ) - return response.choices[0].message.content def _build_exploration(self, steps: list[ExplorationStep]) -> str: parts = []