""" 图表生成器 —— 根据探索结果自动生成可视化图表 """ import json import os import re from typing import Any import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.font_manager as fm from core.config import LLM_CONFIG from core.utils import get_llm_client, extract_json_array from layers.explorer import ExplorationStep def _setup_chinese_font(): candidates = [ "SimHei", "Microsoft YaHei", "STHeiti", "WenQuanYi Micro Hei", "Noto Sans CJK SC", "PingFang SC", "Source Han Sans CN", ] available = {f.name for f in fm.fontManager.ttflist} for font in candidates: if font in available: plt.rcParams["font.sans-serif"] = [font] plt.rcParams["axes.unicode_minus"] = False return font plt.rcParams["axes.unicode_minus"] = False return None _setup_chinese_font() CHART_PLAN_PROMPT = """你是一个数据可视化专家。根据以下分析结果,规划需要生成的图表。 ## 探索结果 {exploration_summary} ## 规划规则 1. 每个有意义的查询结果生成 1 张图,最多 5 张 2. 图表类型:bar / horizontal_bar / pie / line / stacked_bar 3. 跳过数据量太少(<2 行)的结果 4. 标题要简洁 ## 输出格式(纯 JSON 数组,不要代码块) [ {{ "step_index": 0, "chart_type": "bar", "title": "图表标题", "x_column": "分类轴列名", "y_column": "数值轴列名", "y2_column": null, "top_n": 10, "sort_desc": true }} ]""" class ChartGenerator: """图表生成器""" def __init__(self, output_dir: str = "charts"): self.output_dir = output_dir self.client, self.model = get_llm_client(LLM_CONFIG) def generate(self, steps: list[ExplorationStep], question: str) -> list[dict]: valid_steps = [(i, s) for i, s in enumerate(steps) if s.success and s.rows and s.row_count >= 2 and s.action != "done"] if not valid_steps: return [] plans = self._plan_charts(valid_steps, question) if not plans: return [] self._clean_old_charts() os.makedirs(self.output_dir, exist_ok=True) charts = [] for i, plan in enumerate(plans): try: path = self._render_chart(plan, steps, i) if path: charts.append({"path": path, "title": plan.get("title", f"图表 {i+1}")}) except Exception as e: print(f" ⚠️ 图表生成失败: {e}") return charts def _plan_charts(self, valid_steps: list[tuple[int, ExplorationStep]], question: str) -> list[dict]: summary_parts = [] for idx, step in valid_steps: summary_parts.append( f"### 步骤 {idx}: {step.purpose}\n列: {step.columns}\n行数: {step.row_count}\n" f"前 5 行: {json.dumps(step.rows[:5], ensure_ascii=False)}" ) try: response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": "你是数据可视化专家。只输出纯 JSON 数组,不要 markdown 代码块。"}, {"role": "user", "content": CHART_PLAN_PROMPT.format(exploration_summary="\n\n".join(summary_parts))}, ], temperature=0.1, max_tokens=1024, ) plans = extract_json_array(response.choices[0].message.content.strip()) return plans if plans else self._fallback_plan(valid_steps) except Exception as e: print(f" ⚠️ 图表规划失败: {e},使用 fallback") return self._fallback_plan(valid_steps) def _fallback_plan(self, valid_steps: list[tuple[int, ExplorationStep]]) -> list[dict]: plans = [] for idx, step in valid_steps[:4]: if len(step.columns) < 2 or step.row_count < 2: continue x_col = step.columns[0] y_col = None for col in step.columns[1:]: if isinstance(step.rows[0].get(col), (int, float)): y_col = col break if not y_col: continue chart_type = "bar" if any(kw in x_col for kw in ("月", "日期", "时间", "month", "date")): chart_type = "line" elif step.row_count <= 6: chart_type = "pie" elif len(str(step.rows[0].get(x_col, ""))) > 10: chart_type = "horizontal_bar" plans.append({ "step_index": idx, "chart_type": chart_type, "title": f"各{x_col}的{y_col}", "x_column": x_col, "y_column": y_col, "y2_column": None, "top_n": 10, "sort_desc": chart_type != "line", }) return plans def _render_chart(self, plan: dict, steps: list[ExplorationStep], chart_idx: int) -> str | None: step_idx = plan.get("step_index", 0) if step_idx >= len(steps): return None step = steps[step_idx] if not step.success or not step.rows: return None chart_type = plan.get("chart_type", "bar") title = plan.get("title", f"图表 {chart_idx + 1}") x_col, y_col = plan.get("x_column", ""), plan.get("y_column", "") y2_col = plan.get("y2_column") top_n = plan.get("top_n", 15) sort_desc = plan.get("sort_desc", True) rows = step.rows[:top_n] if top_n else step.rows x_vals = [str(r.get(x_col, "")) for r in rows] y_vals = [self._to_number(r.get(y_col, 0)) for r in rows] if sort_desc and chart_type not in ("line",): paired = sorted(zip(x_vals, y_vals), key=lambda p: p[1], reverse=True) x_vals, y_vals = zip(*paired) if paired else ([], []) if not x_vals or not y_vals: return None fig, ax = plt.subplots(figsize=(10, 6)) if chart_type == "bar": bars = ax.bar(range(len(x_vals)), y_vals, color="#4C78A8") ax.set_xticks(range(len(x_vals))) ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9) self._add_bar_labels(ax, bars) elif chart_type == "horizontal_bar": bars = ax.barh(range(len(x_vals)), y_vals, color="#4C78A8") ax.set_yticks(range(len(x_vals))) ax.set_yticklabels(x_vals, fontsize=9) ax.invert_yaxis() elif chart_type == "pie": filtered = [(x, y) for x, y in zip(x_vals, y_vals) if y > 0] if not filtered: plt.close(fig) return None x_vals, y_vals = zip(*filtered) ax.pie(y_vals, labels=x_vals, autopct="%1.1f%%", startangle=90, textprops={"fontsize": 9}) elif chart_type == "line": ax.plot(range(len(x_vals)), y_vals, marker="o", color="#4C78A8", linewidth=2) ax.set_xticks(range(len(x_vals))) ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9) ax.fill_between(range(len(x_vals)), y_vals, alpha=0.1, color="#4C78A8") if y2_col: y2_vals = [self._to_number(r.get(y2_col, 0)) for r in rows] ax2 = ax.twinx() ax2.plot(range(len(x_vals)), y2_vals, marker="s", color="#E45756", linewidth=2, linestyle="--") ax2.set_ylabel(y2_col, fontsize=10, color="#E45756") elif chart_type == "stacked_bar": ax.bar(range(len(x_vals)), y_vals, label=y_col, color="#4C78A8") if y2_col: y2_vals = [self._to_number(r.get(y2_col, 0)) for r in rows] ax.bar(range(len(x_vals)), y2_vals, bottom=y_vals, label=y2_col, color="#E45756") ax.set_xticks(range(len(x_vals))) ax.set_xticklabels(x_vals, rotation=45, ha="right", fontsize=9) ax.legend() ax.set_title(title, fontsize=13, fontweight="bold", pad=12) if chart_type not in ("pie",): ax.set_xlabel(x_col, fontsize=10) if chart_type != "horizontal_bar": ax.set_ylabel(y_col, fontsize=10) ax.grid(axis="y", alpha=0.3) plt.tight_layout() fname = f"chart_{chart_idx + 1}.png" fpath = os.path.join(self.output_dir, fname) fig.savefig(fpath, dpi=150, bbox_inches="tight") plt.close(fig) return fpath def _clean_old_charts(self): if os.path.isdir(self.output_dir): for f in os.listdir(self.output_dir): if f.endswith(".png"): try: os.remove(os.path.join(self.output_dir, f)) except OSError: pass def _add_bar_labels(self, ax, bars): for bar in bars: h = bar.get_height() if h > 0: label = f"{h:.1f}" if isinstance(h, float) else str(int(h)) ax.text(bar.get_x() + bar.get_width() / 2, h, label, ha="center", va="bottom", fontsize=8) def _to_number(self, val) -> float: if isinstance(val, (int, float)): return float(val) if isinstance(val, str): try: return float(val.replace("<", "").replace(",", "").strip()) except ValueError: return 0.0 return 0.0