12345678
This commit is contained in:
@@ -1,171 +0,0 @@
|
||||
# 完整数据分析系统 - 执行结果
|
||||
|
||||
## 问题解决
|
||||
|
||||
### 核心问题
|
||||
工具注册失败,导致 `ToolManager.select_tools()` 返回 0 个工具。
|
||||
|
||||
### 根本原因
|
||||
`ToolManager` 在初始化时创建了一个新的空 `ToolRegistry` 实例,而工具实际上被注册到了全局注册表 `_global_registry` 中。两个注册表互不相通。
|
||||
|
||||
### 解决方案
|
||||
修改 `src/tools/tool_manager.py` 第 18 行:
|
||||
```python
|
||||
# 修改前
|
||||
self.registry = registry if registry else ToolRegistry()
|
||||
|
||||
# 修改后
|
||||
self.registry = registry if registry else _global_registry
|
||||
```
|
||||
|
||||
## 系统验证结果
|
||||
|
||||
### ✅ 阶段 1: AI 数据理解
|
||||
- **数据类型识别**: ticket (IT服务工单)
|
||||
- **数据质量评分**: 85.0/100
|
||||
- **关键字段识别**: 15 个
|
||||
- **数据规模**: 84 行 × 21 列
|
||||
- **隐私保护**: ✓ AI 只能看到表头和统计信息,无法访问原始数据行
|
||||
|
||||
**AI 分析摘要**:
|
||||
> 这是一个典型的IT服务工单数据集,记录了84个车辆相关问题的处理全流程。数据集包含完整的工单生命周期信息(创建、处理、关闭),主要涉及远程控制、导航、网络等车辆系统问题。数据质量较高,缺失率低(仅SIM和Notes字段缺失较多),但部分文本字段存在较长的非结构化描述。问题类型和模块分布显示远程控制问题是主要痛点(占比66%),工单主要来自邮件渠道(55%),平均关闭时长约5天。
|
||||
|
||||
### ✅ 阶段 2: 需求理解
|
||||
- **用户需求**: "Analyze ticket health, find main issues, efficiency and trends"
|
||||
- **生成目标**: 2 个分析目标
|
||||
1. 健康度分析 (优先级: 5/5)
|
||||
2. 趋势分析 (优先级: 4/5)
|
||||
|
||||
### ✅ 阶段 3: 分析规划
|
||||
- **生成任务**: 2 个
|
||||
- **预计时长**: 120 秒
|
||||
- **任务清单**:
|
||||
- [优先级 5] task_1. 质量评估 - 健康度分析
|
||||
- [优先级 4] task_2. 趋势分析 - 趋势分析
|
||||
|
||||
### ✅ 阶段 4: 任务执行
|
||||
- **可用工具**: 9 个(修复后)
|
||||
- get_column_distribution (列分布统计)
|
||||
- get_value_counts (值计数)
|
||||
- perform_groupby (分组聚合)
|
||||
- create_bar_chart (柱状图)
|
||||
- create_pie_chart (饼图)
|
||||
- calculate_statistics (描述性统计)
|
||||
- detect_outliers (异常值检测)
|
||||
- get_time_series (时间序列)
|
||||
- calculate_trend (趋势计算)
|
||||
|
||||
- **执行结果**:
|
||||
- 成功: 2/2 任务
|
||||
- 失败: 0/2 任务
|
||||
- 总执行时间: ~51 秒
|
||||
- 生成洞察: 2 条
|
||||
|
||||
### ✅ 阶段 5: 报告生成
|
||||
- **报告文件**: analysis_output/analysis_report.md
|
||||
- **报告长度**: 774 字符
|
||||
- **包含内容**:
|
||||
- 执行摘要
|
||||
- 数据概览(15个关键字段说明)
|
||||
- 详细分析
|
||||
- 结论与建议
|
||||
- 任务执行附录
|
||||
|
||||
## 系统架构验证
|
||||
|
||||
### 隐私保护机制 ✓
|
||||
1. **数据访问层隔离**: AI 无法直接访问原始数据
|
||||
2. **元数据暴露**: AI 只能看到列名、数据类型、统计信息
|
||||
3. **工具执行**: 工具在原始数据上执行,返回聚合结果
|
||||
4. **结果限制**:
|
||||
- 分组结果最多 100 个
|
||||
- 时间序列最多 100 个数据点
|
||||
- 异常值最多返回 20 个
|
||||
|
||||
### 配置管理 ✓
|
||||
所有 LLM API 调用已统一从 `.env` 文件读取配置:
|
||||
- `OPENAI_MODEL=mimo-v2-flash`
|
||||
- `OPENAI_BASE_URL=https://api.xiaomimimo.com/v1`
|
||||
- `OPENAI_API_KEY=[已配置]`
|
||||
|
||||
修改的文件:
|
||||
1. src/engines/task_execution.py
|
||||
2. src/engines/requirement_understanding.py
|
||||
3. src/engines/report_generation.py
|
||||
4. src/engines/plan_adjustment.py
|
||||
5. src/engines/analysis_planning.py
|
||||
|
||||
### 工具系统 ✓
|
||||
- **全局注册表**: 12 个工具已注册
|
||||
- **动态选择**: 根据数据特征自动选择适用工具
|
||||
- **类型检测**: 支持时间序列、分类、数值、地理数据
|
||||
- **参数验证**: JSON Schema 格式参数定义
|
||||
|
||||
## 测试数据
|
||||
|
||||
### cleaned_data.csv
|
||||
- **行数**: 84
|
||||
- **列数**: 21
|
||||
- **数据类型**: IT 服务工单
|
||||
- **主要字段**:
|
||||
- 工单号、来源、创建日期
|
||||
- 问题类型、问题描述、处理过程
|
||||
- 严重程度、工单状态、模块
|
||||
- 责任人、关闭日期、关闭时长
|
||||
- 车型、VIN
|
||||
|
||||
### 数据质量
|
||||
- **完整性**: 85/100
|
||||
- **缺失字段**: SIM (100%), Notes (较多)
|
||||
- **时间字段**: 创建日期、关闭日期
|
||||
- **分类字段**: 来源、问题类型、严重程度、工单状态、模块
|
||||
- **数值字段**: 关闭时长(天)
|
||||
|
||||
## 执行命令
|
||||
|
||||
```bash
|
||||
python run_analysis_en.py
|
||||
```
|
||||
|
||||
## 输出文件
|
||||
|
||||
```
|
||||
analysis_output/
|
||||
├── analysis_report.md # 分析报告
|
||||
└── *.png # 图表文件(如有生成)
|
||||
```
|
||||
|
||||
## 性能指标
|
||||
|
||||
- **数据加载**: < 1 秒
|
||||
- **AI 数据理解**: ~5 秒
|
||||
- **需求理解**: ~3 秒
|
||||
- **分析规划**: ~2 秒
|
||||
- **任务执行**: ~51 秒 (2 个任务)
|
||||
- **报告生成**: ~2 秒
|
||||
- **总耗时**: ~63 秒
|
||||
|
||||
## 系统状态
|
||||
|
||||
### ✅ 已完成
|
||||
1. 工具注册系统修复
|
||||
2. 配置管理统一
|
||||
3. 隐私保护验证
|
||||
4. 端到端分析流程
|
||||
5. 真实数据测试
|
||||
|
||||
### 📊 测试覆盖率
|
||||
- 单元测试: 314/328 通过 (95.7%)
|
||||
- 属性测试: 已实施
|
||||
- 集成测试: 已通过
|
||||
- 端到端测试: 已通过
|
||||
|
||||
## 结论
|
||||
|
||||
系统已完全就绪,可以进行生产环境部署。所有核心功能已验证,隐私保护机制有效,配置管理规范,工具系统运行正常。
|
||||
|
||||
---
|
||||
|
||||
**生成时间**: 2026-03-09 09:08:27
|
||||
**测试环境**: Windows, Python 3.x
|
||||
**数据集**: cleaned_data.csv (84 rows × 21 columns)
|
||||
177
analysis_output/run_20260309_102648/analysis_report.md
Normal file
177
analysis_output/run_20260309_102648/analysis_report.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# 工单分析报告
|
||||
|
||||
生成时间:2026-03-09 10:29:31
|
||||
数据源:cleaned_data.csv
|
||||
|
||||
---
|
||||
|
||||
# 工单数据分析报告
|
||||
|
||||
## 1. 执行摘要
|
||||
|
||||
本报告基于84条工单数据(质量分数88/100)进行全面分析,主要发现如下:
|
||||
|
||||
1. **工单处理效率存在显著异常**:平均关闭时长为54.77天,但存在2个异常工单(处理时长分别为277天和237天),占总数的2.38%。"Activation SIM"问题的平均处理时长高达142.5天,远高于其他问题类型。
|
||||
2. **工单来源渠道集中**:邮件渠道占比54.76%(46张),Telegram bot占比42.86%(36张),渠道来源仅占2.38%(2张),显示渠道管理可优化。
|
||||
3. **问题类型高度集中**:远程控制问题占主导(66.67%,56张),前五类问题占总量的87%以上,表明问题分布集中。
|
||||
4. **车型问题分布不均**:EXEED RX(T22)车型工单最多(45.24%,38张),JAECOO J7(T1EJ)次之(26.19%,22张),特定车型问题集中。
|
||||
5. **责任人工作负载差异大**:Vsevolod处理工单最多(31个),但平均处理时长66.68天;Vsevolod Tsoi平均处理时长最高(152天),而何韬处理效率最高(平均3.5天)。
|
||||
|
||||
## 2. 数据概览
|
||||
|
||||
- **数据类型**:工单(ticket)
|
||||
- **数据规模**:84行 × 21列
|
||||
- **数据质量**:88.0/100
|
||||
- **关键字段**:工单号、来源、创建日期、问题类型、问题描述、处理过程、跟踪记录、严重程度、工单状态、模块、责任人、关闭日期、车型、VIN、关闭时长(天)等
|
||||
- **分析时间范围**:2025年1月2日至2025年2月24日
|
||||
|
||||
## 3. 详细分析
|
||||
|
||||
### 3.1 工单概况分析
|
||||
|
||||
工单状态分布显示,82.14%(69张)已关闭,17.86%(15张)临时关闭,表明大部分问题已解决。
|
||||
|
||||

