Files
iov_ana/output/chart.py
openclaw e8f8e2f1ba feat: 四层架构全面增强
安全与稳定性:
- 移除硬编码 API Key,改用 .env + 环境变量
- LLM 调用统一重试机制(指数退避,3 次重试,处理 429/5xx/超时)
- 中文字体检测增强(CJK 关键词兜底 + 无字体时英文 fallback)
- 缺失 API Key 给出友好提示而非崩溃

分析能力提升:
- 异常检测新增 z-score 检测(标准差>2 标记异常)
- 新增变异系数 CV 检测(数据波动性)
- 新增零值/缺失检测
- 上下文管理器升级为关键词语义匹配(替代简单取最近 2 条)

用户体验:
- 报告自动保存为 Markdown(reports/ 目录)
- 新增 export 命令导出查询结果为 CSV
- 新增 reports 命令查看已保存报告
- CLI 支持 readline 命令历史(方向键翻阅)
- CSV 导入工具重写:自动列名映射、容错处理、dry-run 模式
- 新增 .env.example 配置模板
2026-03-31 14:39:17 +08:00

266 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
图表生成器 —— 根据探索结果自动生成可视化图表
"""
import json
import os
import re
from typing import Any
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from core.config import LLM_CONFIG
from core.utils import get_llm_client, llm_chat, extract_json_array
from layers.explorer import ExplorationStep
def _setup_chinese_font():
"""尝试加载中文字体,找不到时用英文显示(不崩溃)"""
candidates = [
"SimHei", "Microsoft YaHei", "STHeiti", "WenQuanYi Micro Hei",
"Noto Sans CJK SC", "PingFang SC", "Source Han Sans CN",
]
available = {f.name for f in fm.fontManager.ttflist}
for font in candidates:
if font in available:
plt.rcParams["font.sans-serif"] = [font]
plt.rcParams["axes.unicode_minus"] = False
return font
# 兜底:尝试找任何 CJK 字体
for f in fm.fontManager.ttflist:
if any(kw in f.name.lower() for kw in ("cjk", "chinese", "hei", "song", "ming", "fang")):
plt.rcParams["font.sans-serif"] = [f.name]
plt.rcParams["axes.unicode_minus"] = False
return f.name
plt.rcParams["axes.unicode_minus"] = False
return None # 后续图表标题会用英文 fallback
_CN_FONT = _setup_chinese_font()
def _safe_title(title: str) -> str:
"""无中文字体时将标题转为安全显示文本"""
if _CN_FONT:
return title
# 简单映射:中文→拼音首字母摘要,保留英文和数字
import re
clean = re.sub(r'[^\w\s.,;:!?%/()\-+]', '', title)
return clean if clean.strip() else "Chart"
CHART_PLAN_PROMPT = """你是一个数据可视化专家。根据以下分析结果,规划需要生成的图表。
## 探索结果
{exploration_summary}
## 规划规则
1. 每个有意义的查询结果生成 1 张图,最多 5 张
2. 图表类型bar / horizontal_bar / pie / line / stacked_bar
3. 跳过数据量太少(<2 行)的结果
4. 标题要简洁
## 输出格式(纯 JSON 数组,不要代码块)
[
{{
"step_index": 0,
"chart_type": "bar",
"title": "图表标题",
"x_column": "分类轴列名",
"y_column": "数值轴列名",
"y2_column": null,
"top_n": 10,
"sort_desc": true
}}
]"""
class ChartGenerator:
"""图表生成器"""
def __init__(self, output_dir: str = "charts"):
self.output_dir = output_dir
self.client, self.model = get_llm_client(LLM_CONFIG)
def generate(self, steps: list[ExplorationStep], question: str) -> list[dict]:
valid_steps = [(i, s) for i, s in enumerate(steps) if s.success and s.rows and s.row_count >= 2 and s.action != "done"]
if not valid_steps:
return []
plans = self._plan_charts(valid_steps, question)
if not plans:
return []
self._clean_old_charts()
os.makedirs(self.output_dir, exist_ok=True)
charts = []
for i, plan in enumerate(plans):
try:
path = self._render_chart(plan, steps, i)
if path:
charts.append({"path": path, "title": plan.get("title", f"图表 {i+1}")})
except Exception as e:
print(f" ⚠️ 图表生成失败: {e}")
return charts
def _plan_charts(self, valid_steps: list[tuple[int, ExplorationStep]], question: str) -> list[dict]:
summary_parts = []
for idx, step in valid_steps:
summary_parts.append(
f"### 步骤 {idx}: {step.purpose}\n列: {step.columns}\n行数: {step.row_count}\n"
f"前 5 行: {json.dumps(step.rows[:5], ensure_ascii=False)}"
)
try:
content = llm_chat(
self.client, self.model,
messages=[
{"role": "system", "content": "你是数据可视化专家。只输出纯 JSON 数组,不要 markdown 代码块。"},
{"role": "user", "content": CHART_PLAN_PROMPT.format(exploration_summary="\n\n".join(summary_parts))},
],
temperature=0.1, max_tokens=1024,
)
plans = extract_json_array(content)
return plans if plans else self._fallback_plan(valid_steps)
except Exception as e:
print(f" ⚠️ 图表规划失败: {e},使用 fallback")
return self._fallback_plan(valid_steps)
def _fallback_plan(self, valid_steps: list[tuple[int, ExplorationStep]]) -> list[dict]:
plans = []
for idx, step in valid_steps[:4]:
if len(step.columns) < 2 or step.row_count < 2:
continue
x_col = step.columns[0]
y_col = None
for col in step.columns[1:]:
if isinstance(step.rows[0].get(col), (int, float)):
y_col = col
break
if not y_col:
continue
chart_type = "bar"
if any(kw in x_col for kw in ("", "日期", "时间", "month", "date")):
chart_type = "line"
elif step.row_count <= 6:
chart_type = "pie"
elif len(str(step.rows[0].get(x_col, ""))) > 10:
chart_type = "horizontal_bar"
plans.append({
"step_index": idx, "chart_type": chart_type,
"title": f"{x_col}{y_col}",
"x_column": x_col, "y_column": y_col,
"y2_column": None, "top_n": 10,
"sort_desc": chart_type != "line",
})
return plans
def _render_chart(self, plan: dict, steps: list[ExplorationStep], chart_idx: int) -> str | None:
step_idx = plan.get("step_index", 0)
if step_idx >= len(steps):
return None
step = steps[step_idx]
if not step.success or not step.rows:
return None
chart_type = plan.get("chart_type", "bar")
title = plan.get("title", f"图表 {chart_idx + 1}")
x_col, y_col = plan.get("x_column", ""), plan.get("y_column", "")
y2_col = plan.get("y2_column")
top_n = plan.get("top_n", 15)
sort_desc = plan.get("sort_desc", True)
rows = step.rows[:top_n] if top_n else step.rows
x_vals = [str(r.get(x_col, "")) for r in rows]
y_vals = [self._to_number(r.get(y_col, 0)) for r in rows]
if sort_desc and chart_type not in ("line",):
paired = sorted(zip(x_vals, y_vals), key=lambda p: p[1], reverse=True)
x_vals, y_vals = zip(*paired) if paired else ([], [])
if not x_vals or not y_vals:
return None
fig, ax = plt.subplots(figsize=(10, 6))
if chart_type == "bar":
bars = ax.bar(range(len(x_vals)), y_vals, color="#4C78A8")
ax.set_xticks(range(len(x_vals)))
ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9)
self._add_bar_labels(ax, bars)
elif chart_type == "horizontal_bar":
bars = ax.barh(range(len(x_vals)), y_vals, color="#4C78A8")
ax.set_yticks(range(len(x_vals)))
ax.set_yticklabels(x_vals, fontsize=9)
ax.invert_yaxis()
elif chart_type == "pie":
filtered = [(x, y) for x, y in zip(x_vals, y_vals) if y > 0]
if not filtered:
plt.close(fig)
return None
x_vals, y_vals = zip(*filtered)
ax.pie(y_vals, labels=x_vals, autopct="%1.1f%%", startangle=90, textprops={"fontsize": 9})
elif chart_type == "line":
ax.plot(range(len(x_vals)), y_vals, marker="o", color="#4C78A8", linewidth=2)
ax.set_xticks(range(len(x_vals)))
ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9)
ax.fill_between(range(len(x_vals)), y_vals, alpha=0.1, color="#4C78A8")
if y2_col:
y2_vals = [self._to_number(r.get(y2_col, 0)) for r in rows]
ax2 = ax.twinx()
ax2.plot(range(len(x_vals)), y2_vals, marker="s", color="#E45756", linewidth=2, linestyle="--")
ax2.set_ylabel(y2_col, fontsize=10, color="#E45756")
elif chart_type == "stacked_bar":
ax.bar(range(len(x_vals)), y_vals, label=y_col, color="#4C78A8")
if y2_col:
y2_vals = [self._to_number(r.get(y2_col, 0)) for r in rows]
ax.bar(range(len(x_vals)), y2_vals, bottom=y_vals, label=y2_col, color="#E45756")
ax.set_xticks(range(len(x_vals)))
ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9)
ax.legend()
ax.set_title(_safe_title(title), fontsize=13, fontweight="bold", pad=12)
if chart_type not in ("pie",):
ax.set_xlabel(_safe_title(x_col), fontsize=10)
if chart_type != "horizontal_bar":
ax.set_ylabel(_safe_title(y_col), fontsize=10)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
fname = f"chart_{chart_idx + 1}.png"
fpath = os.path.join(self.output_dir, fname)
fig.savefig(fpath, dpi=150, bbox_inches="tight")
plt.close(fig)
return fpath
def _clean_old_charts(self):
if os.path.isdir(self.output_dir):
for f in os.listdir(self.output_dir):
if f.endswith(".png"):
try:
os.remove(os.path.join(self.output_dir, f))
except OSError:
pass
def _add_bar_labels(self, ax, bars):
for bar in bars:
h = bar.get_height()
if h > 0:
label = f"{h:.1f}" if isinstance(h, float) else str(int(h))
ax.text(bar.get_x() + bar.get_width() / 2, h, label, ha="center", va="bottom", fontsize=8)
def _to_number(self, val) -> float:
if isinstance(val, (int, float)):
return float(val)
if isinstance(val, str):
try:
return float(val.replace("<", "").replace(",", "").strip())
except ValueError:
return 0.0
return 0.0