安全与稳定性: - 移除硬编码 API Key,改用 .env + 环境变量 - LLM 调用统一重试机制(指数退避,3 次重试,处理 429/5xx/超时) - 中文字体检测增强(CJK 关键词兜底 + 无字体时英文 fallback) - 缺失 API Key 给出友好提示而非崩溃 分析能力提升: - 异常检测新增 z-score 检测(标准差>2 标记异常) - 新增变异系数 CV 检测(数据波动性) - 新增零值/缺失检测 - 上下文管理器升级为关键词语义匹配(替代简单取最近 2 条) 用户体验: - 报告自动保存为 Markdown(reports/ 目录) - 新增 export 命令导出查询结果为 CSV - 新增 reports 命令查看已保存报告 - CLI 支持 readline 命令历史(方向键翻阅) - CSV 导入工具重写:自动列名映射、容错处理、dry-run 模式 - 新增 .env.example 配置模板
266 lines
10 KiB
Python
266 lines
10 KiB
Python
"""
|
||
图表生成器 —— 根据探索结果自动生成可视化图表
|
||
"""
|
||
import json
|
||
import os
|
||
import re
|
||
from typing import Any
|
||
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.font_manager as fm
|
||
|
||
from core.config import LLM_CONFIG
|
||
from core.utils import get_llm_client, llm_chat, extract_json_array
|
||
from layers.explorer import ExplorationStep
|
||
|
||
|
||
def _setup_chinese_font():
|
||
"""尝试加载中文字体,找不到时用英文显示(不崩溃)"""
|
||
candidates = [
|
||
"SimHei", "Microsoft YaHei", "STHeiti", "WenQuanYi Micro Hei",
|
||
"Noto Sans CJK SC", "PingFang SC", "Source Han Sans CN",
|
||
]
|
||
available = {f.name for f in fm.fontManager.ttflist}
|
||
for font in candidates:
|
||
if font in available:
|
||
plt.rcParams["font.sans-serif"] = [font]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
return font
|
||
# 兜底:尝试找任何 CJK 字体
|
||
for f in fm.fontManager.ttflist:
|
||
if any(kw in f.name.lower() for kw in ("cjk", "chinese", "hei", "song", "ming", "fang")):
|
||
plt.rcParams["font.sans-serif"] = [f.name]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
return f.name
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
return None # 后续图表标题会用英文 fallback
|
||
|
||
|
||
_CN_FONT = _setup_chinese_font()
|
||
|
||
|
||
def _safe_title(title: str) -> str:
|
||
"""无中文字体时将标题转为安全显示文本"""
|
||
if _CN_FONT:
|
||
return title
|
||
# 简单映射:中文→拼音首字母摘要,保留英文和数字
|
||
import re
|
||
clean = re.sub(r'[^\w\s.,;:!?%/()\-+]', '', title)
|
||
return clean if clean.strip() else "Chart"
|
||
|
||
|
||
CHART_PLAN_PROMPT = """你是一个数据可视化专家。根据以下分析结果,规划需要生成的图表。
|
||
|
||
## 探索结果
|
||
{exploration_summary}
|
||
|
||
## 规划规则
|
||
1. 每个有意义的查询结果生成 1 张图,最多 5 张
|
||
2. 图表类型:bar / horizontal_bar / pie / line / stacked_bar
|
||
3. 跳过数据量太少(<2 行)的结果
|
||
4. 标题要简洁
|
||
|
||
## 输出格式(纯 JSON 数组,不要代码块)
|
||
[
|
||
{{
|
||
"step_index": 0,
|
||
"chart_type": "bar",
|
||
"title": "图表标题",
|
||
"x_column": "分类轴列名",
|
||
"y_column": "数值轴列名",
|
||
"y2_column": null,
|
||
"top_n": 10,
|
||
"sort_desc": true
|
||
}}
|
||
]"""
|
||
|
||
|
||
class ChartGenerator:
|
||
"""图表生成器"""
|
||
|
||
def __init__(self, output_dir: str = "charts"):
|
||
self.output_dir = output_dir
|
||
self.client, self.model = get_llm_client(LLM_CONFIG)
|
||
|
||
def generate(self, steps: list[ExplorationStep], question: str) -> list[dict]:
|
||
valid_steps = [(i, s) for i, s in enumerate(steps) if s.success and s.rows and s.row_count >= 2 and s.action != "done"]
|
||
if not valid_steps:
|
||
return []
|
||
|
||
plans = self._plan_charts(valid_steps, question)
|
||
if not plans:
|
||
return []
|
||
|
||
self._clean_old_charts()
|
||
os.makedirs(self.output_dir, exist_ok=True)
|
||
|
||
charts = []
|
||
for i, plan in enumerate(plans):
|
||
try:
|
||
path = self._render_chart(plan, steps, i)
|
||
if path:
|
||
charts.append({"path": path, "title": plan.get("title", f"图表 {i+1}")})
|
||
except Exception as e:
|
||
print(f" ⚠️ 图表生成失败: {e}")
|
||
return charts
|
||
|
||
def _plan_charts(self, valid_steps: list[tuple[int, ExplorationStep]], question: str) -> list[dict]:
|
||
summary_parts = []
|
||
for idx, step in valid_steps:
|
||
summary_parts.append(
|
||
f"### 步骤 {idx}: {step.purpose}\n列: {step.columns}\n行数: {step.row_count}\n"
|
||
f"前 5 行: {json.dumps(step.rows[:5], ensure_ascii=False)}"
|
||
)
|
||
|
||
try:
|
||
content = llm_chat(
|
||
self.client, self.model,
|
||
messages=[
|
||
{"role": "system", "content": "你是数据可视化专家。只输出纯 JSON 数组,不要 markdown 代码块。"},
|
||
{"role": "user", "content": CHART_PLAN_PROMPT.format(exploration_summary="\n\n".join(summary_parts))},
|
||
],
|
||
temperature=0.1, max_tokens=1024,
|
||
)
|
||
plans = extract_json_array(content)
|
||
return plans if plans else self._fallback_plan(valid_steps)
|
||
except Exception as e:
|
||
print(f" ⚠️ 图表规划失败: {e},使用 fallback")
|
||
return self._fallback_plan(valid_steps)
|
||
|
||
def _fallback_plan(self, valid_steps: list[tuple[int, ExplorationStep]]) -> list[dict]:
|
||
plans = []
|
||
for idx, step in valid_steps[:4]:
|
||
if len(step.columns) < 2 or step.row_count < 2:
|
||
continue
|
||
x_col = step.columns[0]
|
||
y_col = None
|
||
for col in step.columns[1:]:
|
||
if isinstance(step.rows[0].get(col), (int, float)):
|
||
y_col = col
|
||
break
|
||
if not y_col:
|
||
continue
|
||
|
||
chart_type = "bar"
|
||
if any(kw in x_col for kw in ("月", "日期", "时间", "month", "date")):
|
||
chart_type = "line"
|
||
elif step.row_count <= 6:
|
||
chart_type = "pie"
|
||
elif len(str(step.rows[0].get(x_col, ""))) > 10:
|
||
chart_type = "horizontal_bar"
|
||
|
||
plans.append({
|
||
"step_index": idx, "chart_type": chart_type,
|
||
"title": f"各{x_col}的{y_col}",
|
||
"x_column": x_col, "y_column": y_col,
|
||
"y2_column": None, "top_n": 10,
|
||
"sort_desc": chart_type != "line",
|
||
})
|
||
return plans
|
||
|
||
def _render_chart(self, plan: dict, steps: list[ExplorationStep], chart_idx: int) -> str | None:
|
||
step_idx = plan.get("step_index", 0)
|
||
if step_idx >= len(steps):
|
||
return None
|
||
step = steps[step_idx]
|
||
if not step.success or not step.rows:
|
||
return None
|
||
|
||
chart_type = plan.get("chart_type", "bar")
|
||
title = plan.get("title", f"图表 {chart_idx + 1}")
|
||
x_col, y_col = plan.get("x_column", ""), plan.get("y_column", "")
|
||
y2_col = plan.get("y2_column")
|
||
top_n = plan.get("top_n", 15)
|
||
sort_desc = plan.get("sort_desc", True)
|
||
|
||
rows = step.rows[:top_n] if top_n else step.rows
|
||
x_vals = [str(r.get(x_col, "")) for r in rows]
|
||
y_vals = [self._to_number(r.get(y_col, 0)) for r in rows]
|
||
|
||
if sort_desc and chart_type not in ("line",):
|
||
paired = sorted(zip(x_vals, y_vals), key=lambda p: p[1], reverse=True)
|
||
x_vals, y_vals = zip(*paired) if paired else ([], [])
|
||
|
||
if not x_vals or not y_vals:
|
||
return None
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
if chart_type == "bar":
|
||
bars = ax.bar(range(len(x_vals)), y_vals, color="#4C78A8")
|
||
ax.set_xticks(range(len(x_vals)))
|
||
ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9)
|
||
self._add_bar_labels(ax, bars)
|
||
elif chart_type == "horizontal_bar":
|
||
bars = ax.barh(range(len(x_vals)), y_vals, color="#4C78A8")
|
||
ax.set_yticks(range(len(x_vals)))
|
||
ax.set_yticklabels(x_vals, fontsize=9)
|
||
ax.invert_yaxis()
|
||
elif chart_type == "pie":
|
||
filtered = [(x, y) for x, y in zip(x_vals, y_vals) if y > 0]
|
||
if not filtered:
|
||
plt.close(fig)
|
||
return None
|
||
x_vals, y_vals = zip(*filtered)
|
||
ax.pie(y_vals, labels=x_vals, autopct="%1.1f%%", startangle=90, textprops={"fontsize": 9})
|
||
elif chart_type == "line":
|
||
ax.plot(range(len(x_vals)), y_vals, marker="o", color="#4C78A8", linewidth=2)
|
||
ax.set_xticks(range(len(x_vals)))
|
||
ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9)
|
||
ax.fill_between(range(len(x_vals)), y_vals, alpha=0.1, color="#4C78A8")
|
||
if y2_col:
|
||
y2_vals = [self._to_number(r.get(y2_col, 0)) for r in rows]
|
||
ax2 = ax.twinx()
|
||
ax2.plot(range(len(x_vals)), y2_vals, marker="s", color="#E45756", linewidth=2, linestyle="--")
|
||
ax2.set_ylabel(y2_col, fontsize=10, color="#E45756")
|
||
elif chart_type == "stacked_bar":
|
||
ax.bar(range(len(x_vals)), y_vals, label=y_col, color="#4C78A8")
|
||
if y2_col:
|
||
y2_vals = [self._to_number(r.get(y2_col, 0)) for r in rows]
|
||
ax.bar(range(len(x_vals)), y2_vals, bottom=y_vals, label=y2_col, color="#E45756")
|
||
ax.set_xticks(range(len(x_vals)))
|
||
ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9)
|
||
ax.legend()
|
||
|
||
ax.set_title(_safe_title(title), fontsize=13, fontweight="bold", pad=12)
|
||
if chart_type not in ("pie",):
|
||
ax.set_xlabel(_safe_title(x_col), fontsize=10)
|
||
if chart_type != "horizontal_bar":
|
||
ax.set_ylabel(_safe_title(y_col), fontsize=10)
|
||
ax.grid(axis="y", alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
fname = f"chart_{chart_idx + 1}.png"
|
||
fpath = os.path.join(self.output_dir, fname)
|
||
fig.savefig(fpath, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
return fpath
|
||
|
||
def _clean_old_charts(self):
|
||
if os.path.isdir(self.output_dir):
|
||
for f in os.listdir(self.output_dir):
|
||
if f.endswith(".png"):
|
||
try:
|
||
os.remove(os.path.join(self.output_dir, f))
|
||
except OSError:
|
||
pass
|
||
|
||
def _add_bar_labels(self, ax, bars):
|
||
for bar in bars:
|
||
h = bar.get_height()
|
||
if h > 0:
|
||
label = f"{h:.1f}" if isinstance(h, float) else str(int(h))
|
||
ax.text(bar.get_x() + bar.get_width() / 2, h, label, ha="center", va="bottom", fontsize=8)
|
||
|
||
def _to_number(self, val) -> float:
|
||
if isinstance(val, (int, float)):
|
||
return float(val)
|
||
if isinstance(val, str):
|
||
try:
|
||
return float(val.replace("<", "").replace(",", "").strip())
|
||
except ValueError:
|
||
return 0.0
|
||
return 0.0
|