|
||||
|
||||
**来源渠道分析**:
|
||||
- 邮件(Mail):54.76%(46张)
|
||||
- Telegram bot:42.86%(36张)
|
||||
- Telegram channel:2.38%(2张)
|
||||
|
||||
**问题类型分布**(前5位):
|
||||
1. 远程控制(Remote control):66.67%(56张)
|
||||
2. 网络问题(Network):7.14%(6张)
|
||||
3. 导航问题(Navi):5.95%(5张)
|
||||
4. 应用问题(Application):4.76%(4张)
|
||||
5. 成员中心认证问题:3.57%(3张)
|
||||
|
||||
前五类问题合计占总量的87.09%,显示问题类型高度集中。
|
||||
|
||||
### 3.2 工单处理效率分析
|
||||
|
||||
**关闭时长统计**:
|
||||
- 平均值:54.77天
|
||||
- 中位数:41天
|
||||
- 标准差:48.19天
|
||||
- 最小值:2天
|
||||
- 最大值:277天
|
||||
- 四分位距(IQR):26.25天(Q25)至84.5天(Q75)
|
||||
|
||||
**异常值检测**:
|
||||
- 检测到2个异常工单(277天和237天),占总数的2.38%
|
||||
- 异常值上限为171.875天,这两个工单远超此阈值
|
||||
|
||||

|
||||
|
||||
**按问题类型分析平均关闭时长**:
|
||||
- Activation SIM:142.5天(效率最低)
|
||||
- Remote control:66.5天
|
||||
- PKI problem:47天
|
||||
- doesn't exist on TSP:31.67天
|
||||
- Network:24天
|
||||
|
||||

|
||||
|
||||
### 3.3 工单内容与趋势分析
|
||||
|
||||
**工单创建时间趋势**(2025年1-2月):
|
||||
- 创建高峰期:2025年1月13日(8个工单)
|
||||
- 创建低谷期:2025年1月8日、15日、29日、30日及2月多日(仅1个工单)
|
||||
- 整体趋势波动较大,无明显持续上升或下降模式
|
||||
|
||||

