12345678
This commit is contained in:
@@ -1,171 +0,0 @@
|
|||||||
# 完整数据分析系统 - 执行结果
|
|
||||||
|
|
||||||
## 问题解决
|
|
||||||
|
|
||||||
### 核心问题
|
|
||||||
工具注册失败,导致 `ToolManager.select_tools()` 返回 0 个工具。
|
|
||||||
|
|
||||||
### 根本原因
|
|
||||||
`ToolManager` 在初始化时创建了一个新的空 `ToolRegistry` 实例,而工具实际上被注册到了全局注册表 `_global_registry` 中。两个注册表互不相通。
|
|
||||||
|
|
||||||
### 解决方案
|
|
||||||
修改 `src/tools/tool_manager.py` 第 18 行:
|
|
||||||
```python
|
|
||||||
# 修改前
|
|
||||||
self.registry = registry if registry else ToolRegistry()
|
|
||||||
|
|
||||||
# 修改后
|
|
||||||
self.registry = registry if registry else _global_registry
|
|
||||||
```
|
|
||||||
|
|
||||||
## 系统验证结果
|
|
||||||
|
|
||||||
### ✅ 阶段 1: AI 数据理解
|
|
||||||
- **数据类型识别**: ticket (IT服务工单)
|
|
||||||
- **数据质量评分**: 85.0/100
|
|
||||||
- **关键字段识别**: 15 个
|
|
||||||
- **数据规模**: 84 行 × 21 列
|
|
||||||
- **隐私保护**: ✓ AI 只能看到表头和统计信息,无法访问原始数据行
|
|
||||||
|
|
||||||
**AI 分析摘要**:
|
|
||||||
> 这是一个典型的IT服务工单数据集,记录了84个车辆相关问题的处理全流程。数据集包含完整的工单生命周期信息(创建、处理、关闭),主要涉及远程控制、导航、网络等车辆系统问题。数据质量较高,缺失率低(仅SIM和Notes字段缺失较多),但部分文本字段存在较长的非结构化描述。问题类型和模块分布显示远程控制问题是主要痛点(占比66%),工单主要来自邮件渠道(55%),平均关闭时长约5天。
|
|
||||||
|
|
||||||
### ✅ 阶段 2: 需求理解
|
|
||||||
- **用户需求**: "Analyze ticket health, find main issues, efficiency and trends"
|
|
||||||
- **生成目标**: 2 个分析目标
|
|
||||||
1. 健康度分析 (优先级: 5/5)
|
|
||||||
2. 趋势分析 (优先级: 4/5)
|
|
||||||
|
|
||||||
### ✅ 阶段 3: 分析规划
|
|
||||||
- **生成任务**: 2 个
|
|
||||||
- **预计时长**: 120 秒
|
|
||||||
- **任务清单**:
|
|
||||||
- [优先级 5] task_1. 质量评估 - 健康度分析
|
|
||||||
- [优先级 4] task_2. 趋势分析 - 趋势分析
|
|
||||||
|
|
||||||
### ✅ 阶段 4: 任务执行
|
|
||||||
- **可用工具**: 9 个(修复后)
|
|
||||||
- get_column_distribution (列分布统计)
|
|
||||||
- get_value_counts (值计数)
|
|
||||||
- perform_groupby (分组聚合)
|
|
||||||
- create_bar_chart (柱状图)
|
|
||||||
- create_pie_chart (饼图)
|
|
||||||
- calculate_statistics (描述性统计)
|
|
||||||
- detect_outliers (异常值检测)
|
|
||||||
- get_time_series (时间序列)
|
|
||||||
- calculate_trend (趋势计算)
|
|
||||||
|
|
||||||
- **执行结果**:
|
|
||||||
- 成功: 2/2 任务
|
|
||||||
- 失败: 0/2 任务
|
|
||||||
- 总执行时间: ~51 秒
|
|
||||||
- 生成洞察: 2 条
|
|
||||||
|
|
||||||
### ✅ 阶段 5: 报告生成
|
|
||||||
- **报告文件**: analysis_output/analysis_report.md
|
|
||||||
- **报告长度**: 774 字符
|
|
||||||
- **包含内容**:
|
|
||||||
- 执行摘要
|
|
||||||
- 数据概览(15个关键字段说明)
|
|
||||||
- 详细分析
|
|
||||||
- 结论与建议
|
|
||||||
- 任务执行附录
|
|
||||||
|
|
||||||
## 系统架构验证
|
|
||||||
|
|
||||||
### 隐私保护机制 ✓
|
|
||||||
1. **数据访问层隔离**: AI 无法直接访问原始数据
|
|
||||||
2. **元数据暴露**: AI 只能看到列名、数据类型、统计信息
|
|
||||||
3. **工具执行**: 工具在原始数据上执行,返回聚合结果
|
|
||||||
4. **结果限制**:
|
|
||||||
- 分组结果最多 100 个
|
|
||||||
- 时间序列最多 100 个数据点
|
|
||||||
- 异常值最多返回 20 个
|
|
||||||
|
|
||||||
### 配置管理 ✓
|
|
||||||
所有 LLM API 调用已统一从 `.env` 文件读取配置:
|
|
||||||
- `OPENAI_MODEL=mimo-v2-flash`
|
|
||||||
- `OPENAI_BASE_URL=https://api.xiaomimimo.com/v1`
|
|
||||||
- `OPENAI_API_KEY=[已配置]`
|
|
||||||
|
|
||||||
修改的文件:
|
|
||||||
1. src/engines/task_execution.py
|
|
||||||
2. src/engines/requirement_understanding.py
|
|
||||||
3. src/engines/report_generation.py
|
|
||||||
4. src/engines/plan_adjustment.py
|
|
||||||
5. src/engines/analysis_planning.py
|
|
||||||
|
|
||||||
### 工具系统 ✓
|
|
||||||
- **全局注册表**: 12 个工具已注册
|
|
||||||
- **动态选择**: 根据数据特征自动选择适用工具
|
|
||||||
- **类型检测**: 支持时间序列、分类、数值、地理数据
|
|
||||||
- **参数验证**: JSON Schema 格式参数定义
|
|
||||||
|
|
||||||
## 测试数据
|
|
||||||
|
|
||||||
### cleaned_data.csv
|
|
||||||
- **行数**: 84
|
|
||||||
- **列数**: 21
|
|
||||||
- **数据类型**: IT 服务工单
|
|
||||||
- **主要字段**:
|
|
||||||
- 工单号、来源、创建日期
|
|
||||||
- 问题类型、问题描述、处理过程
|
|
||||||
- 严重程度、工单状态、模块
|
|
||||||
- 责任人、关闭日期、关闭时长
|
|
||||||
- 车型、VIN
|
|
||||||
|
|
||||||
### 数据质量
|
|
||||||
- **完整性**: 85/100
|
|
||||||
- **缺失字段**: SIM (100%), Notes (较多)
|
|
||||||
- **时间字段**: 创建日期、关闭日期
|
|
||||||
- **分类字段**: 来源、问题类型、严重程度、工单状态、模块
|
|
||||||
- **数值字段**: 关闭时长(天)
|
|
||||||
|
|
||||||
## 执行命令
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python run_analysis_en.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## 输出文件
|
|
||||||
|
|
||||||
```
|
|
||||||
analysis_output/
|
|
||||||
├── analysis_report.md # 分析报告
|
|
||||||
└── *.png # 图表文件(如有生成)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 性能指标
|
|
||||||
|
|
||||||
- **数据加载**: < 1 秒
|
|
||||||
- **AI 数据理解**: ~5 秒
|
|
||||||
- **需求理解**: ~3 秒
|
|
||||||
- **分析规划**: ~2 秒
|
|
||||||
- **任务执行**: ~51 秒 (2 个任务)
|
|
||||||
- **报告生成**: ~2 秒
|
|
||||||
- **总耗时**: ~63 秒
|
|
||||||
|
|
||||||
## 系统状态
|
|
||||||
|
|
||||||
### ✅ 已完成
|
|
||||||
1. 工具注册系统修复
|
|
||||||
2. 配置管理统一
|
|
||||||
3. 隐私保护验证
|
|
||||||
4. 端到端分析流程
|
|
||||||
5. 真实数据测试
|
|
||||||
|
|
||||||
### 📊 测试覆盖率
|
|
||||||
- 单元测试: 314/328 通过 (95.7%)
|
|
||||||
- 属性测试: 已实施
|
|
||||||
- 集成测试: 已通过
|
|
||||||
- 端到端测试: 已通过
|
|
||||||
|
|
||||||
## 结论
|
|
||||||
|
|
||||||
系统已完全就绪,可以进行生产环境部署。所有核心功能已验证,隐私保护机制有效,配置管理规范,工具系统运行正常。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**生成时间**: 2026-03-09 09:08:27
|
|
||||||
**测试环境**: Windows, Python 3.x
|
|
||||||
**数据集**: cleaned_data.csv (84 rows × 21 columns)
|
|
||||||
177
analysis_output/run_20260309_102648/analysis_report.md
Normal file
177
analysis_output/run_20260309_102648/analysis_report.md
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
# 工单分析报告
|
||||||
|
|
||||||
|
生成时间:2026-03-09 10:29:31
|
||||||
|
数据源:cleaned_data.csv
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 工单数据分析报告
|
||||||
|
|
||||||
|
## 1. 执行摘要
|
||||||
|
|
||||||
|
本报告基于84条工单数据(质量分数88/100)进行全面分析,主要发现如下:
|
||||||
|
|
||||||
|
1. **工单处理效率存在显著异常**:平均关闭时长为54.77天,但存在2个异常工单(处理时长分别为277天和237天),占总数的2.38%。"Activation SIM"问题的平均处理时长高达142.5天,远高于其他问题类型。
|
||||||
|
2. **工单来源渠道集中**:邮件渠道占比54.76%(46张),Telegram bot占比42.86%(36张),渠道来源仅占2.38%(2张),显示渠道管理可优化。
|
||||||
|
3. **问题类型高度集中**:远程控制问题占主导(66.67%,56张),前五类问题占总量的87%以上,表明问题分布集中。
|
||||||
|
4. **车型问题分布不均**:EXEED RX(T22)车型工单最多(45.24%,38张),JAECOO J7(T1EJ)次之(26.19%,22张),特定车型问题集中。
|
||||||
|
5. **责任人工作负载差异大**:Vsevolod处理工单最多(31个),但平均处理时长66.68天;Vsevolod Tsoi平均处理时长最高(152天),而何韬处理效率最高(平均3.5天)。
|
||||||
|
|
||||||
|
## 2. 数据概览
|
||||||
|
|
||||||
|
- **数据类型**:工单(ticket)
|
||||||
|
- **数据规模**:84行 × 21列
|
||||||
|
- **数据质量**:88.0/100
|
||||||
|
- **关键字段**:工单号、来源、创建日期、问题类型、问题描述、处理过程、跟踪记录、严重程度、工单状态、模块、责任人、关闭日期、车型、VIN、关闭时长(天)等
|
||||||
|
- **分析时间范围**:2025年1月2日至2025年2月24日
|
||||||
|
|
||||||
|
## 3. 详细分析
|
||||||
|
|
||||||
|
### 3.1 工单概况分析
|
||||||
|
|
||||||
|
工单状态分布显示,82.14%(69张)已关闭,17.86%(15张)临时关闭,表明大部分问题已解决。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**来源渠道分析**:
|
||||||
|
- 邮件(Mail):54.76%(46张)
|
||||||
|
- Telegram bot:42.86%(36张)
|
||||||
|
- Telegram channel:2.38%(2张)
|
||||||
|
|
||||||
|
**问题类型分布**(前5位):
|
||||||
|
1. 远程控制(Remote control):66.67%(56张)
|
||||||
|
2. 网络问题(Network):7.14%(6张)
|
||||||
|
3. 导航问题(Navi):5.95%(5张)
|
||||||
|
4. 应用问题(Application):4.76%(4张)
|
||||||
|
5. 成员中心认证问题:3.57%(3张)
|
||||||
|
|
||||||
|
前五类问题合计占总量的87.09%,显示问题类型高度集中。
|
||||||
|
|
||||||
|
### 3.2 工单处理效率分析
|
||||||
|
|
||||||
|
**关闭时长统计**:
|
||||||
|
- 平均值:54.77天
|
||||||
|
- 中位数:41天
|
||||||
|
- 标准差:48.19天
|
||||||
|
- 最小值:2天
|
||||||
|
- 最大值:277天
|
||||||
|
- 四分位距(IQR):26.25天(Q25)至84.5天(Q75)
|
||||||
|
|
||||||
|
**异常值检测**:
|
||||||
|
- 检测到2个异常工单(277天和237天),占总数的2.38%
|
||||||
|
- 异常值上限为171.875天,这两个工单远超此阈值
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**按问题类型分析平均关闭时长**:
|
||||||
|
- Activation SIM:142.5天(效率最低)
|
||||||
|
- Remote control:66.5天
|
||||||
|
- PKI problem:47天
|
||||||
|
- doesn't exist on TSP:31.67天
|
||||||
|
- Network:24天
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 3.3 工单内容与趋势分析
|
||||||
|
|
||||||
|
**工单创建时间趋势**(2025年1-2月):
|
||||||
|
- 创建高峰期:2025年1月13日(8个工单)
|
||||||
|
- 创建低谷期:2025年1月8日、15日、29日、30日及2月多日(仅1个工单)
|
||||||
|
- 整体趋势波动较大,无明显持续上升或下降模式
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 3.4 责任人工作负载分析
|
||||||
|
|
||||||
|
**工单处理数量**:
|
||||||
|
- Vsevolod:31个(最多)
|
||||||
|
- Evgeniy:28个
|
||||||
|
- Kostya:5个
|
||||||
|
- 何韬:4个
|
||||||
|
|
||||||
|
**平均关闭时长**:
|
||||||
|
- Vsevolod Tsoi:152天(最高,仅处理2个工单)
|
||||||
|
- 林兆国:89天
|
||||||
|
- Vadim:69天
|
||||||
|
- Vsevolod:66.68天
|
||||||
|
- Evgeniy:62.39天
|
||||||
|
- 何韬:3.5天(最低,效率最高)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 3.5 车辆相关信息分析
|
||||||
|
|
||||||
|
**车型分布**:
|
||||||
|
- EXEED RX(T22):45.24%(38张)
|
||||||
|
- JAECOO J7(T1EJ):26.19%(22张)
|
||||||
|
- EXEED VX FL(M36T):20.24%(17张)
|
||||||
|
- CHERY TIGGO 9 (T28):8.33%(7张)
|
||||||
|
|
||||||
|
**VIN分布**:
|
||||||
|
- LVTDD24B1RG023450和LVTDD24B1RG021245各出现2次,表明特定车辆问题可能复发
|
||||||
|
- 其他VIN均唯一,显示问题主要分散于不同车辆
|
||||||
|
|
||||||
|
## 4. 结论与建议
|
||||||
|
|
||||||
|
### 4.1 主要结论
|
||||||
|
|
||||||
|
1. **处理效率需优化**:整体平均关闭时长54.77天,且存在严重异常值(277天和237天),"Activation SIM"问题处理时长高达142.5天。
|
||||||
|
2. **渠道管理可改进**:邮件渠道占比过高(54.76%),Telegram渠道未充分利用。
|
||||||
|
3. **问题类型集中**:远程控制问题占66.67%,需针对性优化。
|
||||||
|
4. **车型问题分布不均**:EXEED RX(T22)车型问题最集中(45.24%)。
|
||||||
|
5. **责任人效率差异大**:Vsevolod Tsoi处理时长152天,而何韬仅3.5天,存在明显效率差距。
|
||||||
|
|
||||||
|
### 4.2 可操作建议
|
||||||
|
|
||||||
|
1. **优化异常工单处理流程**:
|
||||||
|
- 针对处理时长超过100天的工单(如Activation SIM问题)建立专项处理机制
|
||||||
|
- 定期审查异常工单,分析根本原因
|
||||||
|
- 依据:异常值检测显示2个工单处理时长分别为277天和237天
|
||||||
|
|
||||||
|
2. **平衡渠道分配**:
|
||||||
|
- 推广Telegram bot使用,减轻邮件渠道压力
|
||||||
|
- 依据:邮件渠道占比54.76%,Telegram bot仅42.86%
|
||||||
|
|
||||||
|
3. **加强远程控制问题管理**:
|
||||||
|
- 建立远程控制问题知识库和快速响应机制
|
||||||
|
- 依据:远程控制问题占66.67%(56张)
|
||||||
|
|
||||||
|
4. **针对EXEED RX(T22)车型专项优化**:
|
||||||
|
- 分析该车型高工单率的原因
|
||||||
|
- 依据:该车型工单占比45.24%(38张)
|
||||||
|
|
||||||
|
5. **提升责任人效率一致性**:
|
||||||
|
- 建立效率标杆(如何韬的3.5天平均时长)
|
||||||
|
- 对处理时长较长的责任人提供培训和支持
|
||||||
|
- 依据:责任人平均处理时长从3.5天到152天不等
|
||||||
|
|
||||||
|
6. **建立工单创建趋势监控**:
|
||||||
|
- 监控工单创建高峰期(如1月13日的8个工单)
|
||||||
|
- 提前调配资源应对潜在高峰
|
||||||
|
- 依据:工单创建趋势显示明显波动
|
||||||
|
|
||||||
|
通过实施这些建议,预计可显著提升工单处理效率,优化资源分配,并改善客户满意度。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 分析追溯
|
||||||
|
|
||||||
|
本报告基于以下分析任务:
|
||||||
|
|
||||||
|
- ✓ 工单概况分析:基本分布统计
|
||||||
|
- 工单总数为84,其中82.14%(69张)已关闭,17.86%(15张)临时关闭,表明大部分问题已解决。
|
||||||
|
- 工单来源中,邮件(Mail)占比最高,达54.76%(46张),Telegram bot占42.86%(36张),渠道来源仅占2.38%(2张)。
|
||||||
|
- ✓ 工单处理效率分析:关闭时长统计
|
||||||
|
- 关闭时长的平均值为54.77天,中位数为41天,标准差为48.19天,表明数据右偏分布,存在处理时间较长的工单。
|
||||||
|
- 异常值检测发现2个工单(277天和237天)处理时间过长,占总工单数的2.38%,需重点关注这些异常情况。
|
||||||
|
- ✓ 工单内容与趋势分析:文本和时间序列
|
||||||
|
- 2025年1月13日工单创建数量最高,达到8个,是整体趋势中的峰值。
|
||||||
|
- 2025年1月8日、15日、29日、30日及2月多日工单创建数量最低,仅为1个,显示这些日期工单活动较少。
|
||||||
|
- ✓ 责任人工作负载分析
|
||||||
|
- Vsevolod 处理工单数量最多(31个),但平均关闭时长为66.68天,工作量大且周期长。
|
||||||
|
- Vsevolod Tsoi 平均关闭时长最高(152天),但仅处理2个工单,可能存在效率问题或复杂任务。
|
||||||
|
- ✓ 车辆相关信息分析:车型和VIN分布
|
||||||
|
- EXEED RX(T22)车型占比最高,达45.24%(38/84),是问题集中的主要车型。
|
||||||
|
- JAECOO J7(T1EJ)车型工单数为22,占比26.19%,是第二大问题车型。
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
analysis_output/run_20260309_102648/charts/bar_chart_trend.png
Normal file
BIN
analysis_output/run_20260309_102648/charts/bar_chart_trend.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
BIN
analysis_output/run_20260309_102648/charts/outlier_bar_chart.png
Normal file
BIN
analysis_output/run_20260309_102648/charts/outlier_bar_chart.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
analysis_output/run_20260309_102648/charts/outlier_pie_chart.png
Normal file
BIN
analysis_output/run_20260309_102648/charts/outlier_pie_chart.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
@@ -23,7 +23,7 @@ from src.engines.ai_data_understanding import ai_understand_data_with_dal
|
|||||||
from src.engines.requirement_understanding import understand_requirement
|
from src.engines.requirement_understanding import understand_requirement
|
||||||
from src.engines.analysis_planning import plan_analysis
|
from src.engines.analysis_planning import plan_analysis
|
||||||
from src.engines.task_execution import execute_task
|
from src.engines.task_execution import execute_task
|
||||||
from src.engines.report_generation import generate_report
|
from src.engines.report_generation import generate_report, _convert_chart_paths_in_report
|
||||||
from src.tools.tool_manager import ToolManager
|
from src.tools.tool_manager import ToolManager
|
||||||
from src.tools.base import _global_registry
|
from src.tools.base import _global_registry
|
||||||
from src.models import DataProfile, AnalysisResult
|
from src.models import DataProfile, AnalysisResult
|
||||||
@@ -156,7 +156,8 @@ def run_analysis(
|
|||||||
else:
|
else:
|
||||||
report = generate_report(results, requirement, profile, output_path=run_dir)
|
report = generate_report(results, requirement, profile, output_path=run_dir)
|
||||||
|
|
||||||
# Save report
|
# Save report — convert chart paths to relative (./charts/xxx.png)
|
||||||
|
report = _convert_chart_paths_in_report(report, run_dir)
|
||||||
with open(report_path, 'w', encoding='utf-8') as f:
|
with open(report_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(report)
|
f.write(report)
|
||||||
|
|
||||||
@@ -262,7 +263,8 @@ def _generate_template_report(
|
|||||||
|
|
||||||
|
|
||||||
def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> str:
|
def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> str:
|
||||||
"""Collect all chart paths from task results for embedding in reports."""
|
"""Collect all chart paths from task results for embedding in reports.
|
||||||
|
Returns paths relative to run_dir (e.g. ./charts/bar_chart.png)."""
|
||||||
paths = []
|
paths = []
|
||||||
for r in results:
|
for r in results:
|
||||||
if not r.success:
|
if not r.success:
|
||||||
@@ -280,7 +282,10 @@ def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> st
|
|||||||
paths.append(cp)
|
paths.append(cp)
|
||||||
if not paths:
|
if not paths:
|
||||||
return "(无图表)"
|
return "(无图表)"
|
||||||
return "\n".join(f"- {p}" for p in paths)
|
# Convert to relative paths
|
||||||
|
from src.engines.report_generation import _to_relative_chart_path
|
||||||
|
rel_paths = [_to_relative_chart_path(p, run_dir) for p in paths]
|
||||||
|
return "\n".join(f"- {p}" for p in rel_paths)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -4,9 +4,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import PurePosixPath, Path
|
||||||
|
|
||||||
from src.models.analysis_result import AnalysisResult
|
from src.models.analysis_result import AnalysisResult
|
||||||
from src.models.requirement_spec import RequirementSpec
|
from src.models.requirement_spec import RequirementSpec
|
||||||
@@ -310,6 +312,56 @@ def _generate_conclusion_summary(key_findings: List[Dict[str, Any]]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _to_relative_chart_path(chart_path: str, report_dir: str = "") -> str:
|
||||||
|
"""
|
||||||
|
将图表绝对路径转换为相对于报告文件的路径。
|
||||||
|
|
||||||
|
例如:
|
||||||
|
chart_path = "analysis_output/run_xxx/charts/bar.png"
|
||||||
|
report_dir = "analysis_output/run_xxx"
|
||||||
|
→ "./charts/bar.png"
|
||||||
|
|
||||||
|
如果无法计算相对路径,则只保留 ./charts/filename.png
|
||||||
|
"""
|
||||||
|
if not chart_path:
|
||||||
|
return chart_path
|
||||||
|
|
||||||
|
# 统一为正斜杠
|
||||||
|
chart_path = chart_path.replace('\\', '/')
|
||||||
|
|
||||||
|
if report_dir:
|
||||||
|
report_dir = report_dir.replace('\\', '/')
|
||||||
|
try:
|
||||||
|
rel = os.path.relpath(chart_path, report_dir).replace('\\', '/')
|
||||||
|
return './' + rel if not rel.startswith('.') else rel
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: 提取 charts/filename 部分
|
||||||
|
parts = chart_path.replace('\\', '/').split('/')
|
||||||
|
if 'charts' in parts:
|
||||||
|
idx = parts.index('charts')
|
||||||
|
return './' + '/'.join(parts[idx:])
|
||||||
|
|
||||||
|
# 最后兜底:直接用文件名
|
||||||
|
return './charts/' + os.path.basename(chart_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_chart_paths_in_report(report: str, report_dir: str = "") -> str:
|
||||||
|
"""
|
||||||
|
将报告中所有  的图表路径转换为相对路径。
|
||||||
|
同时统一反斜杠为正斜杠。
|
||||||
|
"""
|
||||||
|
def replace_img(match):
|
||||||
|
alt = match.group(1)
|
||||||
|
path = match.group(2)
|
||||||
|
rel_path = _to_relative_chart_path(path, report_dir)
|
||||||
|
return f''
|
||||||
|
|
||||||
|
# 匹配 
|
||||||
|
return re.sub(r'!\[([^\]]*)\]\(([^)]+)\)', replace_img, report)
|
||||||
|
|
||||||
|
|
||||||
def generate_report(
|
def generate_report(
|
||||||
results: List[AnalysisResult],
|
results: List[AnalysisResult],
|
||||||
requirement: RequirementSpec,
|
requirement: RequirementSpec,
|
||||||
@@ -370,6 +422,10 @@ def generate_report(
|
|||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(report)
|
f.write(report)
|
||||||
|
|
||||||
|
# 将图表路径转换为相对于报告所在目录的路径
|
||||||
|
report_dir = output_path if output_path and os.path.isdir(output_path) else ""
|
||||||
|
report = _convert_chart_paths_in_report(report, report_dir)
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
|
||||||
@@ -394,11 +450,11 @@ def _generate_report_with_ai(
|
|||||||
# 收集所有图表路径
|
# 收集所有图表路径
|
||||||
for viz in (r.visualizations or []):
|
for viz in (r.visualizations or []):
|
||||||
if viz:
|
if viz:
|
||||||
all_chart_paths.append(viz)
|
all_chart_paths.append(_to_relative_chart_path(viz))
|
||||||
if isinstance(r.data, dict):
|
if isinstance(r.data, dict):
|
||||||
for key, val in r.data.items():
|
for key, val in r.data.items():
|
||||||
if isinstance(val, dict) and val.get('chart_path'):
|
if isinstance(val, dict) and val.get('chart_path'):
|
||||||
all_chart_paths.append(val['chart_path'])
|
all_chart_paths.append(_to_relative_chart_path(val['chart_path']))
|
||||||
|
|
||||||
data_section = "\n\n".join(data_summaries) if data_summaries else "无详细数据"
|
data_section = "\n\n".join(data_summaries) if data_summaries else "无详细数据"
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
@@ -1,301 +0,0 @@
|
|||||||
"""数据查询工具。"""
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
from src.tools.base import AnalysisTool
|
|
||||||
from src.models import DataProfile
|
|
||||||
|
|
||||||
|
|
||||||
class GetColumnDistributionTool(AnalysisTool):
|
|
||||||
"""获取列的分布统计工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "get_column_distribution"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "获取指定列的分布统计信息,包括值计数、百分比等。适用于分类和数值列。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要分析的列名"
|
|
||||||
},
|
|
||||||
"top_n": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "返回前N个最常见的值",
|
|
||||||
"default": 10
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行列分布分析。"""
|
|
||||||
column = kwargs.get('column')
|
|
||||||
top_n = kwargs.get('top_n', 10)
|
|
||||||
|
|
||||||
if column not in data.columns:
|
|
||||||
return {'error': f'列 {column} 不存在'}
|
|
||||||
|
|
||||||
col_data = data[column]
|
|
||||||
value_counts = col_data.value_counts().head(top_n)
|
|
||||||
total = len(col_data.dropna())
|
|
||||||
|
|
||||||
distribution = []
|
|
||||||
for value, count in value_counts.items():
|
|
||||||
distribution.append({
|
|
||||||
'value': str(value),
|
|
||||||
'count': int(count),
|
|
||||||
'percentage': float(count / total * 100) if total > 0 else 0.0
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
'column': column,
|
|
||||||
'total_count': int(total),
|
|
||||||
'unique_count': int(col_data.nunique()),
|
|
||||||
'missing_count': int(col_data.isna().sum()),
|
|
||||||
'distribution': distribution
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于所有数据。"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GetValueCountsTool(AnalysisTool):
|
|
||||||
"""获取值计数工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "get_value_counts"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "获取指定列的值计数,返回每个唯一值的出现次数。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要分析的列名"
|
|
||||||
},
|
|
||||||
"normalize": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "是否返回百分比而不是计数",
|
|
||||||
"default": False
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行值计数。"""
|
|
||||||
column = kwargs.get('column')
|
|
||||||
normalize = kwargs.get('normalize', False)
|
|
||||||
|
|
||||||
if column not in data.columns:
|
|
||||||
return {'error': f'列 {column} 不存在'}
|
|
||||||
|
|
||||||
value_counts = data[column].value_counts(normalize=normalize)
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for value, count in value_counts.items():
|
|
||||||
result[str(value)] = float(count) if normalize else int(count)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'column': column,
|
|
||||||
'value_counts': result,
|
|
||||||
'normalized': normalize
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于所有数据。"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GetTimeSeriesTool(AnalysisTool):
|
|
||||||
"""获取时间序列数据工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "get_time_series"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "获取时间序列数据,按时间聚合指定指标。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"time_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "时间列名"
|
|
||||||
},
|
|
||||||
"value_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要聚合的值列名"
|
|
||||||
},
|
|
||||||
"aggregation": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "聚合方式:count, sum, mean, min, max",
|
|
||||||
"default": "count"
|
|
||||||
},
|
|
||||||
"frequency": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "时间频率:D(天), W(周), M(月), Y(年)",
|
|
||||||
"default": "D"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["time_column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行时间序列分析。"""
|
|
||||||
time_column = kwargs.get('time_column')
|
|
||||||
value_column = kwargs.get('value_column')
|
|
||||||
aggregation = kwargs.get('aggregation', 'count')
|
|
||||||
frequency = kwargs.get('frequency', 'D')
|
|
||||||
|
|
||||||
if time_column not in data.columns:
|
|
||||||
return {'error': f'时间列 {time_column} 不存在'}
|
|
||||||
|
|
||||||
# 转换为日期时间类型
|
|
||||||
try:
|
|
||||||
time_data = pd.to_datetime(data[time_column])
|
|
||||||
except Exception as e:
|
|
||||||
return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'}
|
|
||||||
|
|
||||||
# 创建临时 DataFrame
|
|
||||||
temp_df = pd.DataFrame({'time': time_data})
|
|
||||||
|
|
||||||
if value_column:
|
|
||||||
if value_column not in data.columns:
|
|
||||||
return {'error': f'值列 {value_column} 不存在'}
|
|
||||||
temp_df['value'] = data[value_column]
|
|
||||||
|
|
||||||
# 设置时间索引
|
|
||||||
temp_df.set_index('time', inplace=True)
|
|
||||||
|
|
||||||
# 按频率重采样
|
|
||||||
if value_column:
|
|
||||||
if aggregation == 'count':
|
|
||||||
result = temp_df.resample(frequency).count()
|
|
||||||
elif aggregation == 'sum':
|
|
||||||
result = temp_df.resample(frequency).sum()
|
|
||||||
elif aggregation == 'mean':
|
|
||||||
result = temp_df.resample(frequency).mean()
|
|
||||||
elif aggregation == 'min':
|
|
||||||
result = temp_df.resample(frequency).min()
|
|
||||||
elif aggregation == 'max':
|
|
||||||
result = temp_df.resample(frequency).max()
|
|
||||||
else:
|
|
||||||
return {'error': f'不支持的聚合方式: {aggregation}'}
|
|
||||||
else:
|
|
||||||
result = temp_df.resample(frequency).size().to_frame('count')
|
|
||||||
|
|
||||||
# 转换为字典
|
|
||||||
time_series = []
|
|
||||||
for timestamp, row in result.iterrows():
|
|
||||||
time_series.append({
|
|
||||||
'time': timestamp.strftime('%Y-%m-%d'),
|
|
||||||
'value': float(row.iloc[0]) if not pd.isna(row.iloc[0]) else 0.0
|
|
||||||
})
|
|
||||||
|
|
||||||
# 限制返回的数据点数量,最多100个(隐私保护要求)
|
|
||||||
if len(time_series) > 100:
|
|
||||||
# 均匀采样以保持趋势
|
|
||||||
step = len(time_series) / 100
|
|
||||||
sampled_indices = [int(i * step) for i in range(100)]
|
|
||||||
time_series = [time_series[i] for i in sampled_indices]
|
|
||||||
|
|
||||||
return {
|
|
||||||
'time_column': time_column,
|
|
||||||
'value_column': value_column,
|
|
||||||
'aggregation': aggregation,
|
|
||||||
'frequency': frequency,
|
|
||||||
'time_series': time_series,
|
|
||||||
'total_points': len(result), # 记录原始数据点数量
|
|
||||||
'returned_points': len(time_series) # 记录返回的数据点数量
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含日期时间列的数据。"""
|
|
||||||
return any(col.dtype == 'datetime' for col in data_profile.columns)
|
|
||||||
|
|
||||||
|
|
||||||
class GetCorrelationTool(AnalysisTool):
|
|
||||||
"""获取相关性分析工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "get_correlation"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "计算数值列之间的相关系数矩阵。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"columns": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"description": "要分析的列名列表,如果为空则分析所有数值列"
|
|
||||||
},
|
|
||||||
"method": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "相关系数方法:pearson, spearman, kendall",
|
|
||||||
"default": "pearson"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行相关性分析。"""
|
|
||||||
columns = kwargs.get('columns', [])
|
|
||||||
method = kwargs.get('method', 'pearson')
|
|
||||||
|
|
||||||
# 如果没有指定列,使用所有数值列
|
|
||||||
if not columns:
|
|
||||||
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
|
|
||||||
else:
|
|
||||||
numeric_cols = [col for col in columns if col in data.columns]
|
|
||||||
|
|
||||||
if len(numeric_cols) < 2:
|
|
||||||
return {'error': '至少需要两个数值列来计算相关性'}
|
|
||||||
|
|
||||||
# 计算相关系数矩阵
|
|
||||||
corr_matrix = data[numeric_cols].corr(method=method)
|
|
||||||
|
|
||||||
# 转换为字典格式
|
|
||||||
correlation = {}
|
|
||||||
for col1 in corr_matrix.columns:
|
|
||||||
correlation[col1] = {}
|
|
||||||
for col2 in corr_matrix.columns:
|
|
||||||
correlation[col1][col2] = float(corr_matrix.loc[col1, col2])
|
|
||||||
|
|
||||||
return {
|
|
||||||
'columns': numeric_cols,
|
|
||||||
'method': method,
|
|
||||||
'correlation_matrix': correlation
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含至少两个数值列的数据。"""
|
|
||||||
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
|
|
||||||
return len(numeric_cols) >= 2
|
|
||||||
@@ -1,325 +0,0 @@
|
|||||||
"""统计分析工具。"""
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, Any
|
|
||||||
from scipy import stats
|
|
||||||
|
|
||||||
from src.tools.base import AnalysisTool
|
|
||||||
from src.models import DataProfile
|
|
||||||
|
|
||||||
|
|
||||||
class CalculateStatisticsTool(AnalysisTool):
|
|
||||||
"""计算描述性统计工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "calculate_statistics"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "计算指定列的描述性统计信息,包括均值、中位数、标准差、最小值、最大值等。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要分析的列名"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行统计计算。"""
|
|
||||||
column = kwargs.get('column')
|
|
||||||
|
|
||||||
if column not in data.columns:
|
|
||||||
return {'error': f'列 {column} 不存在'}
|
|
||||||
|
|
||||||
col_data = data[column].dropna()
|
|
||||||
|
|
||||||
if not pd.api.types.is_numeric_dtype(col_data):
|
|
||||||
return {'error': f'列 {column} 不是数值类型'}
|
|
||||||
|
|
||||||
statistics = {
|
|
||||||
'column': column,
|
|
||||||
'count': int(len(col_data)),
|
|
||||||
'mean': float(col_data.mean()),
|
|
||||||
'median': float(col_data.median()),
|
|
||||||
'std': float(col_data.std()),
|
|
||||||
'min': float(col_data.min()),
|
|
||||||
'max': float(col_data.max()),
|
|
||||||
'q25': float(col_data.quantile(0.25)),
|
|
||||||
'q75': float(col_data.quantile(0.75)),
|
|
||||||
'skewness': float(col_data.skew()),
|
|
||||||
'kurtosis': float(col_data.kurtosis())
|
|
||||||
}
|
|
||||||
|
|
||||||
return statistics
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含数值列的数据。"""
|
|
||||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
|
||||||
|
|
||||||
|
|
||||||
class PerformGroupbyTool(AnalysisTool):
|
|
||||||
"""执行分组聚合工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "perform_groupby"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "按指定列分组,对另一列进行聚合计算(如求和、平均、计数等)。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"group_by": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "分组依据的列名"
|
|
||||||
},
|
|
||||||
"value_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要聚合的值列名,如果为空则计数"
|
|
||||||
},
|
|
||||||
"aggregation": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "聚合方式:count, sum, mean, min, max, std",
|
|
||||||
"default": "count"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["group_by"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行分组聚合。"""
|
|
||||||
group_by = kwargs.get('group_by')
|
|
||||||
value_column = kwargs.get('value_column')
|
|
||||||
aggregation = kwargs.get('aggregation', 'count')
|
|
||||||
|
|
||||||
if group_by not in data.columns:
|
|
||||||
return {'error': f'分组列 {group_by} 不存在'}
|
|
||||||
|
|
||||||
if value_column and value_column not in data.columns:
|
|
||||||
return {'error': f'值列 {value_column} 不存在'}
|
|
||||||
|
|
||||||
# 执行分组聚合
|
|
||||||
if value_column:
|
|
||||||
grouped = data.groupby(group_by, observed=True)[value_column]
|
|
||||||
else:
|
|
||||||
grouped = data.groupby(group_by, observed=True).size()
|
|
||||||
aggregation = 'count'
|
|
||||||
|
|
||||||
if aggregation == 'count':
|
|
||||||
if value_column:
|
|
||||||
result = grouped.count()
|
|
||||||
else:
|
|
||||||
result = grouped
|
|
||||||
elif aggregation == 'sum':
|
|
||||||
result = grouped.sum()
|
|
||||||
elif aggregation == 'mean':
|
|
||||||
result = grouped.mean()
|
|
||||||
elif aggregation == 'min':
|
|
||||||
result = grouped.min()
|
|
||||||
elif aggregation == 'max':
|
|
||||||
result = grouped.max()
|
|
||||||
elif aggregation == 'std':
|
|
||||||
result = grouped.std()
|
|
||||||
else:
|
|
||||||
return {'error': f'不支持的聚合方式: {aggregation}'}
|
|
||||||
|
|
||||||
# 转换为字典
|
|
||||||
groups = []
|
|
||||||
for group_value, agg_value in result.items():
|
|
||||||
groups.append({
|
|
||||||
'group': str(group_value),
|
|
||||||
'value': float(agg_value) if not pd.isna(agg_value) else 0.0
|
|
||||||
})
|
|
||||||
|
|
||||||
# 限制返回的分组数量,最多100个(隐私保护要求)
|
|
||||||
total_groups = len(groups)
|
|
||||||
if len(groups) > 100:
|
|
||||||
# 按值排序并取前100个
|
|
||||||
groups = sorted(groups, key=lambda x: x['value'], reverse=True)[:100]
|
|
||||||
|
|
||||||
return {
|
|
||||||
'group_by': group_by,
|
|
||||||
'value_column': value_column,
|
|
||||||
'aggregation': aggregation,
|
|
||||||
'groups': groups,
|
|
||||||
'total_groups': total_groups, # 记录原始分组数量
|
|
||||||
'returned_groups': len(groups) # 记录返回的分组数量
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于所有数据。"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class DetectOutliersTool(AnalysisTool):
|
|
||||||
"""检测异常值工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "detect_outliers"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "使用IQR方法或Z-score方法检测数值列中的异常值。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要检测的列名"
|
|
||||||
},
|
|
||||||
"method": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "检测方法:iqr 或 zscore",
|
|
||||||
"default": "iqr"
|
|
||||||
},
|
|
||||||
"threshold": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "阈值(IQR倍数或Z-score标准差倍数)",
|
|
||||||
"default": 1.5
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行异常值检测。"""
|
|
||||||
column = kwargs.get('column')
|
|
||||||
method = kwargs.get('method', 'iqr')
|
|
||||||
threshold = kwargs.get('threshold', 1.5)
|
|
||||||
|
|
||||||
if column not in data.columns:
|
|
||||||
return {'error': f'列 {column} 不存在'}
|
|
||||||
|
|
||||||
col_data = data[column].dropna()
|
|
||||||
|
|
||||||
if not pd.api.types.is_numeric_dtype(col_data):
|
|
||||||
return {'error': f'列 {column} 不是数值类型'}
|
|
||||||
|
|
||||||
if method == 'iqr':
|
|
||||||
# IQR 方法
|
|
||||||
q1 = col_data.quantile(0.25)
|
|
||||||
q3 = col_data.quantile(0.75)
|
|
||||||
iqr = q3 - q1
|
|
||||||
lower_bound = q1 - threshold * iqr
|
|
||||||
upper_bound = q3 + threshold * iqr
|
|
||||||
outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)]
|
|
||||||
elif method == 'zscore':
|
|
||||||
# Z-score 方法
|
|
||||||
z_scores = np.abs(stats.zscore(col_data))
|
|
||||||
outliers = col_data[z_scores > threshold]
|
|
||||||
else:
|
|
||||||
return {'error': f'不支持的检测方法: {method}'}
|
|
||||||
|
|
||||||
return {
|
|
||||||
'column': column,
|
|
||||||
'method': method,
|
|
||||||
'threshold': threshold,
|
|
||||||
'outlier_count': int(len(outliers)),
|
|
||||||
'outlier_percentage': float(len(outliers) / len(col_data) * 100),
|
|
||||||
'outlier_values': outliers.head(20).tolist(), # 最多返回20个异常值
|
|
||||||
'bounds': {
|
|
||||||
'lower': float(lower_bound) if method == 'iqr' else None,
|
|
||||||
'upper': float(upper_bound) if method == 'iqr' else None
|
|
||||||
} if method == 'iqr' else None
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含数值列的数据。"""
|
|
||||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
|
||||||
|
|
||||||
|
|
||||||
class CalculateTrendTool(AnalysisTool):
|
|
||||||
"""计算趋势工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "calculate_trend"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "计算时间序列数据的趋势,包括线性回归斜率、增长率等。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"time_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "时间列名"
|
|
||||||
},
|
|
||||||
"value_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "值列名"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["time_column", "value_column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行趋势计算。"""
|
|
||||||
time_column = kwargs.get('time_column')
|
|
||||||
value_column = kwargs.get('value_column')
|
|
||||||
|
|
||||||
if time_column not in data.columns:
|
|
||||||
return {'error': f'时间列 {time_column} 不存在'}
|
|
||||||
|
|
||||||
if value_column not in data.columns:
|
|
||||||
return {'error': f'值列 {value_column} 不存在'}
|
|
||||||
|
|
||||||
# 转换时间列
|
|
||||||
try:
|
|
||||||
time_data = pd.to_datetime(data[time_column])
|
|
||||||
except Exception as e:
|
|
||||||
return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'}
|
|
||||||
|
|
||||||
# 创建数值型时间索引(天数)
|
|
||||||
time_numeric = (time_data - time_data.min()).dt.days.values
|
|
||||||
value_data = data[value_column].dropna().values
|
|
||||||
|
|
||||||
if len(value_data) < 2:
|
|
||||||
return {'error': '数据点太少,无法计算趋势'}
|
|
||||||
|
|
||||||
# 线性回归
|
|
||||||
slope, intercept, r_value, p_value, std_err = stats.linregress(
|
|
||||||
time_numeric[:len(value_data)], value_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算增长率
|
|
||||||
first_value = value_data[0]
|
|
||||||
last_value = value_data[-1]
|
|
||||||
growth_rate = ((last_value - first_value) / first_value * 100) if first_value != 0 else 0
|
|
||||||
|
|
||||||
return {
|
|
||||||
'time_column': time_column,
|
|
||||||
'value_column': value_column,
|
|
||||||
'slope': float(slope),
|
|
||||||
'intercept': float(intercept),
|
|
||||||
'r_squared': float(r_value ** 2),
|
|
||||||
'p_value': float(p_value),
|
|
||||||
'growth_rate': float(growth_rate),
|
|
||||||
'trend': 'increasing' if slope > 0 else 'decreasing' if slope < 0 else 'stable'
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含日期时间列和数值列的数据。"""
|
|
||||||
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
|
|
||||||
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
|
|
||||||
return has_datetime and has_numeric
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
"""Tool manager — selects applicable tools based on data profile.
|
|
||||||
|
|
||||||
The ToolManager filters tools by data type compatibility (e.g., time series tools
|
|
||||||
need datetime columns). The actual tool+parameter selection is fully AI-driven.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
from src.tools.base import AnalysisTool, ToolRegistry, _global_registry
|
|
||||||
from src.models import DataProfile
|
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
|
||||||
"""
|
|
||||||
Tool manager that selects applicable tools based on data characteristics.
|
|
||||||
|
|
||||||
This is a filter, not a decision maker. AI decides which tools to actually
|
|
||||||
call and with what parameters at runtime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, registry: ToolRegistry = None):
|
|
||||||
self.registry = registry if registry else _global_registry
|
|
||||||
self._missing_tools: List[str] = []
|
|
||||||
|
|
||||||
def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]:
|
|
||||||
"""
|
|
||||||
Return all tools applicable to this data profile.
|
|
||||||
Each tool's is_applicable() checks if the data has the right column types.
|
|
||||||
"""
|
|
||||||
self._missing_tools = []
|
|
||||||
return self.registry.get_applicable_tools(data_profile)
|
|
||||||
|
|
||||||
def get_all_tools(self) -> List[AnalysisTool]:
|
|
||||||
"""Return all registered tools regardless of data profile."""
|
|
||||||
tool_names = self.registry.list_tools()
|
|
||||||
return [self.registry.get_tool(name) for name in tool_names]
|
|
||||||
|
|
||||||
def get_missing_tools(self) -> List[str]:
|
|
||||||
return list(set(self._missing_tools))
|
|
||||||
|
|
||||||
def get_tool_descriptions(self, tools: List[AnalysisTool] = None) -> List[Dict[str, Any]]:
|
|
||||||
"""Get tool descriptions for AI consumption."""
|
|
||||||
if tools is None:
|
|
||||||
tools = self.get_all_tools()
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
'name': t.name,
|
|
||||||
'description': t.description,
|
|
||||||
'parameters': t.parameters
|
|
||||||
}
|
|
||||||
for t in tools
|
|
||||||
]
|
|
||||||
@@ -1,443 +0,0 @@
|
|||||||
"""可视化工具。"""
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('Agg') # 使用非交互式后端
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from typing import Dict, Any
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.tools.base import AnalysisTool
|
|
||||||
from src.models import DataProfile
|
|
||||||
|
|
||||||
# 尝试导入 seaborn,如果不可用则使用 matplotlib
|
|
||||||
try:
|
|
||||||
import seaborn as sns
|
|
||||||
HAS_SEABORN = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_SEABORN = False
|
|
||||||
|
|
||||||
|
|
||||||
# 设置中文字体支持
|
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
|
||||||
plt.rcParams['axes.unicode_minus'] = False
|
|
||||||
|
|
||||||
|
|
||||||
class CreateBarChartTool(AnalysisTool):
|
|
||||||
"""创建柱状图工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "create_bar_chart"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "创建柱状图,用于展示分类数据的分布或比较。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"x_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "X轴列名(分类变量)"
|
|
||||||
},
|
|
||||||
"y_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Y轴列名(数值变量),如果为空则计数"
|
|
||||||
},
|
|
||||||
"title": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "图表标题",
|
|
||||||
"default": "柱状图"
|
|
||||||
},
|
|
||||||
"output_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "输出文件路径",
|
|
||||||
"default": "bar_chart.png"
|
|
||||||
},
|
|
||||||
"top_n": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "只显示前N个类别",
|
|
||||||
"default": 20
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["x_column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行柱状图生成。"""
|
|
||||||
x_column = kwargs.get('x_column')
|
|
||||||
y_column = kwargs.get('y_column')
|
|
||||||
title = kwargs.get('title', '柱状图')
|
|
||||||
output_path = kwargs.get('output_path', 'bar_chart.png')
|
|
||||||
top_n = kwargs.get('top_n', 20)
|
|
||||||
|
|
||||||
if x_column not in data.columns:
|
|
||||||
return {'error': f'列 {x_column} 不存在'}
|
|
||||||
|
|
||||||
if y_column and y_column not in data.columns:
|
|
||||||
return {'error': f'列 {y_column} 不存在'}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 准备数据
|
|
||||||
if y_column:
|
|
||||||
# 按 x_column 分组,对 y_column 求和
|
|
||||||
plot_data = data.groupby(x_column, observed=True)[y_column].sum().sort_values(ascending=False).head(top_n)
|
|
||||||
else:
|
|
||||||
# 计数
|
|
||||||
plot_data = data[x_column].value_counts().head(top_n)
|
|
||||||
|
|
||||||
# 创建图表
|
|
||||||
fig, ax = plt.subplots(figsize=(12, 6))
|
|
||||||
plot_data.plot(kind='bar', ax=ax)
|
|
||||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
|
||||||
ax.set_xlabel(x_column, fontsize=12)
|
|
||||||
ax.set_ylabel(y_column if y_column else '计数', fontsize=12)
|
|
||||||
ax.tick_params(axis='x', rotation=45)
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# 确保输出目录存在
|
|
||||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存图表
|
|
||||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'success': True,
|
|
||||||
'chart_path': output_path,
|
|
||||||
'chart_type': 'bar',
|
|
||||||
'data_points': len(plot_data),
|
|
||||||
'x_column': x_column,
|
|
||||||
'y_column': y_column
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {'error': f'生成柱状图失败: {str(e)}'}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于所有数据。"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class CreateLineChartTool(AnalysisTool):
|
|
||||||
"""创建折线图工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "create_line_chart"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "创建折线图,用于展示时间序列数据或趋势变化。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"x_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "X轴列名(通常是时间)"
|
|
||||||
},
|
|
||||||
"y_column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Y轴列名(数值变量)"
|
|
||||||
},
|
|
||||||
"title": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "图表标题",
|
|
||||||
"default": "折线图"
|
|
||||||
},
|
|
||||||
"output_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "输出文件路径",
|
|
||||||
"default": "line_chart.png"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["x_column", "y_column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行折线图生成。"""
|
|
||||||
x_column = kwargs.get('x_column')
|
|
||||||
y_column = kwargs.get('y_column')
|
|
||||||
title = kwargs.get('title', '折线图')
|
|
||||||
output_path = kwargs.get('output_path', 'line_chart.png')
|
|
||||||
|
|
||||||
if x_column not in data.columns:
|
|
||||||
return {'error': f'列 {x_column} 不存在'}
|
|
||||||
|
|
||||||
if y_column not in data.columns:
|
|
||||||
return {'error': f'列 {y_column} 不存在'}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 准备数据
|
|
||||||
plot_data = data[[x_column, y_column]].copy()
|
|
||||||
plot_data = plot_data.sort_values(x_column)
|
|
||||||
|
|
||||||
# 如果数据点太多,采样
|
|
||||||
if len(plot_data) > 1000:
|
|
||||||
step = len(plot_data) // 1000
|
|
||||||
plot_data = plot_data.iloc[::step]
|
|
||||||
|
|
||||||
# 创建图表
|
|
||||||
fig, ax = plt.subplots(figsize=(12, 6))
|
|
||||||
ax.plot(plot_data[x_column], plot_data[y_column], marker='o', markersize=3, linewidth=2)
|
|
||||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
|
||||||
ax.set_xlabel(x_column, fontsize=12)
|
|
||||||
ax.set_ylabel(y_column, fontsize=12)
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
plt.xticks(rotation=45)
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# 确保输出目录存在
|
|
||||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存图表
|
|
||||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'success': True,
|
|
||||||
'chart_path': output_path,
|
|
||||||
'chart_type': 'line',
|
|
||||||
'data_points': len(plot_data),
|
|
||||||
'x_column': x_column,
|
|
||||||
'y_column': y_column
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {'error': f'生成折线图失败: {str(e)}'}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含数值列的数据。"""
|
|
||||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
|
||||||
|
|
||||||
|
|
||||||
class CreatePieChartTool(AnalysisTool):
|
|
||||||
"""创建饼图工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "create_pie_chart"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "创建饼图,用于展示各部分占整体的比例。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"column": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要分析的列名"
|
|
||||||
},
|
|
||||||
"title": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "图表标题",
|
|
||||||
"default": "饼图"
|
|
||||||
},
|
|
||||||
"output_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "输出文件路径",
|
|
||||||
"default": "pie_chart.png"
|
|
||||||
},
|
|
||||||
"top_n": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "只显示前N个类别,其余归为'其他'",
|
|
||||||
"default": 10
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["column"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行饼图生成。"""
|
|
||||||
column = kwargs.get('column')
|
|
||||||
title = kwargs.get('title', '饼图')
|
|
||||||
output_path = kwargs.get('output_path', 'pie_chart.png')
|
|
||||||
top_n = kwargs.get('top_n', 10)
|
|
||||||
|
|
||||||
if column not in data.columns:
|
|
||||||
return {'error': f'列 {column} 不存在'}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 准备数据
|
|
||||||
value_counts = data[column].value_counts()
|
|
||||||
|
|
||||||
if len(value_counts) > top_n:
|
|
||||||
# 只保留前 N 个,其余归为"其他"
|
|
||||||
top_values = value_counts.head(top_n)
|
|
||||||
other_sum = value_counts.iloc[top_n:].sum()
|
|
||||||
plot_data = pd.concat([top_values, pd.Series({'其他': other_sum})])
|
|
||||||
else:
|
|
||||||
plot_data = value_counts
|
|
||||||
|
|
||||||
# 创建图表
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 8))
|
|
||||||
colors = plt.cm.Set3(range(len(plot_data)))
|
|
||||||
wedges, texts, autotexts = ax.pie(
|
|
||||||
plot_data,
|
|
||||||
labels=plot_data.index,
|
|
||||||
autopct='%1.1f%%',
|
|
||||||
colors=colors,
|
|
||||||
startangle=90
|
|
||||||
)
|
|
||||||
|
|
||||||
# 设置文本样式
|
|
||||||
for text in texts:
|
|
||||||
text.set_fontsize(10)
|
|
||||||
for autotext in autotexts:
|
|
||||||
autotext.set_color('white')
|
|
||||||
autotext.set_fontweight('bold')
|
|
||||||
autotext.set_fontsize(9)
|
|
||||||
|
|
||||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# 确保输出目录存在
|
|
||||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存图表
|
|
||||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'success': True,
|
|
||||||
'chart_path': output_path,
|
|
||||||
'chart_type': 'pie',
|
|
||||||
'categories': len(plot_data),
|
|
||||||
'column': column
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {'error': f'生成饼图失败: {str(e)}'}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于所有数据。"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class CreateHeatmapTool(AnalysisTool):
|
|
||||||
"""创建热力图工具。"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "create_heatmap"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return "创建热力图,用于展示数值矩阵或相关性矩阵。"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"columns": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"description": "要分析的列名列表,如果为空则使用所有数值列"
|
|
||||||
},
|
|
||||||
"title": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "图表标题",
|
|
||||||
"default": "相关性热力图"
|
|
||||||
},
|
|
||||||
"output_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "输出文件路径",
|
|
||||||
"default": "heatmap.png"
|
|
||||||
},
|
|
||||||
"method": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "相关系数方法:pearson, spearman, kendall",
|
|
||||||
"default": "pearson"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
|
||||||
"""执行热力图生成。"""
|
|
||||||
columns = kwargs.get('columns', [])
|
|
||||||
title = kwargs.get('title', '相关性热力图')
|
|
||||||
output_path = kwargs.get('output_path', 'heatmap.png')
|
|
||||||
method = kwargs.get('method', 'pearson')
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 如果没有指定列,使用所有数值列
|
|
||||||
if not columns:
|
|
||||||
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
|
|
||||||
else:
|
|
||||||
numeric_cols = [col for col in columns if col in data.columns]
|
|
||||||
|
|
||||||
if len(numeric_cols) < 2:
|
|
||||||
return {'error': '至少需要两个数值列来创建热力图'}
|
|
||||||
|
|
||||||
# 计算相关系数矩阵
|
|
||||||
corr_matrix = data[numeric_cols].corr(method=method)
|
|
||||||
|
|
||||||
# 创建图表
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 8))
|
|
||||||
|
|
||||||
if HAS_SEABORN:
|
|
||||||
# 使用 seaborn 创建更美观的热力图
|
|
||||||
sns.heatmap(
|
|
||||||
corr_matrix,
|
|
||||||
annot=True,
|
|
||||||
fmt='.2f',
|
|
||||||
cmap='coolwarm',
|
|
||||||
center=0,
|
|
||||||
square=True,
|
|
||||||
linewidths=1,
|
|
||||||
cbar_kws={"shrink": 0.8},
|
|
||||||
ax=ax
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 使用 matplotlib 创建基本热力图
|
|
||||||
im = ax.imshow(corr_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)
|
|
||||||
ax.set_xticks(range(len(corr_matrix.columns)))
|
|
||||||
ax.set_yticks(range(len(corr_matrix.columns)))
|
|
||||||
ax.set_xticklabels(corr_matrix.columns, rotation=45, ha='right')
|
|
||||||
ax.set_yticklabels(corr_matrix.columns)
|
|
||||||
|
|
||||||
# 添加数值标注
|
|
||||||
for i in range(len(corr_matrix.columns)):
|
|
||||||
for j in range(len(corr_matrix.columns)):
|
|
||||||
text = ax.text(j, i, f'{corr_matrix.iloc[i, j]:.2f}',
|
|
||||||
ha="center", va="center", color="black", fontsize=9)
|
|
||||||
|
|
||||||
plt.colorbar(im, ax=ax, shrink=0.8)
|
|
||||||
|
|
||||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# 确保输出目录存在
|
|
||||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存图表
|
|
||||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'success': True,
|
|
||||||
'chart_path': output_path,
|
|
||||||
'chart_type': 'heatmap',
|
|
||||||
'columns': numeric_cols,
|
|
||||||
'method': method
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {'error': f'生成热力图失败: {str(e)}'}
|
|
||||||
|
|
||||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
|
||||||
"""适用于包含至少两个数值列的数据。"""
|
|
||||||
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
|
|
||||||
return len(numeric_cols) >= 2
|
|
||||||
Reference in New Issue
Block a user