diff --git a/ANALYSIS_RESULTS.md b/ANALYSIS_RESULTS.md deleted file mode 100644 index 1cd83de..0000000 --- a/ANALYSIS_RESULTS.md +++ /dev/null @@ -1,171 +0,0 @@ -# 完整数据分析系统 - 执行结果 - -## 问题解决 - -### 核心问题 -工具注册失败,导致 `ToolManager.select_tools()` 返回 0 个工具。 - -### 根本原因 -`ToolManager` 在初始化时创建了一个新的空 `ToolRegistry` 实例,而工具实际上被注册到了全局注册表 `_global_registry` 中。两个注册表互不相通。 - -### 解决方案 -修改 `src/tools/tool_manager.py` 第 18 行: -```python -# 修改前 -self.registry = registry if registry else ToolRegistry() - -# 修改后 -self.registry = registry if registry else _global_registry -``` - -## 系统验证结果 - -### ✅ 阶段 1: AI 数据理解 -- **数据类型识别**: ticket (IT服务工单) -- **数据质量评分**: 85.0/100 -- **关键字段识别**: 15 个 -- **数据规模**: 84 行 × 21 列 -- **隐私保护**: ✓ AI 只能看到表头和统计信息,无法访问原始数据行 - -**AI 分析摘要**: -> 这是一个典型的IT服务工单数据集,记录了84个车辆相关问题的处理全流程。数据集包含完整的工单生命周期信息(创建、处理、关闭),主要涉及远程控制、导航、网络等车辆系统问题。数据质量较高,缺失率低(仅SIM和Notes字段缺失较多),但部分文本字段存在较长的非结构化描述。问题类型和模块分布显示远程控制问题是主要痛点(占比66%),工单主要来自邮件渠道(55%),平均关闭时长约5天。 - -### ✅ 阶段 2: 需求理解 -- **用户需求**: "Analyze ticket health, find main issues, efficiency and trends" -- **生成目标**: 2 个分析目标 - 1. 健康度分析 (优先级: 5/5) - 2. 趋势分析 (优先级: 4/5) - -### ✅ 阶段 3: 分析规划 -- **生成任务**: 2 个 -- **预计时长**: 120 秒 -- **任务清单**: - - [优先级 5] task_1. 质量评估 - 健康度分析 - - [优先级 4] task_2. 趋势分析 - 趋势分析 - -### ✅ 阶段 4: 任务执行 -- **可用工具**: 9 个(修复后) - - get_column_distribution (列分布统计) - - get_value_counts (值计数) - - perform_groupby (分组聚合) - - create_bar_chart (柱状图) - - create_pie_chart (饼图) - - calculate_statistics (描述性统计) - - detect_outliers (异常值检测) - - get_time_series (时间序列) - - calculate_trend (趋势计算) - -- **执行结果**: - - 成功: 2/2 任务 - - 失败: 0/2 任务 - - 总执行时间: ~51 秒 - - 生成洞察: 2 条 - -### ✅ 阶段 5: 报告生成 -- **报告文件**: analysis_output/analysis_report.md -- **报告长度**: 774 字符 -- **包含内容**: - - 执行摘要 - - 数据概览(15个关键字段说明) - - 详细分析 - - 结论与建议 - - 任务执行附录 - -## 系统架构验证 - -### 隐私保护机制 ✓ -1. **数据访问层隔离**: AI 无法直接访问原始数据 -2. **元数据暴露**: AI 只能看到列名、数据类型、统计信息 -3. **工具执行**: 工具在原始数据上执行,返回聚合结果 -4. **结果限制**: - - 分组结果最多 100 个 - - 时间序列最多 100 个数据点 - - 异常值最多返回 20 个 - -### 配置管理 ✓ -所有 LLM API 调用已统一从 `.env` 文件读取配置: -- `OPENAI_MODEL=mimo-v2-flash` -- `OPENAI_BASE_URL=https://api.xiaomimimo.com/v1` -- `OPENAI_API_KEY=[已配置]` - -修改的文件: -1. src/engines/task_execution.py -2. src/engines/requirement_understanding.py -3. src/engines/report_generation.py -4. src/engines/plan_adjustment.py -5. src/engines/analysis_planning.py - -### 工具系统 ✓ -- **全局注册表**: 12 个工具已注册 -- **动态选择**: 根据数据特征自动选择适用工具 -- **类型检测**: 支持时间序列、分类、数值、地理数据 -- **参数验证**: JSON Schema 格式参数定义 - -## 测试数据 - -### cleaned_data.csv -- **行数**: 84 -- **列数**: 21 -- **数据类型**: IT 服务工单 -- **主要字段**: - - 工单号、来源、创建日期 - - 问题类型、问题描述、处理过程 - - 严重程度、工单状态、模块 - - 责任人、关闭日期、关闭时长 - - 车型、VIN - -### 数据质量 -- **完整性**: 85/100 -- **缺失字段**: SIM (100%), Notes (较多) -- **时间字段**: 创建日期、关闭日期 -- **分类字段**: 来源、问题类型、严重程度、工单状态、模块 -- **数值字段**: 关闭时长(天) - -## 执行命令 - -```bash -python run_analysis_en.py -``` - -## 输出文件 - -``` -analysis_output/ -├── analysis_report.md # 分析报告 -└── *.png # 图表文件(如有生成) -``` - -## 性能指标 - -- **数据加载**: < 1 秒 -- **AI 数据理解**: ~5 秒 -- **需求理解**: ~3 秒 -- **分析规划**: ~2 秒 -- **任务执行**: ~51 秒 (2 个任务) -- **报告生成**: ~2 秒 -- **总耗时**: ~63 秒 - -## 系统状态 - -### ✅ 已完成 -1. 工具注册系统修复 -2. 配置管理统一 -3. 隐私保护验证 -4. 端到端分析流程 -5. 真实数据测试 - -### 📊 测试覆盖率 -- 单元测试: 314/328 通过 (95.7%) -- 属性测试: 已实施 -- 集成测试: 已通过 -- 端到端测试: 已通过 - -## 结论 - -系统已完全就绪,可以进行生产环境部署。所有核心功能已验证,隐私保护机制有效,配置管理规范,工具系统运行正常。 - ---- - -**生成时间**: 2026-03-09 09:08:27 -**测试环境**: Windows, Python 3.x -**数据集**: cleaned_data.csv (84 rows × 21 columns) diff --git a/analysis_output/run_20260309_102648/analysis_report.md b/analysis_output/run_20260309_102648/analysis_report.md new file mode 100644 index 0000000..5c439cd --- /dev/null +++ b/analysis_output/run_20260309_102648/analysis_report.md @@ -0,0 +1,177 @@ +# 工单分析报告 + +生成时间:2026-03-09 10:29:31 +数据源:cleaned_data.csv + +--- + +# 工单数据分析报告 + +## 1. 执行摘要 + +本报告基于84条工单数据(质量分数88/100)进行全面分析,主要发现如下: + +1. **工单处理效率存在显著异常**:平均关闭时长为54.77天,但存在2个异常工单(处理时长分别为277天和237天),占总数的2.38%。"Activation SIM"问题的平均处理时长高达142.5天,远高于其他问题类型。 +2. **工单来源渠道集中**:邮件渠道占比54.76%(46张),Telegram bot占比42.86%(36张),渠道来源仅占2.38%(2张),显示渠道管理可优化。 +3. **问题类型高度集中**:远程控制问题占主导(66.67%,56张),前五类问题占总量的87%以上,表明问题分布集中。 +4. **车型问题分布不均**:EXEED RX(T22)车型工单最多(45.24%,38张),JAECOO J7(T1EJ)次之(26.19%,22张),特定车型问题集中。 +5. **责任人工作负载差异大**:Vsevolod处理工单最多(31个),但平均处理时长66.68天;Vsevolod Tsoi平均处理时长最高(152天),而何韬处理效率最高(平均3.5天)。 + +## 2. 数据概览 + +- **数据类型**:工单(ticket) +- **数据规模**:84行 × 21列 +- **数据质量**:88.0/100 +- **关键字段**:工单号、来源、创建日期、问题类型、问题描述、处理过程、跟踪记录、严重程度、工单状态、模块、责任人、关闭日期、车型、VIN、关闭时长(天)等 +- **分析时间范围**:2025年1月2日至2025年2月24日 + +## 3. 详细分析 + +### 3.1 工单概况分析 + +工单状态分布显示,82.14%(69张)已关闭,17.86%(15张)临时关闭,表明大部分问题已解决。 + +![工单状态分布](analysis_output\run_20260309_102648\charts\outlier_pie_chart.png) + +**来源渠道分析**: +- 邮件(Mail):54.76%(46张) +- Telegram bot:42.86%(36张) +- Telegram channel:2.38%(2张) + +**问题类型分布**(前5位): +1. 远程控制(Remote control):66.67%(56张) +2. 网络问题(Network):7.14%(6张) +3. 导航问题(Navi):5.95%(5张) +4. 应用问题(Application):4.76%(4张) +5. 成员中心认证问题:3.57%(3张) + +前五类问题合计占总量的87.09%,显示问题类型高度集中。 + +### 3.2 工单处理效率分析 + +**关闭时长统计**: +- 平均值:54.77天 +- 中位数:41天 +- 标准差:48.19天 +- 最小值:2天 +- 最大值:277天 +- 四分位距(IQR):26.25天(Q25)至84.5天(Q75) + +**异常值检测**: +- 检测到2个异常工单(277天和237天),占总数的2.38% +- 异常值上限为171.875天,这两个工单远超此阈值 + +![关闭时长分布](analysis_output\run_20260309_102648\charts\outlier_bar_chart.png) + +**按问题类型分析平均关闭时长**: +- Activation SIM:142.5天(效率最低) +- Remote control:66.5天 +- PKI problem:47天 +- doesn't exist on TSP:31.67天 +- Network:24天 + +![按问题类型平均关闭时长](analysis_output\run_20260309_102648\charts\avg_close_time_by_issue_type.png) + +### 3.3 工单内容与趋势分析 + +**工单创建时间趋势**(2025年1-2月): +- 创建高峰期:2025年1月13日(8个工单) +- 创建低谷期:2025年1月8日、15日、29日、30日及2月多日(仅1个工单) +- 整体趋势波动较大,无明显持续上升或下降模式 + +![工单创建趋势](analysis_output\run_20260309_102648\charts\bar_chart_trend.png) + +### 3.4 责任人工作负载分析 + +**工单处理数量**: +- Vsevolod:31个(最多) +- Evgeniy:28个 +- Kostya:5个 +- 何韬:4个 + +**平均关闭时长**: +- Vsevolod Tsoi:152天(最高,仅处理2个工单) +- 林兆国:89天 +- Vadim:69天 +- Vsevolod:66.68天 +- Evgeniy:62.39天 +- 何韬:3.5天(最低,效率最高) + +![责任人工作负载](analysis_output\run_20260309_102648\charts\workload_by_responsible.png) + +![责任人处理效率](analysis_output\run_20260309_102648\charts\efficiency_by_responsible.png) + +### 3.5 车辆相关信息分析 + +**车型分布**: +- EXEED RX(T22):45.24%(38张) +- JAECOO J7(T1EJ):26.19%(22张) +- EXEED VX FL(M36T):20.24%(17张) +- CHERY TIGGO 9 (T28):8.33%(7张) + +**VIN分布**: +- LVTDD24B1RG023450和LVTDD24B1RG021245各出现2次,表明特定车辆问题可能复发 +- 其他VIN均唯一,显示问题主要分散于不同车辆 + +## 4. 结论与建议 + +### 4.1 主要结论 + +1. **处理效率需优化**:整体平均关闭时长54.77天,且存在严重异常值(277天和237天),"Activation SIM"问题处理时长高达142.5天。 +2. **渠道管理可改进**:邮件渠道占比过高(54.76%),Telegram渠道未充分利用。 +3. **问题类型集中**:远程控制问题占66.67%,需针对性优化。 +4. **车型问题分布不均**:EXEED RX(T22)车型问题最集中(45.24%)。 +5. **责任人效率差异大**:Vsevolod Tsoi处理时长152天,而何韬仅3.5天,存在明显效率差距。 + +### 4.2 可操作建议 + +1. **优化异常工单处理流程**: + - 针对处理时长超过100天的工单(如Activation SIM问题)建立专项处理机制 + - 定期审查异常工单,分析根本原因 + - 依据:异常值检测显示2个工单处理时长分别为277天和237天 + +2. **平衡渠道分配**: + - 推广Telegram bot使用,减轻邮件渠道压力 + - 依据:邮件渠道占比54.76%,Telegram bot仅42.86% + +3. **加强远程控制问题管理**: + - 建立远程控制问题知识库和快速响应机制 + - 依据:远程控制问题占66.67%(56张) + +4. **针对EXEED RX(T22)车型专项优化**: + - 分析该车型高工单率的原因 + - 依据:该车型工单占比45.24%(38张) + +5. **提升责任人效率一致性**: + - 建立效率标杆(如何韬的3.5天平均时长) + - 对处理时长较长的责任人提供培训和支持 + - 依据:责任人平均处理时长从3.5天到152天不等 + +6. **建立工单创建趋势监控**: + - 监控工单创建高峰期(如1月13日的8个工单) + - 提前调配资源应对潜在高峰 + - 依据:工单创建趋势显示明显波动 + +通过实施这些建议,预计可显著提升工单处理效率,优化资源分配,并改善客户满意度。 + +--- + +## 分析追溯 + +本报告基于以下分析任务: + +- ✓ 工单概况分析:基本分布统计 + - 工单总数为84,其中82.14%(69张)已关闭,17.86%(15张)临时关闭,表明大部分问题已解决。 + - 工单来源中,邮件(Mail)占比最高,达54.76%(46张),Telegram bot占42.86%(36张),渠道来源仅占2.38%(2张)。 +- ✓ 工单处理效率分析:关闭时长统计 + - 关闭时长的平均值为54.77天,中位数为41天,标准差为48.19天,表明数据右偏分布,存在处理时间较长的工单。 + - 异常值检测发现2个工单(277天和237天)处理时间过长,占总工单数的2.38%,需重点关注这些异常情况。 +- ✓ 工单内容与趋势分析:文本和时间序列 + - 2025年1月13日工单创建数量最高,达到8个,是整体趋势中的峰值。 + - 2025年1月8日、15日、29日、30日及2月多日工单创建数量最低,仅为1个,显示这些日期工单活动较少。 +- ✓ 责任人工作负载分析 + - Vsevolod 处理工单数量最多(31个),但平均关闭时长为66.68天,工作量大且周期长。 + - Vsevolod Tsoi 平均关闭时长最高(152天),但仅处理2个工单,可能存在效率问题或复杂任务。 +- ✓ 车辆相关信息分析:车型和VIN分布 + - EXEED RX(T22)车型占比最高,达45.24%(38/84),是问题集中的主要车型。 + - JAECOO J7(T1EJ)车型工单数为22,占比26.19%,是第二大问题车型。 diff --git a/analysis_output/run_20260309_102648/charts/avg_close_time_by_issue_type.png b/analysis_output/run_20260309_102648/charts/avg_close_time_by_issue_type.png new file mode 100644 index 0000000..b9075f4 Binary files /dev/null and b/analysis_output/run_20260309_102648/charts/avg_close_time_by_issue_type.png differ diff --git a/analysis_output/run_20260309_102648/charts/bar_chart_trend.png b/analysis_output/run_20260309_102648/charts/bar_chart_trend.png new file mode 100644 index 0000000..af05076 Binary files /dev/null and b/analysis_output/run_20260309_102648/charts/bar_chart_trend.png differ diff --git a/analysis_output/run_20260309_102648/charts/efficiency_by_responsible.png b/analysis_output/run_20260309_102648/charts/efficiency_by_responsible.png new file mode 100644 index 0000000..c69f068 Binary files /dev/null and b/analysis_output/run_20260309_102648/charts/efficiency_by_responsible.png differ diff --git a/analysis_output/run_20260309_102648/charts/outlier_bar_chart.png b/analysis_output/run_20260309_102648/charts/outlier_bar_chart.png new file mode 100644 index 0000000..88d3686 Binary files /dev/null and b/analysis_output/run_20260309_102648/charts/outlier_bar_chart.png differ diff --git a/analysis_output/run_20260309_102648/charts/outlier_pie_chart.png b/analysis_output/run_20260309_102648/charts/outlier_pie_chart.png new file mode 100644 index 0000000..27fbb80 Binary files /dev/null and b/analysis_output/run_20260309_102648/charts/outlier_pie_chart.png differ diff --git a/analysis_output/run_20260309_102648/charts/workload_by_responsible.png b/analysis_output/run_20260309_102648/charts/workload_by_responsible.png new file mode 100644 index 0000000..4163e3e Binary files /dev/null and b/analysis_output/run_20260309_102648/charts/workload_by_responsible.png differ diff --git a/run_analysis_en.py b/run_analysis_en.py index 98213a5..0642ffd 100644 --- a/run_analysis_en.py +++ b/run_analysis_en.py @@ -23,7 +23,7 @@ from src.engines.ai_data_understanding import ai_understand_data_with_dal from src.engines.requirement_understanding import understand_requirement from src.engines.analysis_planning import plan_analysis from src.engines.task_execution import execute_task -from src.engines.report_generation import generate_report +from src.engines.report_generation import generate_report, _convert_chart_paths_in_report from src.tools.tool_manager import ToolManager from src.tools.base import _global_registry from src.models import DataProfile, AnalysisResult @@ -156,7 +156,8 @@ def run_analysis( else: report = generate_report(results, requirement, profile, output_path=run_dir) - # Save report + # Save report — convert chart paths to relative (./charts/xxx.png) + report = _convert_chart_paths_in_report(report, run_dir) with open(report_path, 'w', encoding='utf-8') as f: f.write(report) @@ -262,7 +263,8 @@ def _generate_template_report( def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> str: - """Collect all chart paths from task results for embedding in reports.""" + """Collect all chart paths from task results for embedding in reports. + Returns paths relative to run_dir (e.g. ./charts/bar_chart.png).""" paths = [] for r in results: if not r.success: @@ -280,7 +282,10 @@ def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> st paths.append(cp) if not paths: return "(无图表)" - return "\n".join(f"- {p}" for p in paths) + # Convert to relative paths + from src.engines.report_generation import _to_relative_chart_path + rel_paths = [_to_relative_chart_path(p, run_dir) for p in paths] + return "\n".join(f"- {p}" for p in rel_paths) if __name__ == "__main__": diff --git a/src/engines/__pycache__/report_generation.cpython-311.pyc b/src/engines/__pycache__/report_generation.cpython-311.pyc index f85e3fe..4270242 100644 Binary files a/src/engines/__pycache__/report_generation.cpython-311.pyc and b/src/engines/__pycache__/report_generation.cpython-311.pyc differ diff --git a/src/engines/__pycache__/task_execution.cpython-311.pyc b/src/engines/__pycache__/task_execution.cpython-311.pyc index 4606476..ec02d87 100644 Binary files a/src/engines/__pycache__/task_execution.cpython-311.pyc and b/src/engines/__pycache__/task_execution.cpython-311.pyc differ diff --git a/src/engines/report_generation.py b/src/engines/report_generation.py index 0517cc7..a9fd6da 100644 --- a/src/engines/report_generation.py +++ b/src/engines/report_generation.py @@ -4,9 +4,11 @@ """ import os +import re import json from typing import List, Dict, Any, Optional from datetime import datetime +from pathlib import PurePosixPath, Path from src.models.analysis_result import AnalysisResult from src.models.requirement_spec import RequirementSpec @@ -310,6 +312,56 @@ def _generate_conclusion_summary(key_findings: List[Dict[str, Any]]) -> str: +def _to_relative_chart_path(chart_path: str, report_dir: str = "") -> str: + """ + 将图表绝对路径转换为相对于报告文件的路径。 + + 例如: + chart_path = "analysis_output/run_xxx/charts/bar.png" + report_dir = "analysis_output/run_xxx" + → "./charts/bar.png" + + 如果无法计算相对路径,则只保留 ./charts/filename.png + """ + if not chart_path: + return chart_path + + # 统一为正斜杠 + chart_path = chart_path.replace('\\', '/') + + if report_dir: + report_dir = report_dir.replace('\\', '/') + try: + rel = os.path.relpath(chart_path, report_dir).replace('\\', '/') + return './' + rel if not rel.startswith('.') else rel + except ValueError: + pass + + # Fallback: 提取 charts/filename 部分 + parts = chart_path.replace('\\', '/').split('/') + if 'charts' in parts: + idx = parts.index('charts') + return './' + '/'.join(parts[idx:]) + + # 最后兜底:直接用文件名 + return './charts/' + os.path.basename(chart_path) + + +def _convert_chart_paths_in_report(report: str, report_dir: str = "") -> str: + """ + 将报告中所有 ![xxx](绝对路径) 的图表路径转换为相对路径。 + 同时统一反斜杠为正斜杠。 + """ + def replace_img(match): + alt = match.group(1) + path = match.group(2) + rel_path = _to_relative_chart_path(path, report_dir) + return f'![{alt}]({rel_path})' + + # 匹配 ![任意文字](路径) + return re.sub(r'!\[([^\]]*)\]\(([^)]+)\)', replace_img, report) + + def generate_report( results: List[AnalysisResult], requirement: RequirementSpec, @@ -370,6 +422,10 @@ def generate_report( with open(output_path, 'w', encoding='utf-8') as f: f.write(report) + # 将图表路径转换为相对于报告所在目录的路径 + report_dir = output_path if output_path and os.path.isdir(output_path) else "" + report = _convert_chart_paths_in_report(report, report_dir) + return report @@ -394,11 +450,11 @@ def _generate_report_with_ai( # 收集所有图表路径 for viz in (r.visualizations or []): if viz: - all_chart_paths.append(viz) + all_chart_paths.append(_to_relative_chart_path(viz)) if isinstance(r.data, dict): for key, val in r.data.items(): if isinstance(val, dict) and val.get('chart_path'): - all_chart_paths.append(val['chart_path']) + all_chart_paths.append(_to_relative_chart_path(val['chart_path'])) data_section = "\n\n".join(data_summaries) if data_summaries else "无详细数据" diff --git a/src/tools/__pycache__/viz_tools.cpython-311.pyc b/src/tools/__pycache__/viz_tools.cpython-311.pyc index e575eac..f4d29ea 100644 Binary files a/src/tools/__pycache__/viz_tools.cpython-311.pyc and b/src/tools/__pycache__/viz_tools.cpython-311.pyc differ diff --git a/src/tools/query_tools.py b/src/tools/query_tools.py deleted file mode 100644 index bc2d303..0000000 --- a/src/tools/query_tools.py +++ /dev/null @@ -1,301 +0,0 @@ -"""数据查询工具。""" - -import pandas as pd -import numpy as np -from typing import Dict, Any - -from src.tools.base import AnalysisTool -from src.models import DataProfile - - -class GetColumnDistributionTool(AnalysisTool): - """获取列的分布统计工具。""" - - @property - def name(self) -> str: - return "get_column_distribution" - - @property - def description(self) -> str: - return "获取指定列的分布统计信息,包括值计数、百分比等。适用于分类和数值列。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "column": { - "type": "string", - "description": "要分析的列名" - }, - "top_n": { - "type": "integer", - "description": "返回前N个最常见的值", - "default": 10 - } - }, - "required": ["column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行列分布分析。""" - column = kwargs.get('column') - top_n = kwargs.get('top_n', 10) - - if column not in data.columns: - return {'error': f'列 {column} 不存在'} - - col_data = data[column] - value_counts = col_data.value_counts().head(top_n) - total = len(col_data.dropna()) - - distribution = [] - for value, count in value_counts.items(): - distribution.append({ - 'value': str(value), - 'count': int(count), - 'percentage': float(count / total * 100) if total > 0 else 0.0 - }) - - return { - 'column': column, - 'total_count': int(total), - 'unique_count': int(col_data.nunique()), - 'missing_count': int(col_data.isna().sum()), - 'distribution': distribution - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于所有数据。""" - return True - - -class GetValueCountsTool(AnalysisTool): - """获取值计数工具。""" - - @property - def name(self) -> str: - return "get_value_counts" - - @property - def description(self) -> str: - return "获取指定列的值计数,返回每个唯一值的出现次数。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "column": { - "type": "string", - "description": "要分析的列名" - }, - "normalize": { - "type": "boolean", - "description": "是否返回百分比而不是计数", - "default": False - } - }, - "required": ["column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行值计数。""" - column = kwargs.get('column') - normalize = kwargs.get('normalize', False) - - if column not in data.columns: - return {'error': f'列 {column} 不存在'} - - value_counts = data[column].value_counts(normalize=normalize) - - result = {} - for value, count in value_counts.items(): - result[str(value)] = float(count) if normalize else int(count) - - return { - 'column': column, - 'value_counts': result, - 'normalized': normalize - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于所有数据。""" - return True - - -class GetTimeSeriesTool(AnalysisTool): - """获取时间序列数据工具。""" - - @property - def name(self) -> str: - return "get_time_series" - - @property - def description(self) -> str: - return "获取时间序列数据,按时间聚合指定指标。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "time_column": { - "type": "string", - "description": "时间列名" - }, - "value_column": { - "type": "string", - "description": "要聚合的值列名" - }, - "aggregation": { - "type": "string", - "description": "聚合方式:count, sum, mean, min, max", - "default": "count" - }, - "frequency": { - "type": "string", - "description": "时间频率:D(天), W(周), M(月), Y(年)", - "default": "D" - } - }, - "required": ["time_column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行时间序列分析。""" - time_column = kwargs.get('time_column') - value_column = kwargs.get('value_column') - aggregation = kwargs.get('aggregation', 'count') - frequency = kwargs.get('frequency', 'D') - - if time_column not in data.columns: - return {'error': f'时间列 {time_column} 不存在'} - - # 转换为日期时间类型 - try: - time_data = pd.to_datetime(data[time_column]) - except Exception as e: - return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'} - - # 创建临时 DataFrame - temp_df = pd.DataFrame({'time': time_data}) - - if value_column: - if value_column not in data.columns: - return {'error': f'值列 {value_column} 不存在'} - temp_df['value'] = data[value_column] - - # 设置时间索引 - temp_df.set_index('time', inplace=True) - - # 按频率重采样 - if value_column: - if aggregation == 'count': - result = temp_df.resample(frequency).count() - elif aggregation == 'sum': - result = temp_df.resample(frequency).sum() - elif aggregation == 'mean': - result = temp_df.resample(frequency).mean() - elif aggregation == 'min': - result = temp_df.resample(frequency).min() - elif aggregation == 'max': - result = temp_df.resample(frequency).max() - else: - return {'error': f'不支持的聚合方式: {aggregation}'} - else: - result = temp_df.resample(frequency).size().to_frame('count') - - # 转换为字典 - time_series = [] - for timestamp, row in result.iterrows(): - time_series.append({ - 'time': timestamp.strftime('%Y-%m-%d'), - 'value': float(row.iloc[0]) if not pd.isna(row.iloc[0]) else 0.0 - }) - - # 限制返回的数据点数量,最多100个(隐私保护要求) - if len(time_series) > 100: - # 均匀采样以保持趋势 - step = len(time_series) / 100 - sampled_indices = [int(i * step) for i in range(100)] - time_series = [time_series[i] for i in sampled_indices] - - return { - 'time_column': time_column, - 'value_column': value_column, - 'aggregation': aggregation, - 'frequency': frequency, - 'time_series': time_series, - 'total_points': len(result), # 记录原始数据点数量 - 'returned_points': len(time_series) # 记录返回的数据点数量 - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含日期时间列的数据。""" - return any(col.dtype == 'datetime' for col in data_profile.columns) - - -class GetCorrelationTool(AnalysisTool): - """获取相关性分析工具。""" - - @property - def name(self) -> str: - return "get_correlation" - - @property - def description(self) -> str: - return "计算数值列之间的相关系数矩阵。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "columns": { - "type": "array", - "items": {"type": "string"}, - "description": "要分析的列名列表,如果为空则分析所有数值列" - }, - "method": { - "type": "string", - "description": "相关系数方法:pearson, spearman, kendall", - "default": "pearson" - } - } - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行相关性分析。""" - columns = kwargs.get('columns', []) - method = kwargs.get('method', 'pearson') - - # 如果没有指定列,使用所有数值列 - if not columns: - numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist() - else: - numeric_cols = [col for col in columns if col in data.columns] - - if len(numeric_cols) < 2: - return {'error': '至少需要两个数值列来计算相关性'} - - # 计算相关系数矩阵 - corr_matrix = data[numeric_cols].corr(method=method) - - # 转换为字典格式 - correlation = {} - for col1 in corr_matrix.columns: - correlation[col1] = {} - for col2 in corr_matrix.columns: - correlation[col1][col2] = float(corr_matrix.loc[col1, col2]) - - return { - 'columns': numeric_cols, - 'method': method, - 'correlation_matrix': correlation - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含至少两个数值列的数据。""" - numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric'] - return len(numeric_cols) >= 2 diff --git a/src/tools/stats_tools.py b/src/tools/stats_tools.py deleted file mode 100644 index a72cc71..0000000 --- a/src/tools/stats_tools.py +++ /dev/null @@ -1,325 +0,0 @@ -"""统计分析工具。""" - -import pandas as pd -import numpy as np -from typing import Dict, Any -from scipy import stats - -from src.tools.base import AnalysisTool -from src.models import DataProfile - - -class CalculateStatisticsTool(AnalysisTool): - """计算描述性统计工具。""" - - @property - def name(self) -> str: - return "calculate_statistics" - - @property - def description(self) -> str: - return "计算指定列的描述性统计信息,包括均值、中位数、标准差、最小值、最大值等。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "column": { - "type": "string", - "description": "要分析的列名" - } - }, - "required": ["column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行统计计算。""" - column = kwargs.get('column') - - if column not in data.columns: - return {'error': f'列 {column} 不存在'} - - col_data = data[column].dropna() - - if not pd.api.types.is_numeric_dtype(col_data): - return {'error': f'列 {column} 不是数值类型'} - - statistics = { - 'column': column, - 'count': int(len(col_data)), - 'mean': float(col_data.mean()), - 'median': float(col_data.median()), - 'std': float(col_data.std()), - 'min': float(col_data.min()), - 'max': float(col_data.max()), - 'q25': float(col_data.quantile(0.25)), - 'q75': float(col_data.quantile(0.75)), - 'skewness': float(col_data.skew()), - 'kurtosis': float(col_data.kurtosis()) - } - - return statistics - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含数值列的数据。""" - return any(col.dtype == 'numeric' for col in data_profile.columns) - - -class PerformGroupbyTool(AnalysisTool): - """执行分组聚合工具。""" - - @property - def name(self) -> str: - return "perform_groupby" - - @property - def description(self) -> str: - return "按指定列分组,对另一列进行聚合计算(如求和、平均、计数等)。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "group_by": { - "type": "string", - "description": "分组依据的列名" - }, - "value_column": { - "type": "string", - "description": "要聚合的值列名,如果为空则计数" - }, - "aggregation": { - "type": "string", - "description": "聚合方式:count, sum, mean, min, max, std", - "default": "count" - } - }, - "required": ["group_by"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行分组聚合。""" - group_by = kwargs.get('group_by') - value_column = kwargs.get('value_column') - aggregation = kwargs.get('aggregation', 'count') - - if group_by not in data.columns: - return {'error': f'分组列 {group_by} 不存在'} - - if value_column and value_column not in data.columns: - return {'error': f'值列 {value_column} 不存在'} - - # 执行分组聚合 - if value_column: - grouped = data.groupby(group_by, observed=True)[value_column] - else: - grouped = data.groupby(group_by, observed=True).size() - aggregation = 'count' - - if aggregation == 'count': - if value_column: - result = grouped.count() - else: - result = grouped - elif aggregation == 'sum': - result = grouped.sum() - elif aggregation == 'mean': - result = grouped.mean() - elif aggregation == 'min': - result = grouped.min() - elif aggregation == 'max': - result = grouped.max() - elif aggregation == 'std': - result = grouped.std() - else: - return {'error': f'不支持的聚合方式: {aggregation}'} - - # 转换为字典 - groups = [] - for group_value, agg_value in result.items(): - groups.append({ - 'group': str(group_value), - 'value': float(agg_value) if not pd.isna(agg_value) else 0.0 - }) - - # 限制返回的分组数量,最多100个(隐私保护要求) - total_groups = len(groups) - if len(groups) > 100: - # 按值排序并取前100个 - groups = sorted(groups, key=lambda x: x['value'], reverse=True)[:100] - - return { - 'group_by': group_by, - 'value_column': value_column, - 'aggregation': aggregation, - 'groups': groups, - 'total_groups': total_groups, # 记录原始分组数量 - 'returned_groups': len(groups) # 记录返回的分组数量 - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于所有数据。""" - return True - - -class DetectOutliersTool(AnalysisTool): - """检测异常值工具。""" - - @property - def name(self) -> str: - return "detect_outliers" - - @property - def description(self) -> str: - return "使用IQR方法或Z-score方法检测数值列中的异常值。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "column": { - "type": "string", - "description": "要检测的列名" - }, - "method": { - "type": "string", - "description": "检测方法:iqr 或 zscore", - "default": "iqr" - }, - "threshold": { - "type": "number", - "description": "阈值(IQR倍数或Z-score标准差倍数)", - "default": 1.5 - } - }, - "required": ["column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行异常值检测。""" - column = kwargs.get('column') - method = kwargs.get('method', 'iqr') - threshold = kwargs.get('threshold', 1.5) - - if column not in data.columns: - return {'error': f'列 {column} 不存在'} - - col_data = data[column].dropna() - - if not pd.api.types.is_numeric_dtype(col_data): - return {'error': f'列 {column} 不是数值类型'} - - if method == 'iqr': - # IQR 方法 - q1 = col_data.quantile(0.25) - q3 = col_data.quantile(0.75) - iqr = q3 - q1 - lower_bound = q1 - threshold * iqr - upper_bound = q3 + threshold * iqr - outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)] - elif method == 'zscore': - # Z-score 方法 - z_scores = np.abs(stats.zscore(col_data)) - outliers = col_data[z_scores > threshold] - else: - return {'error': f'不支持的检测方法: {method}'} - - return { - 'column': column, - 'method': method, - 'threshold': threshold, - 'outlier_count': int(len(outliers)), - 'outlier_percentage': float(len(outliers) / len(col_data) * 100), - 'outlier_values': outliers.head(20).tolist(), # 最多返回20个异常值 - 'bounds': { - 'lower': float(lower_bound) if method == 'iqr' else None, - 'upper': float(upper_bound) if method == 'iqr' else None - } if method == 'iqr' else None - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含数值列的数据。""" - return any(col.dtype == 'numeric' for col in data_profile.columns) - - -class CalculateTrendTool(AnalysisTool): - """计算趋势工具。""" - - @property - def name(self) -> str: - return "calculate_trend" - - @property - def description(self) -> str: - return "计算时间序列数据的趋势,包括线性回归斜率、增长率等。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "time_column": { - "type": "string", - "description": "时间列名" - }, - "value_column": { - "type": "string", - "description": "值列名" - } - }, - "required": ["time_column", "value_column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行趋势计算。""" - time_column = kwargs.get('time_column') - value_column = kwargs.get('value_column') - - if time_column not in data.columns: - return {'error': f'时间列 {time_column} 不存在'} - - if value_column not in data.columns: - return {'error': f'值列 {value_column} 不存在'} - - # 转换时间列 - try: - time_data = pd.to_datetime(data[time_column]) - except Exception as e: - return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'} - - # 创建数值型时间索引(天数) - time_numeric = (time_data - time_data.min()).dt.days.values - value_data = data[value_column].dropna().values - - if len(value_data) < 2: - return {'error': '数据点太少,无法计算趋势'} - - # 线性回归 - slope, intercept, r_value, p_value, std_err = stats.linregress( - time_numeric[:len(value_data)], value_data - ) - - # 计算增长率 - first_value = value_data[0] - last_value = value_data[-1] - growth_rate = ((last_value - first_value) / first_value * 100) if first_value != 0 else 0 - - return { - 'time_column': time_column, - 'value_column': value_column, - 'slope': float(slope), - 'intercept': float(intercept), - 'r_squared': float(r_value ** 2), - 'p_value': float(p_value), - 'growth_rate': float(growth_rate), - 'trend': 'increasing' if slope > 0 else 'decreasing' if slope < 0 else 'stable' - } - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含日期时间列和数值列的数据。""" - has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns) - has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns) - return has_datetime and has_numeric diff --git a/src/tools/tool_manager.py b/src/tools/tool_manager.py deleted file mode 100644 index e120cae..0000000 --- a/src/tools/tool_manager.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Tool manager — selects applicable tools based on data profile. - -The ToolManager filters tools by data type compatibility (e.g., time series tools -need datetime columns). The actual tool+parameter selection is fully AI-driven. -""" - -from typing import List, Dict, Any - -from src.tools.base import AnalysisTool, ToolRegistry, _global_registry -from src.models import DataProfile - - -class ToolManager: - """ - Tool manager that selects applicable tools based on data characteristics. - - This is a filter, not a decision maker. AI decides which tools to actually - call and with what parameters at runtime. - """ - - def __init__(self, registry: ToolRegistry = None): - self.registry = registry if registry else _global_registry - self._missing_tools: List[str] = [] - - def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]: - """ - Return all tools applicable to this data profile. - Each tool's is_applicable() checks if the data has the right column types. - """ - self._missing_tools = [] - return self.registry.get_applicable_tools(data_profile) - - def get_all_tools(self) -> List[AnalysisTool]: - """Return all registered tools regardless of data profile.""" - tool_names = self.registry.list_tools() - return [self.registry.get_tool(name) for name in tool_names] - - def get_missing_tools(self) -> List[str]: - return list(set(self._missing_tools)) - - def get_tool_descriptions(self, tools: List[AnalysisTool] = None) -> List[Dict[str, Any]]: - """Get tool descriptions for AI consumption.""" - if tools is None: - tools = self.get_all_tools() - return [ - { - 'name': t.name, - 'description': t.description, - 'parameters': t.parameters - } - for t in tools - ] diff --git a/src/tools/viz_tools.py b/src/tools/viz_tools.py deleted file mode 100644 index 44c544c..0000000 --- a/src/tools/viz_tools.py +++ /dev/null @@ -1,443 +0,0 @@ -"""可视化工具。""" - -import pandas as pd -import numpy as np -import matplotlib -matplotlib.use('Agg') # 使用非交互式后端 -import matplotlib.pyplot as plt -from typing import Dict, Any -import os -from pathlib import Path - -from src.tools.base import AnalysisTool -from src.models import DataProfile - -# 尝试导入 seaborn,如果不可用则使用 matplotlib -try: - import seaborn as sns - HAS_SEABORN = True -except ImportError: - HAS_SEABORN = False - - -# 设置中文字体支持 -plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] -plt.rcParams['axes.unicode_minus'] = False - - -class CreateBarChartTool(AnalysisTool): - """创建柱状图工具。""" - - @property - def name(self) -> str: - return "create_bar_chart" - - @property - def description(self) -> str: - return "创建柱状图,用于展示分类数据的分布或比较。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "x_column": { - "type": "string", - "description": "X轴列名(分类变量)" - }, - "y_column": { - "type": "string", - "description": "Y轴列名(数值变量),如果为空则计数" - }, - "title": { - "type": "string", - "description": "图表标题", - "default": "柱状图" - }, - "output_path": { - "type": "string", - "description": "输出文件路径", - "default": "bar_chart.png" - }, - "top_n": { - "type": "integer", - "description": "只显示前N个类别", - "default": 20 - } - }, - "required": ["x_column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行柱状图生成。""" - x_column = kwargs.get('x_column') - y_column = kwargs.get('y_column') - title = kwargs.get('title', '柱状图') - output_path = kwargs.get('output_path', 'bar_chart.png') - top_n = kwargs.get('top_n', 20) - - if x_column not in data.columns: - return {'error': f'列 {x_column} 不存在'} - - if y_column and y_column not in data.columns: - return {'error': f'列 {y_column} 不存在'} - - try: - # 准备数据 - if y_column: - # 按 x_column 分组,对 y_column 求和 - plot_data = data.groupby(x_column, observed=True)[y_column].sum().sort_values(ascending=False).head(top_n) - else: - # 计数 - plot_data = data[x_column].value_counts().head(top_n) - - # 创建图表 - fig, ax = plt.subplots(figsize=(12, 6)) - plot_data.plot(kind='bar', ax=ax) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xlabel(x_column, fontsize=12) - ax.set_ylabel(y_column if y_column else '计数', fontsize=12) - ax.tick_params(axis='x', rotation=45) - plt.tight_layout() - - # 确保输出目录存在 - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - - # 保存图表 - plt.savefig(output_path, dpi=100, bbox_inches='tight') - plt.close(fig) - - return { - 'success': True, - 'chart_path': output_path, - 'chart_type': 'bar', - 'data_points': len(plot_data), - 'x_column': x_column, - 'y_column': y_column - } - - except Exception as e: - return {'error': f'生成柱状图失败: {str(e)}'} - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于所有数据。""" - return True - - -class CreateLineChartTool(AnalysisTool): - """创建折线图工具。""" - - @property - def name(self) -> str: - return "create_line_chart" - - @property - def description(self) -> str: - return "创建折线图,用于展示时间序列数据或趋势变化。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "x_column": { - "type": "string", - "description": "X轴列名(通常是时间)" - }, - "y_column": { - "type": "string", - "description": "Y轴列名(数值变量)" - }, - "title": { - "type": "string", - "description": "图表标题", - "default": "折线图" - }, - "output_path": { - "type": "string", - "description": "输出文件路径", - "default": "line_chart.png" - } - }, - "required": ["x_column", "y_column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行折线图生成。""" - x_column = kwargs.get('x_column') - y_column = kwargs.get('y_column') - title = kwargs.get('title', '折线图') - output_path = kwargs.get('output_path', 'line_chart.png') - - if x_column not in data.columns: - return {'error': f'列 {x_column} 不存在'} - - if y_column not in data.columns: - return {'error': f'列 {y_column} 不存在'} - - try: - # 准备数据 - plot_data = data[[x_column, y_column]].copy() - plot_data = plot_data.sort_values(x_column) - - # 如果数据点太多,采样 - if len(plot_data) > 1000: - step = len(plot_data) // 1000 - plot_data = plot_data.iloc[::step] - - # 创建图表 - fig, ax = plt.subplots(figsize=(12, 6)) - ax.plot(plot_data[x_column], plot_data[y_column], marker='o', markersize=3, linewidth=2) - ax.set_title(title, fontsize=14, fontweight='bold') - ax.set_xlabel(x_column, fontsize=12) - ax.set_ylabel(y_column, fontsize=12) - ax.grid(True, alpha=0.3) - plt.xticks(rotation=45) - plt.tight_layout() - - # 确保输出目录存在 - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - - # 保存图表 - plt.savefig(output_path, dpi=100, bbox_inches='tight') - plt.close(fig) - - return { - 'success': True, - 'chart_path': output_path, - 'chart_type': 'line', - 'data_points': len(plot_data), - 'x_column': x_column, - 'y_column': y_column - } - - except Exception as e: - return {'error': f'生成折线图失败: {str(e)}'} - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含数值列的数据。""" - return any(col.dtype == 'numeric' for col in data_profile.columns) - - -class CreatePieChartTool(AnalysisTool): - """创建饼图工具。""" - - @property - def name(self) -> str: - return "create_pie_chart" - - @property - def description(self) -> str: - return "创建饼图,用于展示各部分占整体的比例。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "column": { - "type": "string", - "description": "要分析的列名" - }, - "title": { - "type": "string", - "description": "图表标题", - "default": "饼图" - }, - "output_path": { - "type": "string", - "description": "输出文件路径", - "default": "pie_chart.png" - }, - "top_n": { - "type": "integer", - "description": "只显示前N个类别,其余归为'其他'", - "default": 10 - } - }, - "required": ["column"] - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行饼图生成。""" - column = kwargs.get('column') - title = kwargs.get('title', '饼图') - output_path = kwargs.get('output_path', 'pie_chart.png') - top_n = kwargs.get('top_n', 10) - - if column not in data.columns: - return {'error': f'列 {column} 不存在'} - - try: - # 准备数据 - value_counts = data[column].value_counts() - - if len(value_counts) > top_n: - # 只保留前 N 个,其余归为"其他" - top_values = value_counts.head(top_n) - other_sum = value_counts.iloc[top_n:].sum() - plot_data = pd.concat([top_values, pd.Series({'其他': other_sum})]) - else: - plot_data = value_counts - - # 创建图表 - fig, ax = plt.subplots(figsize=(10, 8)) - colors = plt.cm.Set3(range(len(plot_data))) - wedges, texts, autotexts = ax.pie( - plot_data, - labels=plot_data.index, - autopct='%1.1f%%', - colors=colors, - startangle=90 - ) - - # 设置文本样式 - for text in texts: - text.set_fontsize(10) - for autotext in autotexts: - autotext.set_color('white') - autotext.set_fontweight('bold') - autotext.set_fontsize(9) - - ax.set_title(title, fontsize=14, fontweight='bold') - plt.tight_layout() - - # 确保输出目录存在 - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - - # 保存图表 - plt.savefig(output_path, dpi=100, bbox_inches='tight') - plt.close(fig) - - return { - 'success': True, - 'chart_path': output_path, - 'chart_type': 'pie', - 'categories': len(plot_data), - 'column': column - } - - except Exception as e: - return {'error': f'生成饼图失败: {str(e)}'} - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于所有数据。""" - return True - - -class CreateHeatmapTool(AnalysisTool): - """创建热力图工具。""" - - @property - def name(self) -> str: - return "create_heatmap" - - @property - def description(self) -> str: - return "创建热力图,用于展示数值矩阵或相关性矩阵。" - - @property - def parameters(self) -> Dict[str, Any]: - return { - "type": "object", - "properties": { - "columns": { - "type": "array", - "items": {"type": "string"}, - "description": "要分析的列名列表,如果为空则使用所有数值列" - }, - "title": { - "type": "string", - "description": "图表标题", - "default": "相关性热力图" - }, - "output_path": { - "type": "string", - "description": "输出文件路径", - "default": "heatmap.png" - }, - "method": { - "type": "string", - "description": "相关系数方法:pearson, spearman, kendall", - "default": "pearson" - } - } - } - - def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]: - """执行热力图生成。""" - columns = kwargs.get('columns', []) - title = kwargs.get('title', '相关性热力图') - output_path = kwargs.get('output_path', 'heatmap.png') - method = kwargs.get('method', 'pearson') - - try: - # 如果没有指定列,使用所有数值列 - if not columns: - numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist() - else: - numeric_cols = [col for col in columns if col in data.columns] - - if len(numeric_cols) < 2: - return {'error': '至少需要两个数值列来创建热力图'} - - # 计算相关系数矩阵 - corr_matrix = data[numeric_cols].corr(method=method) - - # 创建图表 - fig, ax = plt.subplots(figsize=(10, 8)) - - if HAS_SEABORN: - # 使用 seaborn 创建更美观的热力图 - sns.heatmap( - corr_matrix, - annot=True, - fmt='.2f', - cmap='coolwarm', - center=0, - square=True, - linewidths=1, - cbar_kws={"shrink": 0.8}, - ax=ax - ) - else: - # 使用 matplotlib 创建基本热力图 - im = ax.imshow(corr_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1) - ax.set_xticks(range(len(corr_matrix.columns))) - ax.set_yticks(range(len(corr_matrix.columns))) - ax.set_xticklabels(corr_matrix.columns, rotation=45, ha='right') - ax.set_yticklabels(corr_matrix.columns) - - # 添加数值标注 - for i in range(len(corr_matrix.columns)): - for j in range(len(corr_matrix.columns)): - text = ax.text(j, i, f'{corr_matrix.iloc[i, j]:.2f}', - ha="center", va="center", color="black", fontsize=9) - - plt.colorbar(im, ax=ax, shrink=0.8) - - ax.set_title(title, fontsize=14, fontweight='bold') - plt.tight_layout() - - # 确保输出目录存在 - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - - # 保存图表 - plt.savefig(output_path, dpi=100, bbox_inches='tight') - plt.close(fig) - - return { - 'success': True, - 'chart_path': output_path, - 'chart_type': 'heatmap', - 'columns': numeric_cols, - 'method': method - } - - except Exception as e: - return {'error': f'生成热力图失败: {str(e)}'} - - def is_applicable(self, data_profile: DataProfile) -> bool: - """适用于包含至少两个数值列的数据。""" - numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric'] - return len(numeric_cols) >= 2