|
||||
|
||||
### 3.4 责任人工作负载分析
|
||||
|
||||
**工单处理数量**:
|
||||
- Vsevolod:31个(最多)
|
||||
- Evgeniy:28个
|
||||
- Kostya:5个
|
||||
- 何韬:4个
|
||||
|
||||
**平均关闭时长**:
|
||||
- Vsevolod Tsoi:152天(最高,仅处理2个工单)
|
||||
- 林兆国:89天
|
||||
- Vadim:69天
|
||||
- Vsevolod:66.68天
|
||||
- Evgeniy:62.39天
|
||||
- 何韬:3.5天(最低,效率最高)
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### 3.5 车辆相关信息分析
|
||||
|
||||
**车型分布**:
|
||||
- EXEED RX(T22):45.24%(38张)
|
||||
- JAECOO J7(T1EJ):26.19%(22张)
|
||||
- EXEED VX FL(M36T):20.24%(17张)
|
||||
- CHERY TIGGO 9 (T28):8.33%(7张)
|
||||
|
||||
**VIN分布**:
|
||||
- LVTDD24B1RG023450和LVTDD24B1RG021245各出现2次,表明特定车辆问题可能复发
|
||||
- 其他VIN均唯一,显示问题主要分散于不同车辆
|
||||
|
||||
## 4. 结论与建议
|
||||
|
||||
### 4.1 主要结论
|
||||
|
||||
1. **处理效率需优化**:整体平均关闭时长54.77天,且存在严重异常值(277天和237天),"Activation SIM"问题处理时长高达142.5天。
|
||||
2. **渠道管理可改进**:邮件渠道占比过高(54.76%),Telegram渠道未充分利用。
|
||||
3. **问题类型集中**:远程控制问题占66.67%,需针对性优化。
|
||||
4. **车型问题分布不均**:EXEED RX(T22)车型问题最集中(45.24%)。
|
||||
5. **责任人效率差异大**:Vsevolod Tsoi处理时长152天,而何韬仅3.5天,存在明显效率差距。
|
||||
|
||||
### 4.2 可操作建议
|
||||
|
||||
1. **优化异常工单处理流程**:
|
||||
- 针对处理时长超过100天的工单(如Activation SIM问题)建立专项处理机制
|
||||
- 定期审查异常工单,分析根本原因
|
||||
- 依据:异常值检测显示2个工单处理时长分别为277天和237天
|
||||
|
||||
2. **平衡渠道分配**:
|
||||
- 推广Telegram bot使用,减轻邮件渠道压力
|
||||
- 依据:邮件渠道占比54.76%,Telegram bot仅42.86%
|
||||
|
||||
3. **加强远程控制问题管理**:
|
||||
- 建立远程控制问题知识库和快速响应机制
|
||||
- 依据:远程控制问题占66.67%(56张)
|
||||
|
||||
4. **针对EXEED RX(T22)车型专项优化**:
|
||||
- 分析该车型高工单率的原因
|
||||
- 依据:该车型工单占比45.24%(38张)
|
||||
|
||||
5. **提升责任人效率一致性**:
|
||||
- 建立效率标杆(如何韬的3.5天平均时长)
|
||||
- 对处理时长较长的责任人提供培训和支持
|
||||
- 依据:责任人平均处理时长从3.5天到152天不等
|
||||
|
||||
6. **建立工单创建趋势监控**:
|
||||
- 监控工单创建高峰期(如1月13日的8个工单)
|
||||
- 提前调配资源应对潜在高峰
|
||||
- 依据:工单创建趋势显示明显波动
|
||||
|
||||
通过实施这些建议,预计可显著提升工单处理效率,优化资源分配,并改善客户满意度。
|
||||
|
||||
---
|
||||
|
||||
## 分析追溯
|
||||
|
||||
本报告基于以下分析任务:
|
||||
|
||||
- ✓ 工单概况分析:基本分布统计
|
||||
- 工单总数为84,其中82.14%(69张)已关闭,17.86%(15张)临时关闭,表明大部分问题已解决。
|
||||
- 工单来源中,邮件(Mail)占比最高,达54.76%(46张),Telegram bot占42.86%(36张),渠道来源仅占2.38%(2张)。
|
||||
- ✓ 工单处理效率分析:关闭时长统计
|
||||
- 关闭时长的平均值为54.77天,中位数为41天,标准差为48.19天,表明数据右偏分布,存在处理时间较长的工单。
|
||||
- 异常值检测发现2个工单(277天和237天)处理时间过长,占总工单数的2.38%,需重点关注这些异常情况。
|
||||
- ✓ 工单内容与趋势分析:文本和时间序列
|
||||
- 2025年1月13日工单创建数量最高,达到8个,是整体趋势中的峰值。
|
||||
- 2025年1月8日、15日、29日、30日及2月多日工单创建数量最低,仅为1个,显示这些日期工单活动较少。
|
||||
- ✓ 责任人工作负载分析
|
||||
- Vsevolod 处理工单数量最多(31个),但平均关闭时长为66.68天,工作量大且周期长。
|
||||
- Vsevolod Tsoi 平均关闭时长最高(152天),但仅处理2个工单,可能存在效率问题或复杂任务。
|
||||
- ✓ 车辆相关信息分析:车型和VIN分布
|
||||
- EXEED RX(T22)车型占比最高,达45.24%(38/84),是问题集中的主要车型。
|
||||
- JAECOO J7(T1EJ)车型工单数为22,占比26.19%,是第二大问题车型。
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
analysis_output/run_20260309_102648/charts/bar_chart_trend.png
Normal file
BIN
analysis_output/run_20260309_102648/charts/bar_chart_trend.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
BIN
analysis_output/run_20260309_102648/charts/outlier_bar_chart.png
Normal file
BIN
analysis_output/run_20260309_102648/charts/outlier_bar_chart.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
analysis_output/run_20260309_102648/charts/outlier_pie_chart.png
Normal file
BIN
analysis_output/run_20260309_102648/charts/outlier_pie_chart.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
@@ -23,7 +23,7 @@ from src.engines.ai_data_understanding import ai_understand_data_with_dal
|
||||
from src.engines.requirement_understanding import understand_requirement
|
||||
from src.engines.analysis_planning import plan_analysis
|
||||
from src.engines.task_execution import execute_task
|
||||
from src.engines.report_generation import generate_report
|
||||
from src.engines.report_generation import generate_report, _convert_chart_paths_in_report
|
||||
from src.tools.tool_manager import ToolManager
|
||||
from src.tools.base import _global_registry
|
||||
from src.models import DataProfile, AnalysisResult
|
||||
@@ -156,7 +156,8 @@ def run_analysis(
|
||||
else:
|
||||
report = generate_report(results, requirement, profile, output_path=run_dir)
|
||||
|
||||
# Save report
|
||||
# Save report — convert chart paths to relative (./charts/xxx.png)
|
||||
report = _convert_chart_paths_in_report(report, run_dir)
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report)
|
||||
|
||||
@@ -262,7 +263,8 @@ def _generate_template_report(
|
||||
|
||||
|
||||
def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> str:
|
||||
"""Collect all chart paths from task results for embedding in reports."""
|
||||
"""Collect all chart paths from task results for embedding in reports.
|
||||
Returns paths relative to run_dir (e.g. ./charts/bar_chart.png)."""
|
||||
paths = []
|
||||
for r in results:
|
||||
if not r.success:
|
||||
@@ -280,7 +282,10 @@ def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> st
|
||||
paths.append(cp)
|
||||
if not paths:
|
||||
return "(无图表)"
|
||||
return "\n".join(f"- {p}" for p in paths)
|
||||
# Convert to relative paths
|
||||
from src.engines.report_generation import _to_relative_chart_path
|
||||
rel_paths = [_to_relative_chart_path(p, run_dir) for p in paths]
|
||||
return "\n".join(f"- {p}" for p in rel_paths)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -4,9 +4,11 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import PurePosixPath, Path
|
||||
|
||||
from src.models.analysis_result import AnalysisResult
|
||||
from src.models.requirement_spec import RequirementSpec
|
||||
@@ -310,6 +312,56 @@ def _generate_conclusion_summary(key_findings: List[Dict[str, Any]]) -> str:
|
||||
|
||||
|
||||
|
||||
def _to_relative_chart_path(chart_path: str, report_dir: str = "") -> str:
|
||||
"""
|
||||
将图表绝对路径转换为相对于报告文件的路径。
|
||||
|
||||
例如:
|
||||
chart_path = "analysis_output/run_xxx/charts/bar.png"
|
||||
report_dir = "analysis_output/run_xxx"
|
||||
→ "./charts/bar.png"
|
||||
|
||||
如果无法计算相对路径,则只保留 ./charts/filename.png
|
||||
"""
|
||||
if not chart_path:
|
||||
return chart_path
|
||||
|
||||
# 统一为正斜杠
|
||||
chart_path = chart_path.replace('\\', '/')
|
||||
|
||||
if report_dir:
|
||||
report_dir = report_dir.replace('\\', '/')
|
||||
try:
|
||||
rel = os.path.relpath(chart_path, report_dir).replace('\\', '/')
|
||||
return './' + rel if not rel.startswith('.') else rel
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback: 提取 charts/filename 部分
|
||||
parts = chart_path.replace('\\', '/').split('/')
|
||||
if 'charts' in parts:
|
||||
idx = parts.index('charts')
|
||||
return './' + '/'.join(parts[idx:])
|
||||
|
||||
# 最后兜底:直接用文件名
|
||||
return './charts/' + os.path.basename(chart_path)
|
||||
|
||||
|
||||
def _convert_chart_paths_in_report(report: str, report_dir: str = "") -> str:
|
||||
"""
|
||||
将报告中所有  的图表路径转换为相对路径。
|
||||
同时统一反斜杠为正斜杠。
|
||||
"""
|
||||
def replace_img(match):
|
||||
alt = match.group(1)
|
||||
path = match.group(2)
|
||||
rel_path = _to_relative_chart_path(path, report_dir)
|
||||
return f''
|
||||
|
||||
# 匹配 
|
||||
return re.sub(r'!\[([^\]]*)\]\(([^)]+)\)', replace_img, report)
|
||||
|
||||
|
||||
def generate_report(
|
||||
results: List[AnalysisResult],
|
||||
requirement: RequirementSpec,
|
||||
@@ -370,6 +422,10 @@ def generate_report(
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report)
|
||||
|
||||
# 将图表路径转换为相对于报告所在目录的路径
|
||||
report_dir = output_path if output_path and os.path.isdir(output_path) else ""
|
||||
report = _convert_chart_paths_in_report(report, report_dir)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
@@ -394,11 +450,11 @@ def _generate_report_with_ai(
|
||||
# 收集所有图表路径
|
||||
for viz in (r.visualizations or []):
|
||||
if viz:
|
||||
all_chart_paths.append(viz)
|
||||
all_chart_paths.append(_to_relative_chart_path(viz))
|
||||
if isinstance(r.data, dict):
|
||||
for key, val in r.data.items():
|
||||
if isinstance(val, dict) and val.get('chart_path'):
|
||||
all_chart_paths.append(val['chart_path'])
|
||||
all_chart_paths.append(_to_relative_chart_path(val['chart_path']))
|
||||
|
||||
data_section = "\n\n".join(data_summaries) if data_summaries else "无详细数据"
|
||||
|
||||
|
||||
Binary file not shown.
@@ -1,301 +0,0 @@
|
||||
"""数据查询工具。"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.tools.base import AnalysisTool
|
||||
from src.models import DataProfile
|
||||
|
||||
|
||||
class GetColumnDistributionTool(AnalysisTool):
|
||||
"""获取列的分布统计工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_column_distribution"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "获取指定列的分布统计信息,包括值计数、百分比等。适用于分类和数值列。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column": {
|
||||
"type": "string",
|
||||
"description": "要分析的列名"
|
||||
},
|
||||
"top_n": {
|
||||
"type": "integer",
|
||||
"description": "返回前N个最常见的值",
|
||||
"default": 10
|
||||
}
|
||||
},
|
||||
"required": ["column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行列分布分析。"""
|
||||
column = kwargs.get('column')
|
||||
top_n = kwargs.get('top_n', 10)
|
||||
|
||||
if column not in data.columns:
|
||||
return {'error': f'列 {column} 不存在'}
|
||||
|
||||
col_data = data[column]
|
||||
value_counts = col_data.value_counts().head(top_n)
|
||||
total = len(col_data.dropna())
|
||||
|
||||
distribution = []
|
||||
for value, count in value_counts.items():
|
||||
distribution.append({
|
||||
'value': str(value),
|
||||
'count': int(count),
|
||||
'percentage': float(count / total * 100) if total > 0 else 0.0
|
||||
})
|
||||
|
||||
return {
|
||||
'column': column,
|
||||
'total_count': int(total),
|
||||
'unique_count': int(col_data.nunique()),
|
||||
'missing_count': int(col_data.isna().sum()),
|
||||
'distribution': distribution
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于所有数据。"""
|
||||
return True
|
||||
|
||||
|
||||
class GetValueCountsTool(AnalysisTool):
|
||||
"""获取值计数工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_value_counts"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "获取指定列的值计数,返回每个唯一值的出现次数。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column": {
|
||||
"type": "string",
|
||||
"description": "要分析的列名"
|
||||
},
|
||||
"normalize": {
|
||||
"type": "boolean",
|
||||
"description": "是否返回百分比而不是计数",
|
||||
"default": False
|
||||
}
|
||||
},
|
||||
"required": ["column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行值计数。"""
|
||||
column = kwargs.get('column')
|
||||
normalize = kwargs.get('normalize', False)
|
||||
|
||||
if column not in data.columns:
|
||||
return {'error': f'列 {column} 不存在'}
|
||||
|
||||
value_counts = data[column].value_counts(normalize=normalize)
|
||||
|
||||
result = {}
|
||||
for value, count in value_counts.items():
|
||||
result[str(value)] = float(count) if normalize else int(count)
|
||||
|
||||
return {
|
||||
'column': column,
|
||||
'value_counts': result,
|
||||
'normalized': normalize
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于所有数据。"""
|
||||
return True
|
||||
|
||||
|
||||
class GetTimeSeriesTool(AnalysisTool):
|
||||
"""获取时间序列数据工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_time_series"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "获取时间序列数据,按时间聚合指定指标。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"time_column": {
|
||||
"type": "string",
|
||||
"description": "时间列名"
|
||||
},
|
||||
"value_column": {
|
||||
"type": "string",
|
||||
"description": "要聚合的值列名"
|
||||
},
|
||||
"aggregation": {
|
||||
"type": "string",
|
||||
"description": "聚合方式:count, sum, mean, min, max",
|
||||
"default": "count"
|
||||
},
|
||||
"frequency": {
|
||||
"type": "string",
|
||||
"description": "时间频率:D(天), W(周), M(月), Y(年)",
|
||||
"default": "D"
|
||||
}
|
||||
},
|
||||
"required": ["time_column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行时间序列分析。"""
|
||||
time_column = kwargs.get('time_column')
|
||||
value_column = kwargs.get('value_column')
|
||||
aggregation = kwargs.get('aggregation', 'count')
|
||||
frequency = kwargs.get('frequency', 'D')
|
||||
|
||||
if time_column not in data.columns:
|
||||
return {'error': f'时间列 {time_column} 不存在'}
|
||||
|
||||
# 转换为日期时间类型
|
||||
try:
|
||||
time_data = pd.to_datetime(data[time_column])
|
||||
except Exception as e:
|
||||
return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'}
|
||||
|
||||
# 创建临时 DataFrame
|
||||
temp_df = pd.DataFrame({'time': time_data})
|
||||
|
||||
if value_column:
|
||||
if value_column not in data.columns:
|
||||
return {'error': f'值列 {value_column} 不存在'}
|
||||
temp_df['value'] = data[value_column]
|
||||
|
||||
# 设置时间索引
|
||||
temp_df.set_index('time', inplace=True)
|
||||
|
||||
# 按频率重采样
|
||||
if value_column:
|
||||
if aggregation == 'count':
|
||||
result = temp_df.resample(frequency).count()
|
||||
elif aggregation == 'sum':
|
||||
result = temp_df.resample(frequency).sum()
|
||||
elif aggregation == 'mean':
|
||||
result = temp_df.resample(frequency).mean()
|
||||
elif aggregation == 'min':
|
||||
result = temp_df.resample(frequency).min()
|
||||
elif aggregation == 'max':
|
||||
result = temp_df.resample(frequency).max()
|
||||
else:
|
||||
return {'error': f'不支持的聚合方式: {aggregation}'}
|
||||
else:
|
||||
result = temp_df.resample(frequency).size().to_frame('count')
|
||||
|
||||
# 转换为字典
|
||||
time_series = []
|
||||
for timestamp, row in result.iterrows():
|
||||
time_series.append({
|
||||
'time': timestamp.strftime('%Y-%m-%d'),
|
||||
'value': float(row.iloc[0]) if not pd.isna(row.iloc[0]) else 0.0
|
||||
})
|
||||
|
||||
# 限制返回的数据点数量,最多100个(隐私保护要求)
|
||||
if len(time_series) > 100:
|
||||
# 均匀采样以保持趋势
|
||||
step = len(time_series) / 100
|
||||
sampled_indices = [int(i * step) for i in range(100)]
|
||||
time_series = [time_series[i] for i in sampled_indices]
|
||||
|
||||
return {
|
||||
'time_column': time_column,
|
||||
'value_column': value_column,
|
||||
'aggregation': aggregation,
|
||||
'frequency': frequency,
|
||||
'time_series': time_series,
|
||||
'total_points': len(result), # 记录原始数据点数量
|
||||
'returned_points': len(time_series) # 记录返回的数据点数量
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含日期时间列的数据。"""
|
||||
return any(col.dtype == 'datetime' for col in data_profile.columns)
|
||||
|
||||
|
||||
class GetCorrelationTool(AnalysisTool):
|
||||
"""获取相关性分析工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_correlation"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "计算数值列之间的相关系数矩阵。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"columns": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "要分析的列名列表,如果为空则分析所有数值列"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "相关系数方法:pearson, spearman, kendall",
|
||||
"default": "pearson"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行相关性分析。"""
|
||||
columns = kwargs.get('columns', [])
|
||||
method = kwargs.get('method', 'pearson')
|
||||
|
||||
# 如果没有指定列,使用所有数值列
|
||||
if not columns:
|
||||
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
|
||||
else:
|
||||
numeric_cols = [col for col in columns if col in data.columns]
|
||||
|
||||
if len(numeric_cols) < 2:
|
||||
return {'error': '至少需要两个数值列来计算相关性'}
|
||||
|
||||
# 计算相关系数矩阵
|
||||
corr_matrix = data[numeric_cols].corr(method=method)
|
||||
|
||||
# 转换为字典格式
|
||||
correlation = {}
|
||||
for col1 in corr_matrix.columns:
|
||||
correlation[col1] = {}
|
||||
for col2 in corr_matrix.columns:
|
||||
correlation[col1][col2] = float(corr_matrix.loc[col1, col2])
|
||||
|
||||
return {
|
||||
'columns': numeric_cols,
|
||||
'method': method,
|
||||
'correlation_matrix': correlation
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含至少两个数值列的数据。"""
|
||||
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
|
||||
return len(numeric_cols) >= 2
|
||||
@@ -1,325 +0,0 @@
|
||||
"""统计分析工具。"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, Any
|
||||
from scipy import stats
|
||||
|
||||
from src.tools.base import AnalysisTool
|
||||
from src.models import DataProfile
|
||||
|
||||
|
||||
class CalculateStatisticsTool(AnalysisTool):
|
||||
"""计算描述性统计工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "calculate_statistics"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "计算指定列的描述性统计信息,包括均值、中位数、标准差、最小值、最大值等。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column": {
|
||||
"type": "string",
|
||||
"description": "要分析的列名"
|
||||
}
|
||||
},
|
||||
"required": ["column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行统计计算。"""
|
||||
column = kwargs.get('column')
|
||||
|
||||
if column not in data.columns:
|
||||
return {'error': f'列 {column} 不存在'}
|
||||
|
||||
col_data = data[column].dropna()
|
||||
|
||||
if not pd.api.types.is_numeric_dtype(col_data):
|
||||
return {'error': f'列 {column} 不是数值类型'}
|
||||
|
||||
statistics = {
|
||||
'column': column,
|
||||
'count': int(len(col_data)),
|
||||
'mean': float(col_data.mean()),
|
||||
'median': float(col_data.median()),
|
||||
'std': float(col_data.std()),
|
||||
'min': float(col_data.min()),
|
||||
'max': float(col_data.max()),
|
||||
'q25': float(col_data.quantile(0.25)),
|
||||
'q75': float(col_data.quantile(0.75)),
|
||||
'skewness': float(col_data.skew()),
|
||||
'kurtosis': float(col_data.kurtosis())
|
||||
}
|
||||
|
||||
return statistics
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含数值列的数据。"""
|
||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
||||
|
||||
|
||||
class PerformGroupbyTool(AnalysisTool):
|
||||
"""执行分组聚合工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "perform_groupby"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "按指定列分组,对另一列进行聚合计算(如求和、平均、计数等)。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"group_by": {
|
||||
"type": "string",
|
||||
"description": "分组依据的列名"
|
||||
},
|
||||
"value_column": {
|
||||
"type": "string",
|
||||
"description": "要聚合的值列名,如果为空则计数"
|
||||
},
|
||||
"aggregation": {
|
||||
"type": "string",
|
||||
"description": "聚合方式:count, sum, mean, min, max, std",
|
||||
"default": "count"
|
||||
}
|
||||
},
|
||||
"required": ["group_by"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行分组聚合。"""
|
||||
group_by = kwargs.get('group_by')
|
||||
value_column = kwargs.get('value_column')
|
||||
aggregation = kwargs.get('aggregation', 'count')
|
||||
|
||||
if group_by not in data.columns:
|
||||
return {'error': f'分组列 {group_by} 不存在'}
|
||||
|
||||
if value_column and value_column not in data.columns:
|
||||
return {'error': f'值列 {value_column} 不存在'}
|
||||
|
||||
# 执行分组聚合
|
||||
if value_column:
|
||||
grouped = data.groupby(group_by, observed=True)[value_column]
|
||||
else:
|
||||
grouped = data.groupby(group_by, observed=True).size()
|
||||
aggregation = 'count'
|
||||
|
||||
if aggregation == 'count':
|
||||
if value_column:
|
||||
result = grouped.count()
|
||||
else:
|
||||
result = grouped
|
||||
elif aggregation == 'sum':
|
||||
result = grouped.sum()
|
||||
elif aggregation == 'mean':
|
||||
result = grouped.mean()
|
||||
elif aggregation == 'min':
|
||||
result = grouped.min()
|
||||
elif aggregation == 'max':
|
||||
result = grouped.max()
|
||||
elif aggregation == 'std':
|
||||
result = grouped.std()
|
||||
else:
|
||||
return {'error': f'不支持的聚合方式: {aggregation}'}
|
||||
|
||||
# 转换为字典
|
||||
groups = []
|
||||
for group_value, agg_value in result.items():
|
||||
groups.append({
|
||||
'group': str(group_value),
|
||||
'value': float(agg_value) if not pd.isna(agg_value) else 0.0
|
||||
})
|
||||
|
||||
# 限制返回的分组数量,最多100个(隐私保护要求)
|
||||
total_groups = len(groups)
|
||||
if len(groups) > 100:
|
||||
# 按值排序并取前100个
|
||||
groups = sorted(groups, key=lambda x: x['value'], reverse=True)[:100]
|
||||
|
||||
return {
|
||||
'group_by': group_by,
|
||||
'value_column': value_column,
|
||||
'aggregation': aggregation,
|
||||
'groups': groups,
|
||||
'total_groups': total_groups, # 记录原始分组数量
|
||||
'returned_groups': len(groups) # 记录返回的分组数量
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于所有数据。"""
|
||||
return True
|
||||
|
||||
|
||||
class DetectOutliersTool(AnalysisTool):
|
||||
"""检测异常值工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "detect_outliers"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "使用IQR方法或Z-score方法检测数值列中的异常值。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column": {
|
||||
"type": "string",
|
||||
"description": "要检测的列名"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "检测方法:iqr 或 zscore",
|
||||
"default": "iqr"
|
||||
},
|
||||
"threshold": {
|
||||
"type": "number",
|
||||
"description": "阈值(IQR倍数或Z-score标准差倍数)",
|
||||
"default": 1.5
|
||||
}
|
||||
},
|
||||
"required": ["column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行异常值检测。"""
|
||||
column = kwargs.get('column')
|
||||
method = kwargs.get('method', 'iqr')
|
||||
threshold = kwargs.get('threshold', 1.5)
|
||||
|
||||
if column not in data.columns:
|
||||
return {'error': f'列 {column} 不存在'}
|
||||
|
||||
col_data = data[column].dropna()
|
||||
|
||||
if not pd.api.types.is_numeric_dtype(col_data):
|
||||
return {'error': f'列 {column} 不是数值类型'}
|
||||
|
||||
if method == 'iqr':
|
||||
# IQR 方法
|
||||
q1 = col_data.quantile(0.25)
|
||||
q3 = col_data.quantile(0.75)
|
||||
iqr = q3 - q1
|
||||
lower_bound = q1 - threshold * iqr
|
||||
upper_bound = q3 + threshold * iqr
|
||||
outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)]
|
||||
elif method == 'zscore':
|
||||
# Z-score 方法
|
||||
z_scores = np.abs(stats.zscore(col_data))
|
||||
outliers = col_data[z_scores > threshold]
|
||||
else:
|
||||
return {'error': f'不支持的检测方法: {method}'}
|
||||
|
||||
return {
|
||||
'column': column,
|
||||
'method': method,
|
||||
'threshold': threshold,
|
||||
'outlier_count': int(len(outliers)),
|
||||
'outlier_percentage': float(len(outliers) / len(col_data) * 100),
|
||||
'outlier_values': outliers.head(20).tolist(), # 最多返回20个异常值
|
||||
'bounds': {
|
||||
'lower': float(lower_bound) if method == 'iqr' else None,
|
||||
'upper': float(upper_bound) if method == 'iqr' else None
|
||||
} if method == 'iqr' else None
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含数值列的数据。"""
|
||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
||||
|
||||
|
||||
class CalculateTrendTool(AnalysisTool):
|
||||
"""计算趋势工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "calculate_trend"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "计算时间序列数据的趋势,包括线性回归斜率、增长率等。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"time_column": {
|
||||
"type": "string",
|
||||
"description": "时间列名"
|
||||
},
|
||||
"value_column": {
|
||||
"type": "string",
|
||||
"description": "值列名"
|
||||
}
|
||||
},
|
||||
"required": ["time_column", "value_column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行趋势计算。"""
|
||||
time_column = kwargs.get('time_column')
|
||||
value_column = kwargs.get('value_column')
|
||||
|
||||
if time_column not in data.columns:
|
||||
return {'error': f'时间列 {time_column} 不存在'}
|
||||
|
||||
if value_column not in data.columns:
|
||||
return {'error': f'值列 {value_column} 不存在'}
|
||||
|
||||
# 转换时间列
|
||||
try:
|
||||
time_data = pd.to_datetime(data[time_column])
|
||||
except Exception as e:
|
||||
return {'error': f'无法将 {time_column} 转换为日期时间: {str(e)}'}
|
||||
|
||||
# 创建数值型时间索引(天数)
|
||||
time_numeric = (time_data - time_data.min()).dt.days.values
|
||||
value_data = data[value_column].dropna().values
|
||||
|
||||
if len(value_data) < 2:
|
||||
return {'error': '数据点太少,无法计算趋势'}
|
||||
|
||||
# 线性回归
|
||||
slope, intercept, r_value, p_value, std_err = stats.linregress(
|
||||
time_numeric[:len(value_data)], value_data
|
||||
)
|
||||
|
||||
# 计算增长率
|
||||
first_value = value_data[0]
|
||||
last_value = value_data[-1]
|
||||
growth_rate = ((last_value - first_value) / first_value * 100) if first_value != 0 else 0
|
||||
|
||||
return {
|
||||
'time_column': time_column,
|
||||
'value_column': value_column,
|
||||
'slope': float(slope),
|
||||
'intercept': float(intercept),
|
||||
'r_squared': float(r_value ** 2),
|
||||
'p_value': float(p_value),
|
||||
'growth_rate': float(growth_rate),
|
||||
'trend': 'increasing' if slope > 0 else 'decreasing' if slope < 0 else 'stable'
|
||||
}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含日期时间列和数值列的数据。"""
|
||||
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
|
||||
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
|
||||
return has_datetime and has_numeric
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Tool manager — selects applicable tools based on data profile.
|
||||
|
||||
The ToolManager filters tools by data type compatibility (e.g., time series tools
|
||||
need datetime columns). The actual tool+parameter selection is fully AI-driven.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from src.tools.base import AnalysisTool, ToolRegistry, _global_registry
|
||||
from src.models import DataProfile
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager that selects applicable tools based on data characteristics.
|
||||
|
||||
This is a filter, not a decision maker. AI decides which tools to actually
|
||||
call and with what parameters at runtime.
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ToolRegistry = None):
|
||||
self.registry = registry if registry else _global_registry
|
||||
self._missing_tools: List[str] = []
|
||||
|
||||
def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]:
|
||||
"""
|
||||
Return all tools applicable to this data profile.
|
||||
Each tool's is_applicable() checks if the data has the right column types.
|
||||
"""
|
||||
self._missing_tools = []
|
||||
return self.registry.get_applicable_tools(data_profile)
|
||||
|
||||
def get_all_tools(self) -> List[AnalysisTool]:
|
||||
"""Return all registered tools regardless of data profile."""
|
||||
tool_names = self.registry.list_tools()
|
||||
return [self.registry.get_tool(name) for name in tool_names]
|
||||
|
||||
def get_missing_tools(self) -> List[str]:
|
||||
return list(set(self._missing_tools))
|
||||
|
||||
def get_tool_descriptions(self, tools: List[AnalysisTool] = None) -> List[Dict[str, Any]]:
|
||||
"""Get tool descriptions for AI consumption."""
|
||||
if tools is None:
|
||||
tools = self.get_all_tools()
|
||||
return [
|
||||
{
|
||||
'name': t.name,
|
||||
'description': t.description,
|
||||
'parameters': t.parameters
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
@@ -1,443 +0,0 @@
|
||||
"""可视化工具。"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 使用非交互式后端
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from src.tools.base import AnalysisTool
|
||||
from src.models import DataProfile
|
||||
|
||||
# 尝试导入 seaborn,如果不可用则使用 matplotlib
|
||||
try:
|
||||
import seaborn as sns
|
||||
HAS_SEABORN = True
|
||||
except ImportError:
|
||||
HAS_SEABORN = False
|
||||
|
||||
|
||||
# 设置中文字体支持
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
class CreateBarChartTool(AnalysisTool):
|
||||
"""创建柱状图工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_bar_chart"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "创建柱状图,用于展示分类数据的分布或比较。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x_column": {
|
||||
"type": "string",
|
||||
"description": "X轴列名(分类变量)"
|
||||
},
|
||||
"y_column": {
|
||||
"type": "string",
|
||||
"description": "Y轴列名(数值变量),如果为空则计数"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "图表标题",
|
||||
"default": "柱状图"
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "输出文件路径",
|
||||
"default": "bar_chart.png"
|
||||
},
|
||||
"top_n": {
|
||||
"type": "integer",
|
||||
"description": "只显示前N个类别",
|
||||
"default": 20
|
||||
}
|
||||
},
|
||||
"required": ["x_column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行柱状图生成。"""
|
||||
x_column = kwargs.get('x_column')
|
||||
y_column = kwargs.get('y_column')
|
||||
title = kwargs.get('title', '柱状图')
|
||||
output_path = kwargs.get('output_path', 'bar_chart.png')
|
||||
top_n = kwargs.get('top_n', 20)
|
||||
|
||||
if x_column not in data.columns:
|
||||
return {'error': f'列 {x_column} 不存在'}
|
||||
|
||||
if y_column and y_column not in data.columns:
|
||||
return {'error': f'列 {y_column} 不存在'}
|
||||
|
||||
try:
|
||||
# 准备数据
|
||||
if y_column:
|
||||
# 按 x_column 分组,对 y_column 求和
|
||||
plot_data = data.groupby(x_column, observed=True)[y_column].sum().sort_values(ascending=False).head(top_n)
|
||||
else:
|
||||
# 计数
|
||||
plot_data = data[x_column].value_counts().head(top_n)
|
||||
|
||||
# 创建图表
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
plot_data.plot(kind='bar', ax=ax)
|
||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
ax.set_xlabel(x_column, fontsize=12)
|
||||
ax.set_ylabel(y_column if y_column else '计数', fontsize=12)
|
||||
ax.tick_params(axis='x', rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存图表
|
||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'chart_path': output_path,
|
||||
'chart_type': 'bar',
|
||||
'data_points': len(plot_data),
|
||||
'x_column': x_column,
|
||||
'y_column': y_column
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': f'生成柱状图失败: {str(e)}'}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于所有数据。"""
|
||||
return True
|
||||
|
||||
|
||||
class CreateLineChartTool(AnalysisTool):
|
||||
"""创建折线图工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_line_chart"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "创建折线图,用于展示时间序列数据或趋势变化。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x_column": {
|
||||
"type": "string",
|
||||
"description": "X轴列名(通常是时间)"
|
||||
},
|
||||
"y_column": {
|
||||
"type": "string",
|
||||
"description": "Y轴列名(数值变量)"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "图表标题",
|
||||
"default": "折线图"
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "输出文件路径",
|
||||
"default": "line_chart.png"
|
||||
}
|
||||
},
|
||||
"required": ["x_column", "y_column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行折线图生成。"""
|
||||
x_column = kwargs.get('x_column')
|
||||
y_column = kwargs.get('y_column')
|
||||
title = kwargs.get('title', '折线图')
|
||||
output_path = kwargs.get('output_path', 'line_chart.png')
|
||||
|
||||
if x_column not in data.columns:
|
||||
return {'error': f'列 {x_column} 不存在'}
|
||||
|
||||
if y_column not in data.columns:
|
||||
return {'error': f'列 {y_column} 不存在'}
|
||||
|
||||
try:
|
||||
# 准备数据
|
||||
plot_data = data[[x_column, y_column]].copy()
|
||||
plot_data = plot_data.sort_values(x_column)
|
||||
|
||||
# 如果数据点太多,采样
|
||||
if len(plot_data) > 1000:
|
||||
step = len(plot_data) // 1000
|
||||
plot_data = plot_data.iloc[::step]
|
||||
|
||||
# 创建图表
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
ax.plot(plot_data[x_column], plot_data[y_column], marker='o', markersize=3, linewidth=2)
|
||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
ax.set_xlabel(x_column, fontsize=12)
|
||||
ax.set_ylabel(y_column, fontsize=12)
|
||||
ax.grid(True, alpha=0.3)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存图表
|
||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'chart_path': output_path,
|
||||
'chart_type': 'line',
|
||||
'data_points': len(plot_data),
|
||||
'x_column': x_column,
|
||||
'y_column': y_column
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': f'生成折线图失败: {str(e)}'}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含数值列的数据。"""
|
||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
||||
|
||||
|
||||
class CreatePieChartTool(AnalysisTool):
|
||||
"""创建饼图工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_pie_chart"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "创建饼图,用于展示各部分占整体的比例。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column": {
|
||||
"type": "string",
|
||||
"description": "要分析的列名"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "图表标题",
|
||||
"default": "饼图"
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "输出文件路径",
|
||||
"default": "pie_chart.png"
|
||||
},
|
||||
"top_n": {
|
||||
"type": "integer",
|
||||
"description": "只显示前N个类别,其余归为'其他'",
|
||||
"default": 10
|
||||
}
|
||||
},
|
||||
"required": ["column"]
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行饼图生成。"""
|
||||
column = kwargs.get('column')
|
||||
title = kwargs.get('title', '饼图')
|
||||
output_path = kwargs.get('output_path', 'pie_chart.png')
|
||||
top_n = kwargs.get('top_n', 10)
|
||||
|
||||
if column not in data.columns:
|
||||
return {'error': f'列 {column} 不存在'}
|
||||
|
||||
try:
|
||||
# 准备数据
|
||||
value_counts = data[column].value_counts()
|
||||
|
||||
if len(value_counts) > top_n:
|
||||
# 只保留前 N 个,其余归为"其他"
|
||||
top_values = value_counts.head(top_n)
|
||||
other_sum = value_counts.iloc[top_n:].sum()
|
||||
plot_data = pd.concat([top_values, pd.Series({'其他': other_sum})])
|
||||
else:
|
||||
plot_data = value_counts
|
||||
|
||||
# 创建图表
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
colors = plt.cm.Set3(range(len(plot_data)))
|
||||
wedges, texts, autotexts = ax.pie(
|
||||
plot_data,
|
||||
labels=plot_data.index,
|
||||
autopct='%1.1f%%',
|
||||
colors=colors,
|
||||
startangle=90
|
||||
)
|
||||
|
||||
# 设置文本样式
|
||||
for text in texts:
|
||||
text.set_fontsize(10)
|
||||
for autotext in autotexts:
|
||||
autotext.set_color('white')
|
||||
autotext.set_fontweight('bold')
|
||||
autotext.set_fontsize(9)
|
||||
|
||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存图表
|
||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'chart_path': output_path,
|
||||
'chart_type': 'pie',
|
||||
'categories': len(plot_data),
|
||||
'column': column
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': f'生成饼图失败: {str(e)}'}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于所有数据。"""
|
||||
return True
|
||||
|
||||
|
||||
class CreateHeatmapTool(AnalysisTool):
|
||||
"""创建热力图工具。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_heatmap"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "创建热力图,用于展示数值矩阵或相关性矩阵。"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"columns": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "要分析的列名列表,如果为空则使用所有数值列"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "图表标题",
|
||||
"default": "相关性热力图"
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "输出文件路径",
|
||||
"default": "heatmap.png"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "相关系数方法:pearson, spearman, kendall",
|
||||
"default": "pearson"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def execute(self, data: pd.DataFrame, **kwargs) -> Dict[str, Any]:
|
||||
"""执行热力图生成。"""
|
||||
columns = kwargs.get('columns', [])
|
||||
title = kwargs.get('title', '相关性热力图')
|
||||
output_path = kwargs.get('output_path', 'heatmap.png')
|
||||
method = kwargs.get('method', 'pearson')
|
||||
|
||||
try:
|
||||
# 如果没有指定列,使用所有数值列
|
||||
if not columns:
|
||||
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
|
||||
else:
|
||||
numeric_cols = [col for col in columns if col in data.columns]
|
||||
|
||||
if len(numeric_cols) < 2:
|
||||
return {'error': '至少需要两个数值列来创建热力图'}
|
||||
|
||||
# 计算相关系数矩阵
|
||||
corr_matrix = data[numeric_cols].corr(method=method)
|
||||
|
||||
# 创建图表
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if HAS_SEABORN:
|
||||
# 使用 seaborn 创建更美观的热力图
|
||||
sns.heatmap(
|
||||
corr_matrix,
|
||||
annot=True,
|
||||
fmt='.2f',
|
||||
cmap='coolwarm',
|
||||
center=0,
|
||||
square=True,
|
||||
linewidths=1,
|
||||
cbar_kws={"shrink": 0.8},
|
||||
ax=ax
|
||||
)
|
||||
else:
|
||||
# 使用 matplotlib 创建基本热力图
|
||||
im = ax.imshow(corr_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)
|
||||
ax.set_xticks(range(len(corr_matrix.columns)))
|
||||
ax.set_yticks(range(len(corr_matrix.columns)))
|
||||
ax.set_xticklabels(corr_matrix.columns, rotation=45, ha='right')
|
||||
ax.set_yticklabels(corr_matrix.columns)
|
||||
|
||||
# 添加数值标注
|
||||
for i in range(len(corr_matrix.columns)):
|
||||
for j in range(len(corr_matrix.columns)):
|
||||
text = ax.text(j, i, f'{corr_matrix.iloc[i, j]:.2f}',
|
||||
ha="center", va="center", color="black", fontsize=9)
|
||||
|
||||
plt.colorbar(im, ax=ax, shrink=0.8)
|
||||
|
||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存图表
|
||||
plt.savefig(output_path, dpi=100, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'chart_path': output_path,
|
||||
'chart_type': 'heatmap',
|
||||
'columns': numeric_cols,
|
||||
'method': method
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': f'生成热力图失败: {str(e)}'}
|
||||
|
||||
def is_applicable(self, data_profile: DataProfile) -> bool:
|
||||
"""适用于包含至少两个数值列的数据。"""
|
||||
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
|
||||
return len(numeric_cols) >= 2
|
||||
Reference in New Issue
Block a user